Skip to content

Commit 70a3346

Browse files
Add batched pairwise similarity method for Semantic Dedup (#581)
* add batched similarity Signed-off-by: Praateek <praateekm@gmail.com> * pre-commit Signed-off-by: Praateek <praateekm@gmail.com> * increase tolerance Signed-off-by: Praateek <praateekm@gmail.com> * uncomment Signed-off-by: Praateek <praateekm@gmail.com> * pr review + sad that we can only run tests on gpu Signed-off-by: Praateek <praateekm@gmail.com> * pc Signed-off-by: Praateek <praateekm@gmail.com> * pr suggestions Signed-off-by: Praateek <praateekm@gmail.com> * another test config Signed-off-by: Praateek <praateekm@gmail.com> * reduce tolerance for random array Signed-off-by: Praateek <praateekm@gmail.com> --------- Signed-off-by: Praateek <praateekm@gmail.com>
1 parent 242c7f4 commit 70a3346

File tree

11 files changed

+165
-28
lines changed

11 files changed

+165
-28
lines changed

config/sem_dedup_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ clustering_save_loc: "clustering_results"
1919
random_state: 1234
2020
sim_metric: "cosine"
2121
which_to_keep: "hard"
22+
batched_cosine_similarity: 1024
2223
sort_clusters: true
2324
kmeans_with_cos_dist: false
2425
clustering_input_partition_size: "2gb"

docs/user-guide/semdedup.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ Semantic deduplication in NeMo Curator can be configured using a YAML file. Here
5757
random_state: 1234
5858
sim_metric: "cosine"
5959
which_to_keep: "hard"
60+
batched_cosine_similarity: 1024
6061
sort_clusters: true
6162
kmeans_with_cos_dist: false
6263
clustering_input_partition_size: "2gb"
@@ -209,6 +210,7 @@ Use Individual Components
209210
id_column="doc_id",
210211
id_column_type="str",
211212
which_to_keep="hard",
213+
batched_cosine_similarity=1024,
212214
output_dir="path/to/output/deduped",
213215
logger="path/to/log/dir"
214216
)
@@ -257,7 +259,7 @@ Key parameters in the configuration file include:
257259
- ``n_clusters``: Number of clusters for k-means clustering.
258260
- ``eps_to_extract``: Deduplication threshold. Higher values result in more aggressive deduplication.
259261
- ``which_to_keep``: Strategy for choosing which duplicate to keep ("hard" or "soft").
260-
262+
- ``batched_cosine_similarity``: Whether to use batched cosine similarity (has less memory usage, O(N*B) where B is the batch size) or vanilla cosine similarity (O(N^2) memory usage).
261263
-----------------------------------------
262264
Output
263265
-----------------------------------------

nemo_curator/modules/config.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import warnings
1616
from dataclasses import dataclass, field
17-
from typing import List, Optional
17+
from typing import List, Optional, Union
1818

1919
import yaml
2020

