Skip to content

Commit f60d1ed

Browse files
Bug fix in removal write kwargs + add input_task_limit in removal (#995)
Signed-off-by: Praateek Mahajan <[email protected]>
1 parent f3b8e00 commit f60d1ed

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

nemo_curator/stages/text/deduplication/removal_workflow.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class TextDuplicatesRemovalWorkflow:
4242
input_files_per_partition: int | None = None
4343
input_blocksize: str | None = None
4444
input_file_extensions: list[str] | None = None
45+
input_task_limit: int | None = None
4546
input_kwargs: dict[str, Any] | None = None
4647

4748
# ids_to_remove args
@@ -83,6 +84,7 @@ def _generate_stages(self, initial_tasks: list[FileGroupTask] | None = None) ->
8384
blocksize=self.input_blocksize,
8485
file_extensions=self.input_file_extensions,
8586
storage_options=(self.input_kwargs or {}).get("storage_options"),
87+
limit=self.input_task_limit,
8688
)
8789
)
8890
else:
@@ -135,7 +137,7 @@ def _generate_stages(self, initial_tasks: list[FileGroupTask] | None = None) ->
135137
write_stage(
136138
path=self.output_path,
137139
**({"file_extension": self.output_file_extension} if self.output_file_extension else {}),
138-
write_kwargs=self.output_kwargs,
140+
write_kwargs=self.output_kwargs or {},
139141
fields=self.output_fields,
140142
**({"mode": self.output_mode} if self.output_mode else {}),
141143
)
@@ -151,6 +153,11 @@ def run(
151153
description="Text duplicates removal workflow",
152154
stages=self._generate_stages(initial_tasks),
153155
)
156+
if self.input_task_limit is not None and len(initial_tasks) > self.input_task_limit:
157+
logger.warning(
158+
f"Initial tasks provided ({len(initial_tasks)}) is greater than input_task_limit ({self.input_task_limit}), truncating to {self.input_task_limit}"
159+
)
160+
initial_tasks = initial_tasks[: self.input_task_limit]
154161

155162
if executor is None:
156163
from nemo_curator.backends.xenna import XennaExecutor

tests/stages/text/deduplication/test_removal_workflow.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def test_initial_tasks_partitioning(self, test_config: "TestTextDuplicateRemoval
230230
output_filetype="parquet",
231231
input_id_field=CURATOR_DEDUP_ID_STR,
232232
ids_to_remove_duplicate_id_field="id",
233+
input_task_limit=10, # truncate to 10 tasks only
233234
input_kwargs={},
234235
ids_to_remove_read_kwargs={},
235236
output_kwargs={},
@@ -239,14 +240,14 @@ def test_initial_tasks_partitioning(self, test_config: "TestTextDuplicateRemoval
239240
output_tasks = workflow.run(executor, initial_tasks=initial_tasks)
240241

241242
# Verify we get 20 output tasks (one per input task)
242-
assert len(output_tasks) == 20, (
243-
f"Expected 20 output tasks, got {len(output_tasks)} for {test_config.executor_cls.__name__}"
243+
assert len(output_tasks) == 10, (
244+
f"Expected 10 output tasks, got {len(output_tasks)} for {test_config.executor_cls.__name__}"
244245
)
245246

246247
# Verify correctness remains the same as other tests
247248
combined_output_df = pd.concat([pd.read_parquet(task.data) for task in output_tasks], ignore_index=True)
248-
assert len(combined_output_df) == 800, (
249-
f"Expected 800 records, got {len(combined_output_df)} for {test_config.executor_cls.__name__}"
249+
assert len(combined_output_df) == 400, (
250+
f"Expected 400 records, got {len(combined_output_df)} for {test_config.executor_cls.__name__}"
250251
)
251252

252253
# Verify no IDs divisible by 5 remain

0 commit comments

Comments
 (0)