2121from nemo_curator .stages .base import ProcessingStage
2222from nemo_curator .stages .resources import Resources
2323from nemo_curator .tasks import FileGroupTask , _EmptyTask
24+ from nemo_curator .utils .client_utils import is_remote_url
2425from nemo_curator .utils .file_utils import get_all_file_paths_under , get_fs , infer_dataset_name_from_path
2526
2627if TYPE_CHECKING :
@@ -52,6 +53,7 @@ def __init__(
5253 self .name = "pairwise_file_partitioning"
5354 self .resources = Resources (cpus = 0.5 )
5455 self .fs : AbstractFileSystem | None = None
56+ self .path_normalizer = lambda x : x
5557
5658 def inputs (self ) -> tuple [list [str ], list [str ]]:
5759 return ["data" ], []
@@ -61,6 +63,7 @@ def outputs(self) -> tuple[list[str], list[str]]:
6163
6264 def setup (self , _ : WorkerMetadata | None = None ) -> None :
6365 self .fs = get_fs (self .input_path , storage_options = self .storage_options )
66+ self .path_normalizer = self .fs .unstrip_protocol if is_remote_url (self .input_path ) else (lambda x : x )
6467
6568 def ray_stage_spec (self ) -> dict [str , Any ]:
6669 """Ray stage specification for this stage."""
@@ -83,7 +86,7 @@ def process(self, _: _EmptyTask) -> list[FileGroupTask]:
8386 # Extract centroid ID from directory name (e.g., "centroid=0" -> 0)
8487 if "centroid=" in entry :
8588 centroid_id = int (entry .split ("centroid=" )[- 1 ])
86- centroid_dirs [centroid_id ] = entry
89+ centroid_dirs [centroid_id ] = self . path_normalizer ( entry )
8790
8891 logger .debug (
8992 f"Found { len (centroid_dirs )} centroid directories e.g. { next (iter (centroid_dirs .values ())) if centroid_dirs else None } "
0 commit comments