@@ -178,6 +178,9 @@ class SemDedupConfig(BaseConfig):
178178
Default is "cosine".
179179
which_to_keep (str): Method to determine which duplicates to keep.
180180
Default is "hard".
181+
batched_cosine_similarity (Union[bool, int]): Whether to use batched cosine similarity (has less memory usage).
182+
Default is 1024. When False or 0, no batching is used and memory requirements are O(N^2) where N is the number of items in the cluster.
183+
When True, batch size is set to 1024 and memory requirements are O(N*B) where N is the number of items in the cluster and B is the batch size.
181184
sort_clusters (bool): Whether to sort clusters. Default is True.
182185
kmeans_with_cos_dist (bool): Whether or not to use KMeans with cosine distance.
183186
Default is False.
@@ -199,6 +202,7 @@ class SemDedupConfig(BaseConfig):
199202
embedding_batch_size: int = 128
200203
embeddings_save_loc: str = "embeddings"
201204
embedding_max_mem_gb: Optional[int] = None
205+
202206
# Options: "mean_pooling", "last_token"
203207
embedding_pooling_strategy: str = "mean_pooling"
204208
embedding_column: str = "embeddings"
@@ -212,6 +216,7 @@ class SemDedupConfig(BaseConfig):
212216
random_state: int = 1234
213217
sim_metric: str = "cosine"
214218
which_to_keep: str = "hard"
219+
batched_cosine_similarity: Union[bool, int] = 1024
215220
sort_clusters: bool = True
216221
kmeans_with_cos_dist: bool = False
217222
clustering_input_partition_size: str = "2gb"
@@ -230,3 +235,12 @@ def __post_init__(self):
230235
raise ValueError(
231236
f"Epsilon to extract {self.eps_to_extract} must be in eps_thresholds {self.eps_thresholds}"
232237
)
238+
239+
# Convert bool to int
240+
if isinstance(self.batched_cosine_similarity, bool):
241+
if self.batched_cosine_similarity:
242+
self.batched_cosine_similarity = 1024
243+
else:
244+
self.batched_cosine_similarity = 0
245+
if not isinstance(self.batched_cosine_similarity, int):
246+
raise ValueError("batched_cosine_similarity must be an integer")

