Skip to content

Commit e7d7ef6

Browse files
authored
Changes in semdedup scripts (#785)
* Changes in semdedup scripts Signed-off-by: abdr17 <abdulrahmanejaz19@gmail.com> * Made required changes in semdedup tests Signed-off-by: abdr17 <abdulrahmanejaz19@gmail.com> --------- Signed-off-by: abdr17 <abdulrahmanejaz19@gmail.com>
1 parent c772cd2 commit e7d7ef6

File tree

5 files changed

+24
-18
lines changed

5 files changed

+24
-18
lines changed

nemo_curator/modules/semantic_dedup/semanticclusterleveldedup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def extract_dedup_data(self, eps_to_extract: float) -> DocumentDataset:
149149
msg = "eps_to_extract must be a float"
150150
self.logger.error(msg)
151151
raise TypeError(msg)
152-
output_parquet_path = os.path.join(self.output_dir, f"unique_ids_{eps_to_extract}.parquet")
152+
output_parquet_path = os.path.join(self.output_dir, f"duplicate_ids_{eps_to_extract}.parquet")
153153

154154
t0 = time.time()
155155
with performance_report_if_with_ts_suffix(

nemo_curator/scripts/semdedup/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ Please edit `config/sem_dedup_config.yaml` to configure the pipeline and run it
2828
- `embs_by_nearest_center` directory, containing `nearest_cent={x}` where x ranges from 0 to `num_clusters - 1`
2929
- Parquet files within `embs_by_nearest_center/nearest_cent={x}` containing the data points in each cluster
3030

31-
4) Extract deduplicated data
31+
4) Extract duplicate data IDs to remove
3232
```sh
33-
semdedup_extract_unique_ids --id-column "my_id" --id-column-type "str" --config-file "$CONFIG_FILE"
33+
semdedup_extract_duplicate_ids --id-column "my_id" --config-file "$CONFIG_FILE"
3434
```
3535
**Input:** Output from step (3) and YAML file from step (1)
3636

37-
**Output:** `{config.cache_dir}/{config.clustering_save_loc}/unique_ids_{}.parquet`
37+
**Output:** `{config.cache_dir}/{config.clustering_save_loc}/duplicate_ids_{}.parquet`

nemo_curator/scripts/semdedup/extract_dedup_data.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,14 @@ def main(args: argparse.Namespace) -> None:
5858

5959
semantic_dedup.compute_semantic_match_dfs()
6060
dedup_id_dataset = semantic_dedup.extract_dedup_data(eps_to_extract=semdedup_config.eps_to_extract)
61-
print(dedup_id_dataset.df.head(10))
61+
62+
len_dedup_id_dataset = len(dedup_id_dataset.df.index)
63+
64+
# Check whether duplicates are found or not
65+
if len_dedup_id_dataset == 0:
66+
logger.info("No semantic duplicates found!")
67+
else:
68+
print(dedup_id_dataset.df.head(10, npartitions=-1))
6269

6370
dt2 = time.perf_counter()
6471
logger.info(f"End: {dt2}")
@@ -77,7 +84,6 @@ def attach_args() -> argparse.ArgumentParser:
7784
"earlier using semdedup_extract_embeddings and semdedup_cluster_embeddings."
7885
"Input arguments include: "
7986
"--id-column for the the identifier in the dataset, "
80-
"--id-column-type for the data type of ID column, "
8187
"--config-file for the path to the semantic deduplication configuration file. "
8288
"Important configuration parameters include:"
8389
" cache_dir for the directory to store cache"

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ fineweb_nemotron_edu_classifier_inference = "nemo_curator.scripts.classifiers.fi
166166
blend_datasets = "nemo_curator.scripts.blend_datasets:console_script"
167167
semdedup_extract_embeddings = "nemo_curator.scripts.semdedup.compute_embeddings:console_script"
168168
semdedup_clustering = "nemo_curator.scripts.semdedup.clustering:console_script"
169-
semdedup_extract_unique_ids = "nemo_curator.scripts.semdedup.extract_dedup_data:console_script"
169+
semdedup_extract_duplicate_ids = "nemo_curator.scripts.semdedup.extract_dedup_data:console_script"
170170
async_llm_pii_redaction = "nemo_curator.scripts.async_llm_pii_redaction:console_script"
171171
llm_pii_redaction = "nemo_curator.scripts.llm_pii_redaction:console_script"
172172

tests/test_semdedup.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -690,22 +690,22 @@ def test_semantic_cluster_level_dedup(
690690

691691
# Call extract_dedup_data
692692
semantic_cluster_level_dedup.extract_dedup_data(eps_to_extract=0.01)
693-
# Check content of unique_ids
694-
unique_ids_path = os.path.join(semantic_extraction_output_dir, "unique_ids_0.01.parquet")
695-
assert os.path.exists(unique_ids_path)
696-
unique_ids_df = pd.read_parquet(unique_ids_path)
697-
assert unique_ids_df.columns.tolist() == [
693+
# Check content of duplicate_ids
694+
duplicate_ids_path = os.path.join(semantic_extraction_output_dir, "duplicate_ids_0.01.parquet")
695+
assert os.path.exists(duplicate_ids_path)
696+
duplicate_ids_df = pd.read_parquet(duplicate_ids_path)
697+
assert duplicate_ids_df.columns.tolist() == [
698698
"id",
699699
"cosine_dist_to_cent",
700700
"cluster",
701701
]
702702

703-
# Check content of semdedup_pruning_table with the filter matches the unique_ids
703+
# Check content of semdedup_pruning_table with the filter matches the duplicate_ids
704704
semdedup_pruning_tables_df_filtered = semdedup_pruning_tables_df[
705705
semdedup_pruning_tables_df["cosine_sim_score"] >= 1 - 0.01
706706
]
707-
assert len(semdedup_pruning_tables_df_filtered) == len(unique_ids_df)
708-
assert set(semdedup_pruning_tables_df_filtered["id"].to_list()) == set(unique_ids_df["id"].to_list())
707+
assert len(semdedup_pruning_tables_df_filtered) == len(duplicate_ids_df)
708+
assert set(semdedup_pruning_tables_df_filtered["id"].to_list()) == set(duplicate_ids_df["id"].to_list())
709709

710710
# Check content of summary file
711711
summary_path = os.path.join(semantic_extraction_output_dir, "dedup_summary_0.01.csv")
@@ -729,6 +729,6 @@ def test_semantic_cluster_level_dedup(
729729
}
730730
),
731731
)
732-
# Ensure that the unique_ids are also correct (this implicitly checks for semdedup_pruning_tables output)
733-
assert len(unique_ids_df) == _removed
734-
assert len(set(unique_ids_df["id"].to_list())) == _removed
732+
# Ensure that the duplicate_ids are also correct (this implicitly checks for semdedup_pruning_tables output)
733+
assert len(duplicate_ids_df) == _removed
734+
assert len(set(duplicate_ids_df["id"].to_list())) == _removed

0 commit comments

Comments
 (0)