3232from nemo_curator .utils .semdedup_utils import assign_and_sort_clusters
3333
3434
35- ### Clustering Module
3635def get_embedding_ar (df : "cudf.DataFrame" , embedding_col : str ) -> cp .ndarray :
3736 return df [embedding_col ].list .leaves .values .reshape (len (df ), - 1 )
3837
@@ -47,14 +46,16 @@ def add_dist_to_cents(
4746 return df
4847
4948
49+ # Clustering module
5050class ClusteringModel :
5151 def __init__ (
5252 self ,
53- id_column : str ,
54- max_iter : int ,
55- n_clusters : int ,
56- clustering_output_dir : str ,
53+ id_column : str = "id" ,
54+ max_iter : int = 100 ,
55+ n_clusters : int = 1000 ,
56+ clustering_output_dir : str = "./clustering_results" ,
5757 embedding_column : str = "embeddings" ,
58+ random_state : int = 1234 ,
5859 sim_metric : str = "cosine" ,
5960 which_to_keep : str = "hard" ,
6061 sort_clusters : bool = True ,
@@ -68,25 +69,36 @@ def __init__(
6869
6970 Args:
7071 id_column (str): Column name used as the identifier in the dataset.
71- max_iter (int): Maximum number of iterations for the clustering algorithm.
72- n_clusters (int): The number of clusters to form.
73- clustering_output_dir (str): Directory path where clustering results will be saved.
74- embedding_column (str): Column name where the embeddings are stored.
75- sim_metric (str): Similarity metric to use for clustering, default is "cosine".
76- which_to_keep (str): Strategy to decide which duplicates to keep; default is "hard".
77- sort_clusters (bool): Whether to sort clusters, default is True.
78- kmeans_with_cos_dist (bool): Whether to use KMeans with cosine distance, default is False.
79- clustering_input_partition_size (str): The size of data partition to run kmeans with, default is "2gb".
80- logger (Union[logging.Logger, str]): Logger object or directory path to save logs; default is "./".
81- profile_dir (str): If specified directory to write dask profile. Default is None.
82-
83- This constructor sets up the parameters required for clustering operations.
72+ Default is "id".
73+ max_iter (int): Maximum iterations for clustering. Default is 100.
74+ n_clusters (int): Number of clusters. Default is 1000.
75+ clustering_output_dir (str): Location to save clustering results.
76+ Default is "./clustering_results".
77+ embedding_column (str): The column name that stores the embeddings.
78+ Default is "embeddings".
79+ random_state (int): KMeans random state used for reproducibility.
80+ Default is 1234.
81+ sim_metric (str): Similarity metric for deduplication.
82+ Default is "cosine".
83+ which_to_keep (str): Method to determine which duplicates to keep.
84+ Default is "hard".
85+ sort_clusters (bool): Whether to sort clusters. Default is True.
86+ kmeans_with_cos_dist (bool): Whether or not to use KMeans with cosine distance.
87+ Default is False.
88+ clustering_input_partition_size (str): The size of data partition with which to run KMeans.
89+ Default is "2gb".
90+ logger (Union[logging.Logger, str]): Existing logger to log to, or a path to a log directory.
91+ Default is "./".
92+ profile_dir (Optional[str]): If specified, directory to write Dask profile.
93+ Default is None.
94+
8495 """
8596 self .id_col = id_column
8697 self .max_iter = max_iter
8798 self .n_clusters = n_clusters
8899 self .clustering_output_dir = clustering_output_dir
89100 self .embedding_column = embedding_column
101+ self .random_state = random_state
90102 self .sim_metric = sim_metric
91103 self .keep_hard = which_to_keep == "hard"
92104 self .kmeans_with_cos_dist = kmeans_with_cos_dist
@@ -119,7 +131,7 @@ def __call__(self, embeddings_dataset: DocumentDataset):
119131
120132 if self .embedding_column not in embeddings_df .columns :
121133 raise ValueError (
122- f" Expected embedding column ' { self .embedding_column } '"
134+ f' Expected embedding column " { self .embedding_column } "'
123135 f" to be in dataset. Only found columns { embeddings_df .columns } "
124136 )
125137
@@ -153,18 +165,22 @@ def __call__(self, embeddings_dataset: DocumentDataset):
153165 )
154166 cupy_darr .compute_chunk_sizes ()
155167 t0 = time .time ()
156- kmeans = KMeans (n_clusters = self .n_clusters , max_iter = self .max_iter )
168+ kmeans = KMeans (
169+ n_clusters = self .n_clusters ,
170+ max_iter = self .max_iter ,
171+ random_state = self .random_state ,
172+ )
157173 self .logger .info ("KMeans starting fit" )
158174 kmeans .fit (cupy_darr )
159175 self .logger .info ("KMeans fit complete" )
160- self .logger .info (f"Time taken for KMeans Fit : { time .time () - t0 } " )
176+ self .logger .info (f"Time taken for KMeans fit : { time .time () - t0 } " )
161177
162178 self .logger .info (
163- "Computing nearest centroids + distance to centers using kmeans.predict"
179+ "Computing nearest centroids and distance to centers using kmeans.predict"
164180 )
165181 t0 = time .time ()
166182 nearest_cents = kmeans .predict (cupy_darr )
167- self .logger .info (f"Time taken for KMeans Predict : { time .time () - t0 } " )
183+ self .logger .info (f"Time taken for KMeans predict : { time .time () - t0 } " )
168184
169185 t0 = time .time ()
170186 embeddings_df ["nearest_cent" ] = nearest_cents .astype (np .int32 )
@@ -196,13 +212,11 @@ def __call__(self, embeddings_dataset: DocumentDataset):
196212 shutil .rmtree (clustering_output_dir )
197213
198214 embeddings_df .to_parquet (
199- clustering_output_dir ,
200- index = False ,
201- partition_on = "nearest_cent" ,
215+ clustering_output_dir , index = False , partition_on = "nearest_cent"
202216 )
203217 self .logger .info (
204- f"Time taken for Assigning distance to each embedding : { time .time () - t0 } "
205- f"and output written at { clustering_output_dir } "
218+ f"Time taken for assigning distance to each embedding: { time .time () - t0 } s "
219+ f" and output written at { clustering_output_dir } "
206220 )
207221
208222 del embeddings_df
0 commit comments