nemo_curator/modules/semantic_dedup/semanticclusterleveldedup.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
from nemo_curator.datasets import DocumentDataset
2525
from nemo_curator.log import create_logger
26-
from nemo_curator.modules.config import SemDedupConfig
2726
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
2827
from nemo_curator.utils.file_utils import expand_outdir_and_mkdir
2928
from nemo_curator.utils.semdedup_utils import (
@@ -43,6 +42,7 @@ def __init__(
4342
which_to_keep: str = "hard",
4443
output_dir: str = "./clustering_results",
4544
embedding_column: str = "embeddings",
45+
batched_cosine_similarity: int = 1024,
4646
logger: Union[logging.Logger, str] = "./",
4747
profile_dir: Optional[str] = None,
4848
) -> None:
@@ -64,6 +64,9 @@ def __init__(
6464
Default is "./clustering_results".
6565
embedding_column (str): The column name that stores the embeddings.
6666
Default is "embeddings".
67+
batched_cosine_similarity (int): Whether to use batched cosine similarity (has less memory usage).
68+
Default is 1024. When greater than 0, batching is used and memory requirements are O(N*B) where N is the number of items in the cluster and B is the batch size.
69+
When less than or equal to 0, no batching is used and memory requirements are O(N^2) where N is the number of items in the cluster.
6770
logger (Union[logging.Logger, str]): Existing logger to log to, or a path to a log directory.
6871
Default is "./".
6972
profile_dir (Optional[str]): If specified, directory to write Dask profile.
@@ -82,6 +85,7 @@ def __init__(
8285
)
8386
self.computed_semantic_match_dfs = False
8487
self.embedding_column = embedding_column
88+
self.batched_cosine_similarity = batched_cosine_similarity
8589
self.logger = self._setup_logger(logger)
8690
self.profile_dir = profile_dir
8791

@@ -144,6 +148,7 @@ def compute_semantic_match_dfs(
144148
output_dir=self.semdedup_pruning_tables_dir,
145149
embedding_col=self.embedding_column,
146150
which_to_keep=self.which_to_keep,
151+
batched_cosine_similarity=self.batched_cosine_similarity,
147152
)
148153
)
149154
tasks.compute()

nemo_curator/modules/semantic_dedup/semdedup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def __init__(
9191
id_column=id_column,
9292
id_column_type=id_column_type,
9393
which_to_keep=config.which_to_keep,
94+
batched_cosine_similarity=config.batched_cosine_similarity,
9495
output_dir=os.path.join(cache_dir, config.clustering_save_loc),
9596
embedding_column=config.embedding_column,
9697
logger=logger,

nemo_curator/scripts/semdedup/extract_dedup_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def main(args):
5454
id_column=args.id_column,
5555
id_column_type=args.id_column_type,
5656
which_to_keep=semdedup_config.which_to_keep,
57+
batched_cosine_similarity=semdedup_config.batched_cosine_similarity,
5758
output_dir=os.path.join(
5859
semdedup_config.cache_dir, semdedup_config.clustering_save_loc
5960
),

nemo_curator/utils/semdedup_utils.py

Lines changed: 67 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import random
1919
import shutil
2020
import time
21-
from typing import List, Optional, Tuple
21+
from typing import List, Literal, Optional, Tuple
2222

2323
import cudf
2424
import dask.bag as db
@@ -179,25 +179,60 @@ def rank_within_cluster(
179179
return len(cluster_ids) - missing_files
180180

181181

182-
def _semdedup(
183-
cluster_reps: torch.Tensor, device: str
182+
def pairwise_cosine_similarity(
183+
cluster_reps: torch.Tensor,
184+
device: Literal["cuda", "cpu"],
184185
) -> Tuple[torch.Tensor, List[int]]:
185-
# compute pairwise cos sim between cluster items,
186-
# then replace to diagonal with zeros to ignore self similarity
187-
cluster_reps.to(device)
188-
pair_w_sim_matrix = cluster_reps @ (cluster_reps.T)
186+
"""
187+
Compute pairwise cosine similarity between cluster items,
188+
then replace to diagonal with zeros to ignore self similarity
189+
"""
190+
# Move to device
191+
cluster_reps = cluster_reps.to(device)
192+
# Compute pairwise cosine similarity
193+
pairwise_sim_matrix = torch.mm(cluster_reps, cluster_reps.T)
189194
del cluster_reps
190-
pair_w_sim_matrix.fill_diagonal_(0.0)
191-
assert pair_w_sim_matrix.shape[0] == pair_w_sim_matrix.shape[1]
192-
193-
triu_sim_mat = torch.triu(pair_w_sim_matrix, diagonal=1)
194-
195-
M = torch.max(triu_sim_mat, dim=0)[0].cpu()
196-
M1 = torch.max(triu_sim_mat, dim=0)[1].cpu().numpy().tolist()
197-
return M, M1
198-
199-
200-
def get_cluster_reps(
195+
# Get upper triangular matrix
196+
assert pairwise_sim_matrix.shape[0] == pairwise_sim_matrix.shape[1]
197+
triu_sim_mat = torch.triu(pairwise_sim_matrix, diagonal=1)
198+
# Get max similarity and indices
199+
max_values_and_indices = torch.max(triu_sim_mat, dim=0)
200+
max_similarity = max_values_and_indices[0].cpu()
201+
max_indices = max_values_and_indices[1].cpu().numpy().tolist()
202+
return max_similarity, max_indices
203+
204+
205+
def pairwise_cosine_similarity_batched(
206+
cluster_reps: torch.Tensor,
207+
device: Literal["cuda", "cpu"],
208+
batch_size: int = 1024,
209+
) -> Tuple[torch.Tensor, List[int]]:
210+
"""
211+
Computes pairwise cosine similarity between cluster items,
212+
then replace to diagonal with zeros to ignore self similarity.
213+
This function is useful for large clusters where the pairwise similarity matrix
214+
does not fit into memory.
215+
We use a batched approach to compute the pairwise similarity matrix in batches.
216+
Memory requirements are O(N*B) where N is the number of items in the cluster and B is the batch size
217+
instead of O(N^2) for the full matrix.
218+
"""
219+
cluster_reps = cluster_reps.to(device)
220+
max_similarity = torch.zeros(cluster_reps.shape[0], device=device)
221+
max_indices = torch.zeros(cluster_reps.shape[0], dtype=torch.int64, device=device)
222+
for start_idx in range(0, cluster_reps.shape[0], batch_size):
223+
end_idx = min(start_idx + batch_size, cluster_reps.shape[0])
224+
batch = cluster_reps[start_idx:end_idx]
225+
pairwise_sim_matrix = torch.mm(cluster_reps, batch.T)
226+
triu_sim_matrix = torch.triu(pairwise_sim_matrix, diagonal=1 - start_idx)
227+
del batch, pairwise_sim_matrix
228+
max_values_and_indices = torch.max(triu_sim_matrix, dim=0)
229+
max_similarity[start_idx:end_idx] = max_values_and_indices[0]
230+
max_indices[start_idx:end_idx] = max_values_and_indices[1]
231+
232+
return max_similarity.cpu(), max_indices.cpu().numpy().tolist()
233+
234+
235+
def get_normalized_cluster_reps(
201236
cluster_id: int,
202237
emb_by_clust_dir: str,
203238
id_col: str,
@@ -220,6 +255,8 @@ def get_cluster_reps(
220255
cluster_reps[embedding_col].list.leaves.values.reshape(len(cluster_reps), -1),
221256
device="cuda",
222257
)
258+
# Normalize embeddings
259+
cluster_reps = cluster_reps / cluster_reps.norm(dim=1, keepdim=True)
223260
return cluster_reps
224261

225262

@@ -233,6 +270,7 @@ def get_semantic_matches_per_cluster(
233270
output_dir: str,
234271
embedding_col: str,
235272
which_to_keep: str,
273+
batched_cosine_similarity: int = 1024,
236274
) -> None:
237275

238276
output_df_file_path = os.path.join(output_dir, f"cluster_{cluster_id}.parquet")
@@ -266,22 +304,26 @@ def get_semantic_matches_per_cluster(
266304

267305
text_ids = cluster_i[:, 0].astype(id_col_type)
268306

269-
cluster_reps = get_cluster_reps(
307+
cluster_reps = get_normalized_cluster_reps(
270308
cluster_id, emb_by_clust_dir, id_col, embedding_col, text_ids
271309
)
272-
M, M1 = _semdedup(cluster_reps, "cuda")
310+
if batched_cosine_similarity > 0:
311+
max_similarity, max_indices = pairwise_cosine_similarity_batched(
312+
cluster_reps, "cuda", batched_cosine_similarity
313+
)
314+
else:
315+
max_similarity, max_indices = pairwise_cosine_similarity(cluster_reps, "cuda")
273316
assert cluster_reps.shape[0] == len(text_ids)
274-
275-
M1_id = [text_ids[m] for m in M1]
317+
max_indices_id = [text_ids[m] for m in max_indices]
276318

277319
points_to_remove_df = cudf.DataFrame()
278320
points_to_remove_df["indices"] = clutser_items_indices
279321
points_to_remove_df["id"] = text_ids
280-
points_to_remove_df["max_id"] = M1_id
281-
points_to_remove_df["cosine_sim_score"] = M.numpy().tolist()
322+
points_to_remove_df["max_id"] = max_indices_id
323+
points_to_remove_df["cosine_sim_score"] = max_similarity.numpy().tolist()
282324

283325
for eps in eps_list:
284-
eps_points_to_remove = M > 1 - eps
326+
eps_points_to_remove = max_similarity > 1 - eps
285327
points_to_remove_df[f"eps={eps}"] = eps_points_to_remove
286328

287329
points_to_remove_df.to_parquet(output_df_file_path)

tests/test_semdedup.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from typing import TYPE_CHECKING, Literal
1516

1617
import numpy as np
1718
import pytest
@@ -29,6 +30,17 @@
2930
EmbeddingCreator = gpu_only_import_from(
3031
"nemo_curator.modules.semantic_dedup.embeddings", "EmbeddingCreator"
3132
)
33+
pairwise_cosine_similarity = gpu_only_import_from(
34+
"nemo_curator.utils.semdedup_utils", "pairwise_cosine_similarity"
35+
)
36+
pairwise_cosine_similarity_batched = gpu_only_import_from(
37+
"nemo_curator.utils.semdedup_utils", "pairwise_cosine_similarity_batched"
38+
)
39+
if TYPE_CHECKING:
40+
from nemo_curator.utils.semdedup_utils import (
41+
pairwise_cosine_similarity,
42+
pairwise_cosine_similarity_batched,
43+
)
3244

3345

3446
@pytest.fixture
@@ -233,3 +245,59 @@ def get_reference_embeddings(
233245
embs.append(normed_emb)
234246

235247
return np.array(embs)
248+
249+
250+
class TestPairwiseCosineSimilarity:
251+
def setup_method(self):
252+
# We create a 5x3 array where each row is a unit vector
253+
# The second and last two rows are the same
254+
input_arr = torch.tensor(
255+
np.asarray(
256+
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [1, 2, 3], [1, 2, 3]],
257+
),
258+
dtype=torch.float32,
259+
)
260+
# Normalize the input array
261+
self.input_arr = input_arr / torch.norm(input_arr, dim=1, keepdim=True)
262+
self.expected_similarity = torch.tensor(
263+
[0.0000, 0.974631, 0.998190, 0.999618, 1.0000, 1.0000]
264+
)
265+
self.expected_indices = [0, 0, 1, 2, 0, 0]
266+
267+
@pytest.mark.parametrize("device", [pytest.param("cuda", marks=pytest.mark.gpu)])
268+
def test_pairwise_cosine_similarity(self, device: Literal["cpu", "cuda"]):
269+
max_similarity, max_indices = pairwise_cosine_similarity(
270+
self.input_arr.to(device), device
271+
)
272+
torch.testing.assert_close(
273+
max_similarity, self.expected_similarity, rtol=1e-6, atol=1e-6
274+
)
275+
assert max_indices == self.expected_indices
276+
277+
@pytest.mark.parametrize("device", [pytest.param("cuda", marks=pytest.mark.gpu)])
278+
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5, 6])
279+
def test_pairwise_cosine_similarity_batched(
280+
self, device: Literal["cpu", "cuda"], batch_size: int
281+
):
282+
max_similarity, max_indices = pairwise_cosine_similarity_batched(
283+
self.input_arr.to(device), device, batch_size
284+
)
285+
torch.testing.assert_close(max_similarity, self.expected_similarity)
286+
assert max_indices == self.expected_indices
287+
288+
@pytest.mark.parametrize("device", [pytest.param("cuda", marks=pytest.mark.gpu)])
289+
@pytest.mark.parametrize("batch_size", [100, 512, 1024, 2048])
290+
def test_pairwise_cosine_similarity_batched_rand_array(
291+
self, device: Literal["cpu", "cuda"], batch_size: int
292+
):
293+
N = 1024
294+
D = 512
295+
rand_arr = torch.randn(N, D, device=device)
296+
max_similarity, max_indices = pairwise_cosine_similarity(rand_arr, device)
297+
max_similarity_batched, max_indices_batched = (
298+
pairwise_cosine_similarity_batched(rand_arr, device, batch_size=batch_size)
299+
)
300+
torch.testing.assert_close(
301+
max_similarity, max_similarity_batched, rtol=1e-5, atol=1e-5
302+
)
303+
assert max_indices == max_indices_batched

tutorials/dapt-curation/code/configs/text_semantic_dedupe_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ clustering_save_loc: "clustering_results"
1818
random_state: 1234
1919
sim_metric: "cosine"
2020
which_to_keep: "hard"
21+
batched_cosine_similarity: 1024
2122
sort_clusters: true
2223
kmeans_with_cos_dist: false
2324
clustering_input_partition_size: "2gb"

tutorials/image-curation/image-curation.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,7 @@
681681
" id_column_type=\"str\",\n",
682682
" embedding_col=\"image_embedding\",\n",
683683
" which_to_keep=\"hard\",\n",
684+
" batched_cosine_similarity=1024,\n",
684685
" output_dir=duplicate_output,\n",
685686
")\n",
686687
"semantic_dedup.compute_semantic_match_dfs([0.01, 0.001])\n",

0 commit comments

Comments
 (0)