Skip to content

Commit 0803323

Browse files
Fix ParquetReader and *Writer for Cloud I/O (#1249)
1 parent aab1dcc commit 0803323

File tree

2 files changed

+17
-23
lines changed

2 files changed

+17
-23
lines changed

nemo_curator/stages/text/io/reader/parquet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@ def read_data(
5959
if "dtype_backend" not in read_kwargs:
6060
update_kwargs["dtype_backend"] = "pyarrow"
6161
read_kwargs.update(update_kwargs)
62-
return pd.read_parquet(paths, **read_kwargs)
62+
return pd.concat(
63+
(pd.read_parquet(path, **read_kwargs) for path in paths),
64+
ignore_index=True,
65+
)
6366

6467

6568
@dataclass

nemo_curator/stages/text/io/writer/base.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
from dataclasses import dataclass, field
1818
from typing import Any, Literal
1919

20-
import fsspec
21-
from fsspec.utils import infer_storage_options
20+
from fsspec.core import url_to_fs
2221
from loguru import logger
2322

2423
import nemo_curator.stages.text.io.writer.utils as writer_utils
2524
from nemo_curator.stages.base import ProcessingStage
2625
from nemo_curator.tasks import DocumentBatch, FileGroupTask
26+
from nemo_curator.utils.client_utils import is_remote_url
2727
from nemo_curator.utils.file_utils import check_output_mode
2828

2929

@@ -39,26 +39,14 @@ class BaseWriter(ProcessingStage[DocumentBatch, FileGroupTask], ABC):
3939
file_extension: str
4040
write_kwargs: dict[str, Any] = field(default_factory=dict)
4141
fields: list[str] | None = None
42-
mode: Literal["ignore", "overwrite", "append", "error"] = "ignore"
4342
name: str = "BaseWriter"
44-
_fs_path: str = field(init=False, repr=False, default="")
45-
_protocol: str = field(init=False, repr=False, default="file")
46-
_has_explicit_protocol: bool = field(init=False, repr=False, default=False)
43+
mode: Literal["ignore", "overwrite", "append", "error"] = "ignore"
4744
append_mode_implemented: bool = False
4845

4946
def __post_init__(self):
50-
# Determine protocol and normalized filesystem path
51-
path_opts = infer_storage_options(self.path)
52-
protocol = path_opts.get("protocol", "file")
53-
self._protocol = protocol or "file"
54-
# Track if the user provided an explicit URL-style protocol in the path
55-
self._has_explicit_protocol = "://" in self.path
56-
# Use the filesystem-native path (no protocol) for fs operations
57-
self._fs_path = path_opts.get("path", self.path)
58-
59-
# Only pass user-provided storage options to fsspec
47+
# Use fsspec's url_to_fs to get both filesystem and normalized path
6048
self.storage_options = (self.write_kwargs or {}).get("storage_options", {})
61-
self.fs = fsspec.filesystem(protocol, **self.storage_options)
49+
self.fs, self._fs_path = url_to_fs(self.path, **self.storage_options)
6250
check_output_mode(self.mode, self.fs, self._fs_path, append_mode_implemented=self.append_mode_implemented)
6351

6452
def inputs(self) -> tuple[list[str], list[str]]:
@@ -95,17 +83,20 @@ def process(self, task: DocumentBatch) -> FileGroupTask:
9583
file_extension = self.get_file_extension()
9684
file_path = self.fs.sep.join([self._fs_path, f"{filename}.{file_extension}"])
9785

86+
# For remote URLs, restore the protocol prefix so downstream code can infer the filesystem
87+
file_path_with_protocol = self.fs.unstrip_protocol(file_path) if is_remote_url(self.path) else file_path
88+
9889
if self.fs.exists(file_path):
99-
logger.debug(f"File {file_path} already exists, overwriting it")
90+
logger.debug(f"File {file_path_with_protocol} already exists, overwriting it")
10091

101-
self.write_data(task, file_path)
102-
logger.debug(f"Written {task.num_items} records to {file_path}")
92+
self.write_data(task, file_path_with_protocol)
93+
logger.debug(f"Written {task.num_items} records to {file_path_with_protocol}")
10394

104-
# Create FileGroupTask with written files
95+
# Create FileGroupTask with written files using the full protocol-prefixed path
10596
return FileGroupTask(
10697
task_id=task.task_id,
10798
dataset_name=task.dataset_name,
108-
data=[file_path],
99+
data=[file_path_with_protocol],
109100
_metadata={
110101
**task._metadata,
111102
"format": self.get_file_extension(),

0 commit comments

Comments
 (0)