55from functools import partial
66from typing import Dict , List , Optional , Set , Tuple , Type
77import os
8- from datasets import Dataset , load_dataset , concatenate_datasets
8+ from datasets import Dataset
99from tqdm .auto import tqdm
1010
1111from datasketch import MinHash , MinHashLSH
1212from dpu_utils .utils .iterators import ThreadedIterator
1313from sklearn .metrics .pairwise import cosine_similarity
1414
15- NON_ALPHA = re .compile ("[^\u0080 -\u00FF \u0100 -\u017F \u0600 -\u06FF \u07C0 -\u07FF \u0900 -\u097F \u1200 -\u137F \u2D30 -\u2D7F \uA500 -\uA63F A-Za-z_0-9]" )
15+ NON_ALPHA = re .compile (
16+ "[^\u0080 -\u00FF \u0100 -\u017F \u0600 -\u06FF \u07C0 -\u07FF \u0900 -\u097F \u1200 -\u137F \u2D30 -\u2D7F \uA500 -\uA63F A-Za-z_0-9]"
17+ )
1618# parameters used in DuplicationIndex
1719MIN_NUM_TOKENS = 10
1820NUM_PERM = 256
1921
2022
2123def get_model (model_name = "sentence-transformers/all-MiniLM-L12-v2" ):
22- from sentence_transformers import SentenceTransformer
23- model = SentenceTransformer (model_name )
24- return model
24+ from sentence_transformers import SentenceTransformer
25+
26+ model = SentenceTransformer (model_name )
27+ return model
28+
2529
2630def get_min_hash (tokens : List [str ]) -> Optional [MinHash ]:
2731 """Compute the MinHash of a code snippet."""
@@ -37,10 +41,10 @@ def get_tokens(code: str) -> Set[str]:
3741 """Tokenize a code snippet."""
3842 return set (get_data (code ))
3943
40-
44+
4145def get_data (data ):
42- return [t for t in NON_ALPHA .split (data ) if len (t .strip ()) > 0 ]
43- # return tokenizer(data)
46+ return [t for t in NON_ALPHA .split (data ) if len (t .strip ()) > 0 ]
47+ # return tokenizer(data)
4448
4549
4650class DuplicationIndex :
@@ -72,7 +76,6 @@ def add(self, code_key: Tuple, min_hash: MinHash) -> None:
7276
7377 self ._index .insert (code_key , min_hash )
7478 if len (close_duplicates ) > 0 :
75-
7679 for base_duplicate in close_duplicates :
7780 if base_duplicate in self ._duplicate_clusters :
7881 self ._duplicate_clusters [base_duplicate ].add (code_key )
@@ -104,7 +107,7 @@ def save(self, filepath) -> None:
104107
105108def _compute_min_hash (element ):
106109 index , data = element
107- min_hash = get_min_hash (get_data (data [' text' ]))
110+ min_hash = get_min_hash (get_data (data [" text" ]))
108111 if min_hash is not None :
109112 return index , min_hash
110113
@@ -129,7 +132,11 @@ def make_duplicate_clusters(dataset_iterator: Type[Dataset], jaccard_threshold:
129132 """
130133 di = DuplicationIndex (duplication_jaccard_threshold = jaccard_threshold )
131134
132- for filename , min_hash in tqdm (ThreadedIterator (minhash_iter (enumerate (dataset_iterator )), max_queue_size = 10000 ), total = len (dataset_iterator ), desc = "Deduplication" ):
135+ for filename , min_hash in tqdm (
136+ ThreadedIterator (minhash_iter (enumerate (dataset_iterator )), max_queue_size = 10000 ),
137+ total = len (dataset_iterator ),
138+ desc = "Deduplication" ,
139+ ):
133140 di .add (filename , min_hash )
134141
135142 # Returns a List[Cluster] where Cluster is List[str] with the filenames.
@@ -142,8 +149,10 @@ def jaccard_similarity(code1: str, code2: str) -> float:
142149 tokens2 = get_tokens (code2 )
143150 return len (tokens1 & tokens2 ) / len (tokens1 | tokens2 )
144151
152+
145153_shared_dataset = None
146154
155+
147156def _find_cluster_extremes_shared (cluster , jaccard_threshold ):
148157 """Find a reduced cluster such that each code in the origin cluster is similar to at least one code in the reduced cluster.
149158 Two codes are similar if their Jaccard similarity is above the threshold.
@@ -175,13 +184,15 @@ def _find_cluster_extremes_shared(cluster, jaccard_threshold):
175184 extremes .append (element1 )
176185 return extremes
177186
187+
178188def _find_cluster_extremes_shared_semantic (cluster , jaccard_threshold ):
179189 extremes = []
180190 # Convert code snippets to embeddings
181191 model = get_model ()
182- code_embeddings = {element ["base_index" ]: model .encode (_shared_dataset [element ["base_index" ]]["text" ])
183- for element in cluster }
184-
192+ code_embeddings = {
193+ element ["base_index" ]: model .encode (_shared_dataset [element ["base_index" ]]["text" ]) for element in cluster
194+ }
195+
185196 for element1 in cluster :
186197 embedding1 = code_embeddings [element1 ["base_index" ]]
187198 for element2 in extremes :
@@ -195,6 +206,7 @@ def _find_cluster_extremes_shared_semantic(cluster, jaccard_threshold):
195206 extremes .append (element1 )
196207 return extremes
197208
209+
198210def find_extremes (cluster_list , dataset , jaccard_threshold ):
199211 """Call the _find_cluster_extremes_shared function in a parallel fashion.
200212 Args:
@@ -224,11 +236,12 @@ def find_extremes(cluster_list, dataset, jaccard_threshold):
224236 cluster_list ,
225237 ),
226238 total = len (cluster_list ),
227- desc = "Finding overlaps"
239+ desc = "Finding overlaps" ,
228240 ):
229241 extremes_list .append (extremes )
230242 return extremes_list
231243
244+
232245def find_extremes_semantic (cluster_list , dataset , jaccard_threshold ):
233246 """Call the _find_cluster_extremes_shared_semantic function in a parallel fashion.
234247 Args:
@@ -258,18 +271,19 @@ def find_extremes_semantic(cluster_list, dataset, jaccard_threshold):
258271 cluster_list ,
259272 ),
260273 total = len (cluster_list ),
261- desc = "Finding overlaps semantically"
274+ desc = "Finding overlaps semantically" ,
262275 ):
263276 extremes_list .append (extremes )
264277 return extremes_list
265278
279+
266280def display_dataset_entries (dataset , duplicate_clusters ):
267281 """
268282 Fetches and displays the dataset entries for given base_index values in duplicate clusters.
269283
270284 Args:
271285 dataset (Dataset): The dataset object containing the code snippets.
272- duplicate_clusters (List[List[Dict]]): List of duplicate clusters, each containing dictionaries
286+ duplicate_clusters (List[List[Dict]]): List of duplicate clusters, each containing dictionaries
273287 with 'base_index', 'is_extreme', and 'copies' keys.
274288
275289 Returns:
@@ -278,12 +292,13 @@ def display_dataset_entries(dataset, duplicate_clusters):
278292 for cluster in duplicate_clusters :
279293 print ("Cluster:" )
280294 for item in cluster :
281- base_index = item [' base_index' ]
295+ base_index = item [" base_index" ]
282296 data_entry = dataset [base_index ] # Assuming the dataset can be accessed by index
283297 print (f"Base Index: { base_index } , Data: { data_entry } " )
284298
285299 print ("\n " ) # Separate clusters by a newline for clarity
286300
301+
287302def deduplicate_dataset (
288303 dataset : Type [Dataset ], jaccard_threshold : float = 0.85
289304) -> Tuple [Type [Dataset ], List [List [Dict ]]]:
@@ -343,4 +358,4 @@ def deduplicate_dataset(
343358 print (f"Unique files in duplicate cluster: { len (extreme_dict )} " )
344359 print (f"Filtered dataset size: { len (ds_filter )} " )
345360
346- return ds_filter , duplicate_clusters
361+ return ds_filter , duplicate_clusters
0 commit comments