Skip to content

Commit d377992

Browse files
Fix Pairwise IO and IdentifyDuplicates in SemDedup for Cloud I/O (#1253)
1 parent 0803323 commit d377992

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

nemo_curator/stages/deduplication/semantic/identify_duplicates.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,23 @@ def process_batch(self, tasks: list[FileGroupTask]) -> list[FileGroupTask]:
8888
return []
8989

9090
all_files = [file for task in tasks for file in task.data]
91-
# Read using filters
92-
df: pd.DataFrame = pd.read_parquet(
93-
all_files,
94-
storage_options=self.input_storage_options,
95-
**self.read_kwargs,
96-
filters=[("cosine_sim_score", ">=", 1.0 - self.eps)],
97-
engine="pyarrow",
98-
)[["id"]] # TODO: If we want we can add other columns
91+
# We read using filters
92+
# We read file by file since list[files] when files are remote urls can fail
93+
# See https://github.com/pandas-dev/pandas/issues/62922
94+
df: pd.DataFrame = pd.concat(
95+
(
96+
pd.read_parquet(
97+
f,
98+
storage_options=self.input_storage_options,
99+
**self.read_kwargs,
100+
filters=[("cosine_sim_score", ">=", 1.0 - self.eps)],
101+
columns=["id"],
102+
engine="pyarrow",
103+
)
104+
for f in all_files
105+
),
106+
ignore_index=True,
107+
)
99108
# Write out sorted and with multiple row groups
100109
df.sort_values("id", inplace=True) # noqa: PD002
101110

nemo_curator/stages/deduplication/semantic/pairwise_io.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from nemo_curator.stages.base import ProcessingStage
2222
from nemo_curator.stages.resources import Resources
2323
from nemo_curator.tasks import FileGroupTask, _EmptyTask
24+
from nemo_curator.utils.client_utils import is_remote_url
2425
from nemo_curator.utils.file_utils import get_all_file_paths_under, get_fs, infer_dataset_name_from_path
2526

2627
if TYPE_CHECKING:
@@ -52,6 +53,7 @@ def __init__(
5253
self.name = "pairwise_file_partitioning"
5354
self.resources = Resources(cpus=0.5)
5455
self.fs: AbstractFileSystem | None = None
56+
self.path_normalizer = lambda x: x
5557

5658
def inputs(self) -> tuple[list[str], list[str]]:
5759
return ["data"], []
@@ -61,6 +63,7 @@ def outputs(self) -> tuple[list[str], list[str]]:
6163

6264
def setup(self, _: WorkerMetadata | None = None) -> None:
6365
self.fs = get_fs(self.input_path, storage_options=self.storage_options)
66+
self.path_normalizer = self.fs.unstrip_protocol if is_remote_url(self.input_path) else (lambda x: x)
6467

6568
def ray_stage_spec(self) -> dict[str, Any]:
6669
"""Ray stage specification for this stage."""
@@ -83,7 +86,7 @@ def process(self, _: _EmptyTask) -> list[FileGroupTask]:
8386
# Extract centroid ID from directory name (e.g., "centroid=0" -> 0)
8487
if "centroid=" in entry:
8588
centroid_id = int(entry.split("centroid=")[-1])
86-
centroid_dirs[centroid_id] = entry
89+
centroid_dirs[centroid_id] = self.path_normalizer(entry)
8790

8891
logger.debug(
8992
f"Found {len(centroid_dirs)} centroid directories e.g. {next(iter(centroid_dirs.values())) if centroid_dirs else None}"

0 commit comments

Comments
 (0)