Skip to content

Commit e662ac0

Browse files
authored
Add KMeans random_state to semantic deduplication configs (#575)
* Add KMeans random_state to semantic deduplication configs Signed-off-by: Sarah Yurick <sarahyurick@gmail.com> * edit docstrings Signed-off-by: Sarah Yurick <sarahyurick@gmail.com> --------- Signed-off-by: Sarah Yurick <sarahyurick@gmail.com>
1 parent e41159b commit e662ac0

File tree

13 files changed

+140
-73
lines changed

13 files changed

+140
-73
lines changed

config/sem_dedup_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ write_to_filename: false
1616
max_iter: 100
1717
n_clusters: 1000
1818
clustering_save_loc: "clustering_results"
19+
random_state: 1234
1920
sim_metric: "cosine"
2021
which_to_keep: "hard"
2122
sort_clusters: true

docs/user-guide/semdedup.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Semantic deduplication in NeMo Curator can be configured using a YAML file. Here
5454
max_iter: 100
5555
n_clusters: 1000
5656
clustering_save_loc: "clustering_results"
57+
random_state: 1234
5758
sim_metric: "cosine"
5859
which_to_keep: "hard"
5960
sort_clusters: true

nemo_curator/modules/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ class SemDedupConfig(BaseConfig):
173173
n_clusters (int): Number of clusters. Default is 1000.
174174
clustering_save_loc (str): Location to save clustering results.
175175
Default is "clustering_results".
176+
random_state (int): KMeans random state used for reproducibility. Default is 1234.
176177
sim_metric (str): Similarity metric for deduplication.
177178
Default is "cosine".
178179
which_to_keep (str): Method to determine which duplicates to keep.
@@ -208,6 +209,7 @@ class SemDedupConfig(BaseConfig):
208209
max_iter: int = 100
209210
n_clusters: int = 1000
210211
clustering_save_loc: str = "clustering_results"
212+
random_state: int = 1234
211213
sim_metric: str = "cosine"
212214
which_to_keep: str = "hard"
213215
sort_clusters: bool = True

nemo_curator/modules/semantic_dedup/clusteringmodel.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from nemo_curator.utils.semdedup_utils import assign_and_sort_clusters
3333

3434

35-
### Clustering Module
3635
def get_embedding_ar(df: "cudf.DataFrame", embedding_col: str) -> cp.ndarray:
3736
return df[embedding_col].list.leaves.values.reshape(len(df), -1)
3837

@@ -47,14 +46,16 @@ def add_dist_to_cents(
4746
return df
4847

4948

49+
# Clustering module
5050
class ClusteringModel:
5151
def __init__(
5252
self,
53-
id_column: str,
54-
max_iter: int,
55-
n_clusters: int,
56-
clustering_output_dir: str,
53+
id_column: str = "id",
54+
max_iter: int = 100,
55+
n_clusters: int = 1000,
56+
clustering_output_dir: str = "./clustering_results",
5757
embedding_column: str = "embeddings",
58+
random_state: int = 1234,
5859
sim_metric: str = "cosine",
5960
which_to_keep: str = "hard",
6061
sort_clusters: bool = True,
@@ -68,25 +69,36 @@ def __init__(
6869
6970
Args:
7071
id_column (str): Column name used as the identifier in the dataset.
71-
max_iter (int): Maximum number of iterations for the clustering algorithm.
72-
n_clusters (int): The number of clusters to form.
73-
clustering_output_dir (str): Directory path where clustering results will be saved.
74-
embedding_column (str): Column name where the embeddings are stored.
75-
sim_metric (str): Similarity metric to use for clustering, default is "cosine".
76-
which_to_keep (str): Strategy to decide which duplicates to keep; default is "hard".
77-
sort_clusters (bool): Whether to sort clusters, default is True.
78-
kmeans_with_cos_dist (bool): Whether to use KMeans with cosine distance, default is False.
79-
clustering_input_partition_size (str): The size of data partition to run kmeans with, default is "2gb".
80-
logger (Union[logging.Logger, str]): Logger object or directory path to save logs; default is "./".
81-
profile_dir (str): If specified directory to write dask profile. Default is None.
82-
83-
This constructor sets up the parameters required for clustering operations.
72+
Default is "id".
73+
max_iter (int): Maximum iterations for clustering. Default is 100.
74+
n_clusters (int): Number of clusters. Default is 1000.
75+
clustering_output_dir (str): Location to save clustering results.
76+
Default is "./clustering_results".
77+
embedding_column (str): The column name that stores the embeddings.
78+
Default is "embeddings".
79+
random_state (int): KMeans random state used for reproducibility.
80+
Default is 1234.
81+
sim_metric (str): Similarity metric for deduplication.
82+
Default is "cosine".
83+
which_to_keep (str): Method to determine which duplicates to keep.
84+
Default is "hard".
85+
sort_clusters (bool): Whether to sort clusters. Default is True.
86+
kmeans_with_cos_dist (bool): Whether or not to use KMeans with cosine distance.
87+
Default is False.
88+
clustering_input_partition_size (str): The size of data partition with which to run KMeans.
89+
Default is "2gb".
90+
logger (Union[logging.Logger, str]): Existing logger to log to, or a path to a log directory.
91+
Default is "./".
92+
profile_dir (Optional[str]): If specified, directory to write Dask profile.
93+
Default is None.
94+
8495
"""
8596
self.id_col = id_column
8697
self.max_iter = max_iter
8798
self.n_clusters = n_clusters
8899
self.clustering_output_dir = clustering_output_dir
89100
self.embedding_column = embedding_column
101+
self.random_state = random_state
90102
self.sim_metric = sim_metric
91103
self.keep_hard = which_to_keep == "hard"
92104
self.kmeans_with_cos_dist = kmeans_with_cos_dist
@@ -119,7 +131,7 @@ def __call__(self, embeddings_dataset: DocumentDataset):
119131

120132
if self.embedding_column not in embeddings_df.columns:
121133
raise ValueError(
122-
f"Expected embedding column '{self.embedding_column}'"
134+
f'Expected embedding column "{self.embedding_column}"'
123135
f" to be in dataset. Only found columns {embeddings_df.columns}"
124136
)
125137

@@ -153,18 +165,22 @@ def __call__(self, embeddings_dataset: DocumentDataset):
153165
)
154166
cupy_darr.compute_chunk_sizes()
155167
t0 = time.time()
156-
kmeans = KMeans(n_clusters=self.n_clusters, max_iter=self.max_iter)
168+
kmeans = KMeans(
169+
n_clusters=self.n_clusters,
170+
max_iter=self.max_iter,
171+
random_state=self.random_state,
172+
)
157173
self.logger.info("KMeans starting fit")
158174
kmeans.fit(cupy_darr)
159175
self.logger.info("KMeans fit complete")
160-
self.logger.info(f"Time taken for KMeans Fit: {time.time() - t0}")
176+
self.logger.info(f"Time taken for KMeans fit: {time.time() - t0}")
161177

162178
self.logger.info(
163-
"Computing nearest centroids + distance to centers using kmeans.predict"
179+
"Computing nearest centroids and distance to centers using kmeans.predict"
164180
)
165181
t0 = time.time()
166182
nearest_cents = kmeans.predict(cupy_darr)
167-
self.logger.info(f"Time taken for KMeans Predict: {time.time() - t0}")
183+
self.logger.info(f"Time taken for KMeans predict: {time.time() - t0}")
168184

169185
t0 = time.time()
170186
embeddings_df["nearest_cent"] = nearest_cents.astype(np.int32)
@@ -196,13 +212,11 @@ def __call__(self, embeddings_dataset: DocumentDataset):
196212
shutil.rmtree(clustering_output_dir)
197213

198214
embeddings_df.to_parquet(
199-
clustering_output_dir,
200-
index=False,
201-
partition_on="nearest_cent",
215+
clustering_output_dir, index=False, partition_on="nearest_cent"
202216
)
203217
self.logger.info(
204-
f"Time taken for Assigning distance to each embedding : {time.time() - t0} "
205-
f"and output written at {clustering_output_dir}"
218+
f"Time taken for assigning distance to each embedding: {time.time() - t0}s"
219+
f" and output written at {clustering_output_dir}"
206220
)
207221

208222
del embeddings_df

nemo_curator/modules/semantic_dedup/embeddings.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
)
3737

3838

39-
# Embedding Creation Module
39+
# Embedding creation module
4040
@dataclass
4141
class EmbeddingConfig:
4242
model_name_or_path: str
@@ -47,7 +47,7 @@ def __post_init__(self):
4747
self.max_seq_length = AutoTokenizer.from_pretrained(
4848
self.model_name_or_path
4949
).model_max_length
50-
# Gaurd against the HF bug
50+
# Guard against Hugging Face bug
5151
# which sets max_seq_length to max(int) for some models
5252
if self.max_seq_length > 1e5:
5353
self.max_seq_length = AutoConfig.from_pretrained(
@@ -133,9 +133,9 @@ def load_tokenizer(self):
133133
class EmbeddingCreator:
134134
def __init__(
135135
self,
136-
embedding_model_name_or_path: str,
137-
embedding_batch_size: int,
138-
embedding_output_dir: str,
136+
embedding_model_name_or_path: str = "sentence-transformers/all-MiniLM-L6-v2",
137+
embedding_batch_size: int = 128,
138+
embedding_output_dir: str = "./embeddings",
139139
embedding_max_mem_gb: Optional[int] = None,
140140
embedding_pooling_strategy: str = "mean_pooling",
141141
input_column: str = "text",
@@ -149,28 +149,29 @@ def __init__(
149149
Initializes an EmbeddingCreator for generating embeddings using the specified model configurations.
150150
151151
Args:
152-
embedding_model_name_or_path (str): The path or identifier for the model used to generate embeddings.
153-
embedding_batch_size (int): Number of samples to process in each batch.
154-
embedding_output_dir (str): Directory path where embeddings will be saved.
155-
embedding_max_mem_gb (int): Maximum memory usage in GB for the embedding process.
156-
If None, it defaults to the available GPU memory minus 4 GB.
157-
embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean_pooling" or "last_token". Defaults to "mean_pooling".
158-
input_column (str): Column name from the data to be used for embedding generation, defaults to "text".
159-
write_embeddings_to_disk (bool, optional): If True, saves the embeddings to disk, defaults to True.
160-
We recommend setting this to False when you have a delayed pipeline.
161-
Setting it to False can lead to more memory overhead.
162-
write_to_filename (bool): If True, saves the embeddings to the same filename as input files, defaults to False.
163-
logger (Union[logging.Logger, str]): Logger object or path to store logs, defaults to "./".
164-
profile_dir (str): If specified directory to write dask profile. Default is None.
165-
166-
Attributes:
167-
embeddings_config (EmbeddingConfig): Configuration for embeddings.
168-
batch_size (int): Batch size for embedding generation.
169-
logger (logging.Logger): Logger instance for the class.
170-
embedding_output_dir (str): Output directory for embeddings.
171-
input_column (str): Input column for data processing.
172-
model (EmbeddingCrossFitModel): Model instance for embedding generation.
173-
write_to_filename (bool): If True, saves the embeddings to the same filename as input files, defaults to False.
152+
embedding_model_name_or_path (str): Model name or path for embeddings.
153+
Default is "sentence-transformers/all-MiniLM-L6-v2".
154+
embedding_batch_size (int): Initial batch size for processing embeddings.
155+
Default is 128.
156+
embedding_output_dir (str): Location to save embeddings.
157+
Default is "./embeddings".
158+
embedding_max_mem_gb (int, optional): Maximum memory usage in GB for the embedding process.
159+
If None, it defaults to the available GPU memory minus 4 GB.
160+
embedding_pooling_strategy: Strategy for pooling embeddings, either "mean_pooling" or "last_token".
161+
Default is "mean_pooling".
162+
input_column (str): Column name from the data to be used for embedding generation.
163+
Default is "text".
164+
embedding_column (str): The column name that stores the embeddings. Default is "embeddings".
165+
write_embeddings_to_disk (bool): If True, saves the embeddings to disk.
166+
We recommend setting this to False when you have a delayed pipeline.
167+
Setting it to False can lead to more memory overhead. Default is True.
168+
write_to_filename (bool): If True, saves the embeddings to the same filename as input files.
169+
Default False.
170+
logger (Union[logging.Logger, str]): Existing logger to log to, or a path to a log directory.
171+
Default is "./".
172+
profile_dir (Optional[str]): If specified, directory to write Dask profile.
173+
Default is None.
174+
174175
"""
175176

176177
self.embeddings_config = EmbeddingConfig(

nemo_curator/modules/semantic_dedup/semanticclusterleveldedup.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@
3535
class SemanticClusterLevelDedup:
3636
def __init__(
3737
self,
38-
n_clusters: int,
39-
emb_by_clust_dir: str,
40-
sorted_clusters_dir: str,
41-
id_column: str,
42-
id_column_type: str,
43-
which_to_keep: str,
44-
output_dir: str,
38+
n_clusters: int = 1000,
39+
emb_by_clust_dir: str = "./clustering_results/embs_by_nearest_center",
40+
sorted_clusters_dir: str = "./clustering_results/sorted",
41+
id_column: str = "id",
42+
id_column_type: str = "int",
43+
which_to_keep: str = "hard",
44+
output_dir: str = "./clustering_results",
4545
embedding_column: str = "embeddings",
4646
logger: Union[logging.Logger, str] = "./",
4747
profile_dir: Optional[str] = None,
@@ -50,16 +50,25 @@ def __init__(
5050
Initialize the SemanticClusterLevelDedup class.
5151
5252
Args:
53-
n_clusters (int): Number of clusters.
53+
n_clusters (int): Number of clusters. Default is 1000.
5454
emb_by_clust_dir (str): Directory containing embeddings by cluster.
55+
Default is "./clustering_results/embs_by_nearest_center".
5556
sorted_clusters_dir (str): Directory containing sorted clusters.
56-
id_column (str): Column name for IDs.
57-
id_column_type (str): Data type of the ID column.
58-
which_to_keep (str): Strategy for which duplicate to keep.
57+
Default is "./clustering_results/sorted".
58+
id_column (str): Column name used as the identifier in the dataset.
59+
Default is "id".
60+
id_column_type (str): Data type of id_column. Default is "int".
61+
which_to_keep (str): Method to determine which duplicates to keep.
62+
Default is "hard".
5963
output_dir (str): Directory to save output files.
60-
embedding_column (str): Column where the embeddings are stored.
61-
logger (Union[logging.Logger, str]): Logger instance or path to the log file directory.
62-
profile_dir (str): If specified directory to write dask profile. Default is None.
64+
Default is "./clustering_results".
65+
embedding_column (str): The column name that stores the embeddings.
66+
Default is "embeddings".
67+
logger (Union[logging.Logger, str]): Existing logger to log to, or a path to a log directory.
68+
Default is "./".
69+
profile_dir (Optional[str]): If specified, directory to write Dask profile.
70+
Default is None.
71+
6372
"""
6473
self.n_clusters = n_clusters
6574
self.emb_by_clust_dir = emb_by_clust_dir
@@ -118,6 +127,7 @@ def compute_semantic_match_dfs(
118127
shutil.rmtree(self.semdedup_pruning_tables_dir)
119128
expand_outdir_and_mkdir(self.semdedup_pruning_tables_dir)
120129
t0 = time.time()
130+
121131
with performance_report_if_with_ts_suffix(
122132
self.profile_dir, "semantic-match-compute"
123133
):

nemo_curator/modules/semantic_dedup/semdedup.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,13 @@ def __init__(
4141
4242
Args:
4343
config (SemDedupConfig): Configuration for SemDedup.
44-
logger (Union[logging.Logger, str]): Logger instance or path to the log file directory.
44+
input_column (str): Column name from the data to be used for embedding generation.
45+
Default is "text".
46+
id_column (str): Column name used as the identifier in the dataset.
47+
Default is "id".
48+
id_column_type (str): Data type of id_column. Default is "int".
49+
logger (Union[logging.Logger, str]): Existing logger to log to, or a path to a log directory.
50+
Default is "./".
4551
"""
4652
super().__init__(input_backend="cudf")
4753
self.config = config

nemo_curator/scripts/semdedup/clustering.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ def main(args):
6666
max_iter=semdedup_config.max_iter,
6767
n_clusters=semdedup_config.n_clusters,
6868
clustering_output_dir=clustering_output_dir,
69+
embedding_column=semdedup_config.embedding_column,
70+
random_state=semdedup_config.random_state,
71+
sim_metric=semdedup_config.sim_metric,
72+
which_to_keep=semdedup_config.which_to_keep,
73+
sort_clusters=semdedup_config.sort_clusters,
74+
kmeans_with_cos_dist=semdedup_config.kmeans_with_cos_dist,
75+
clustering_input_partition_size=semdedup_config.clustering_input_partition_size,
6976
logger=logger,
7077
)
7178

nemo_curator/scripts/semdedup/compute_embeddings.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -41,7 +41,7 @@ def main(args):
4141
semdedup_config.cache_dir, semdedup_config.embeddings_save_loc
4242
)
4343

44-
# Some time jsonl files are stored as .json
44+
# Sometimes JSONL files are stored as .json
4545
# So to handle that case we can pass the input_file_extension
4646
if args.input_file_extension is not None:
4747
input_file_extension = args.input_file_extension
@@ -79,10 +79,13 @@ def main(args):
7979
embedding_output_dir=os.path.join(
8080
semdedup_config.cache_dir, semdedup_config.embeddings_save_loc
8181
),
82+
embedding_max_mem_gb=semdedup_config.embedding_max_mem_gb,
83+
embedding_pooling_strategy=semdedup_config.embedding_pooling_strategy,
8284
input_column=args.input_text_field,
85+
embedding_column=semdedup_config.embedding_column,
8386
write_embeddings_to_disk=semdedup_config.write_embeddings_to_disk,
87+
write_to_filename=semdedup_config.write_to_filename,
8488
logger=logger,
85-
write_to_filename=True,
8689
)
8790

8891
embedding_dataset = embedding_creator(dataset=dataset)

0 commit comments

Comments
 (0)