Skip to content

Commit f1067d4

Browse files
authored
Fix minor bugs in fuzzy workflow (#999)
Signed-off-by: Ayush Dattagupta <ayushdg95@gmail.com>
1 parent 8ccd7d4 commit f1067d4

File tree

5 files changed

+24
-5
lines changed

5 files changed

+24
-5
lines changed

nemo_curator/stages/deduplication/fuzzy/identify_duplicates.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def _get_removal_ids(self, df: "cudf.DataFrame") -> "cudf.DataFrame":
9595
"""
9696
Get the removal ids for the given dataframe.
9797
"""
98+
if len(df) == 0:
99+
return df[[self.document_id_field]]
100+
98101
removal_ids = df[df[self.duplicate_group_field].duplicated(keep="first")][self.document_id_field]
99102
removal_ids = removal_ids.sort_values(ignore_index=True)
100103
return removal_ids.to_frame()

nemo_curator/stages/deduplication/fuzzy/lsh/lsh.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ def group_by_bucket(self, df: cudf.DataFrame, include_singles: bool = False) ->
245245
-------
246246
DataFrame with bucket IDs and lists of document IDs.
247247
"""
248+
if len(df) == 0:
249+
return df
248250
if not include_singles:
249251
# TODO: Add support for generating LSH index with single-document buckets that can be reused in incremental runs
250252
# Find bucket_ids that appear more than once (have multiple documents)

nemo_curator/stages/deduplication/fuzzy/lsh/stage.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ class LSHStage(ProcessingStage[FileGroupTask, FileGroupTask]):
6262
bands_per_iteration
6363
Number of bands to process per shuffle iteration. Between 1 and num_bands.
6464
Higher values reduce the number of shuffle iterations but increase the memory usage.
65+
total_nparts
66+
Total number of partitions to write during the shuffle.
67+
If None, the number of partitions will be decided automatically by the executor as the closest power of 2 <= number of input tasks.
6568
"""
6669

6770
_name = "LSHStage"
@@ -84,6 +87,7 @@ class LSHStage(ProcessingStage[FileGroupTask, FileGroupTask]):
8487
spill_memory_limit: int | Literal["auto"] | None = "auto"
8588
enable_statistics: bool = False
8689
bands_per_iteration: int = 5 # number of bands to process in each iteration
90+
total_nparts: int | None = None
8791

8892
def __post_init__(self):
8993
super().__init__()
@@ -102,6 +106,7 @@ def __post_init__(self):
102106
"enable_statistics": self.enable_statistics,
103107
"read_kwargs": self.read_kwargs,
104108
"write_kwargs": self.write_kwargs,
109+
"total_nparts": self.total_nparts, # Can be None, executor will set it
105110
}
106111

107112
if self.bands_per_iteration < 1 or self.bands_per_iteration > self.num_bands:

nemo_curator/stages/deduplication/fuzzy/workflow.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,8 @@ def _create_lsh_pipeline(self) -> Pipeline:
221221
num_bands=self.num_bands,
222222
minhashes_per_band=self.minhashes_per_band,
223223
output_path=self.cache_path,
224-
read_kwargs=self.read_kwargs,
224+
# Reading minhashes from cache_path
225+
read_kwargs=self.cache_kwargs,
225226
write_kwargs=self.cache_kwargs,
226227
bands_per_iteration=self.bands_per_iteration,
227228
rmm_pool_size="auto",
@@ -236,17 +237,17 @@ def _create_connected_components_pipeline(self) -> Pipeline:
236237
stages=[
237238
BucketsToEdgesStage(
238239
output_path=self.cache_path,
239-
read_kwargs=self.read_kwargs,
240+
read_kwargs=self.cache_kwargs,
240241
write_kwargs=self.cache_kwargs,
241242
),
242243
ConnectedComponentsStage(
243244
output_path=self.cache_path,
244-
read_kwargs=self.read_kwargs,
245+
read_kwargs=self.cache_kwargs,
245246
write_kwargs=self.cache_kwargs,
246247
),
247248
IdentifyDuplicatesStage(
248249
output_path=self.output_path,
249-
read_kwargs=self.read_kwargs,
250+
read_kwargs=self.cache_kwargs,
250251
write_kwargs=self.write_kwargs,
251252
rmm_pool_size="auto",
252253
spill_memory_limit="auto",

tests/stages/deduplication/fuzzy/test_lsh_stage.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,19 @@ def minhash_data(self, tmp_path: Path) -> FileGroupTask:
6262
},
6363
)
6464

65-
@pytest.mark.parametrize("bands_per_iteration", [2, 3])
65+
@pytest.mark.parametrize(
66+
("bands_per_iteration", "total_nparts"),
67+
[
68+
(2, 4),
69+
(3, None),
70+
],
71+
)
6672
def test_lsh(
6773
self,
6874
minhash_data: FileGroupTask,
6975
tmp_path: Path,
7076
bands_per_iteration: int,
77+
total_nparts: int | None,
7178
) -> None:
7279
# Create LSHStage
7380
lsh_stage = LSHStage(
@@ -77,6 +84,7 @@ def test_lsh(
7784
bands_per_iteration=bands_per_iteration,
7885
minhash_field="_minhash_signature",
7986
id_field=CURATOR_DEDUP_ID_STR,
87+
total_nparts=total_nparts,
8088
)
8189

8290
# Create pipeline and executor

0 commit comments

Comments
 (0)