Skip to content

Commit 8f2e26d

Browse files
Remove conditions for Minhash #679
1 parent c74b436 commit 8f2e26d

File tree

3 files changed

+12
-60
lines changed

3 files changed

+12
-60
lines changed

nemo_curator/_compat.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,11 @@
3939
except (ImportError, TypeError):
4040
CURRENT_CUDF_VERSION = parse_version("24.10.0")
4141

42-
# TODO: remove this once 25.02 becomes the base version of cudf in nemo-curator
43-
44-
# minhash in < 24.12 used to have a minhash(txt) api which was deprecated in favor of
45-
# minhash(a, b) in 25.02 (in 24.12, minhash_permuted(a,b) was introduced)
46-
MINHASH_DEPRECATED_API = CURRENT_CUDF_VERSION.base_version < parse_version("24.12").base_version
47-
MINHASH_PERMUTED_AVAILABLE = (CURRENT_CUDF_VERSION.major == 24) & (CURRENT_CUDF_VERSION.minor == 12) # noqa: PLR2004
48-
4942
# TODO: remove when dask min version gets bumped
5043
DASK_SHUFFLE_METHOD_ARG = _dask_version > parse_version("2024.1.0")
5144
DASK_P2P_ERROR = _dask_version < parse_version("2023.10.0")
5245
DASK_SHUFFLE_CAST_DTYPE = _dask_version > parse_version("2023.12.0")
53-
DASK_CUDF_PARQUET_READ_INCONSISTENT_SCHEMA = _dask_version > parse_version("2024.12")
46+
DASK_CUDF_PARQUET_READ_INCONSISTENT_SCHEMA = _dask_version > parse_version("2025.2.0")
5447

5548
# Query-planning check (and cache)
5649
_DASK_QUERY_PLANNING_ENABLED = None

nemo_curator/modules/fuzzy_dedup/minhash.py

Lines changed: 11 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import dask_cudf
2828
import numpy as np
2929

30-
from nemo_curator._compat import MINHASH_DEPRECATED_API, MINHASH_PERMUTED_AVAILABLE
3130
from nemo_curator.datasets import DocumentDataset
3231
from nemo_curator.log import create_logger
3332
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
@@ -71,14 +70,11 @@ def __init__( # noqa: PLR0913
7170
self.num_hashes = num_hashes
7271
self.char_ngram = char_ngrams
7372

74-
if MINHASH_DEPRECATED_API:
75-
self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed)
76-
else:
77-
self.seeds = self.generate_hash_permutation_seeds(
78-
bit_width=64 if use_64bit_hash else 32,
79-
n_permutations=self.num_hashes,
80-
seed=seed,
81-
)
73+
self.seeds = self.generate_hash_permutation_seeds(
74+
bit_width=64 if use_64bit_hash else 32,
75+
n_permutations=self.num_hashes,
76+
seed=seed,
77+
)
8278

8379
self.minhash_method = self.minhash64 if use_64bit_hash else self.minhash32
8480
self.id_field = id_field
@@ -98,13 +94,6 @@ def __init__( # noqa: PLR0913
9894
else:
9995
self._logger = logger
10096

101-
def generate_seeds(self, n_seeds: int = 260, seed: int = 0) -> np.ndarray:
102-
"""
103-
Generate seeds for all minhash permutations based on the given seed.
104-
"""
105-
gen = np.random.RandomState(seed)
106-
return gen.randint(0, 1e6, size=n_seeds)
107-
10897
def generate_hash_permutation_seeds(self, bit_width: int, n_permutations: int = 260, seed: int = 0) -> np.ndarray:
10998
"""
11099
Generate seeds for all minhash permutations based on the given seed.
@@ -141,24 +130,10 @@ def minhash32(self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int) -> cud
141130
msg = "Expected data of type cudf.Series"
142131
raise TypeError(msg)
143132

144-
if MINHASH_DEPRECATED_API:
145-
warnings.warn(
146-
"Using an outdated minhash implementation, please update to cuDF version 24.12 "
147-
"or later for improved performance. "
148-
"Install the latest version of cuDF using `pip install curator[cuda12x_nightly]`",
149-
category=FutureWarning,
150-
stacklevel=2,
151-
)
152-
seeds = cudf.Series(seeds, dtype="uint32")
153-
return ser.str.minhash(seeds=seeds, width=char_ngram)
154-
else:
155-
seeds_a = cudf.Series(seeds[:, 0], dtype="uint32")
156-
seeds_b = cudf.Series(seeds[:, 1], dtype="uint32")
133+
seeds_a = cudf.Series(seeds[:, 0], dtype="uint32")
134+
seeds_b = cudf.Series(seeds[:, 1], dtype="uint32")
157135

158-
if MINHASH_PERMUTED_AVAILABLE:
159-
return ser.str.minhash_permuted(a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram)
160-
else:
161-
return ser.str.minhash(a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram)
136+
return ser.str.minhash(a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram)
162137

163138
def minhash64(self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int) -> cudf.Series:
164139
"""
@@ -167,24 +142,9 @@ def minhash64(self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int) -> cud
167142
if not isinstance(ser, cudf.Series):
168143
msg = "Expected data of type cudf.Series"
169144
raise TypeError(msg)
170-
if MINHASH_DEPRECATED_API:
171-
warnings.warn(
172-
"Using an outdated minhash implementation, please update to cuDF version 24.12 "
173-
"or later for improved performance. "
174-
"Install the latest version of cuDF using `pip install curator[cuda12x_nightly]`",
175-
category=FutureWarning,
176-
stacklevel=2,
177-
)
178-
seeds = cudf.Series(seeds, dtype="uint64")
179-
return ser.str.minhash64(seeds=seeds, width=char_ngram)
180-
else:
181-
seeds_a = cudf.Series(seeds[:, 0], dtype="uint64")
182-
seeds_b = cudf.Series(seeds[:, 1], dtype="uint64")
183-
184-
if MINHASH_PERMUTED_AVAILABLE:
185-
return ser.str.minhash64_permuted(a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram)
186-
else:
187-
return ser.str.minhash64(a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram)
145+
seeds_a = cudf.Series(seeds[:, 0], dtype="uint64")
146+
seeds_b = cudf.Series(seeds[:, 1], dtype="uint64")
147+
return ser.str.minhash64(a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram)
188148

189149
def __call__(self, dataset: DocumentDataset) -> str | DocumentDataset:
190150
"""

tests/test_read_data.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,6 @@ def test_read_data_different_columns_files_per_partition(
567567
assert len(df) == NUM_FILES * NUM_RECORDS
568568

569569

570-
@pytest.mark.skip(reason="Parquet tests are failing after upgrading to RAPIDS 25.02")
571570
@pytest.mark.parametrize(
572571
("backend", "file_type"),
573572
[

0 commit comments

Comments
 (0)