1717from dataclasses import dataclass , field
1818from typing import Any , Literal
1919
20- import fsspec
21- from fsspec .utils import infer_storage_options
20+ from fsspec .core import url_to_fs
2221from loguru import logger
2322
2423import nemo_curator .stages .text .io .writer .utils as writer_utils
2524from nemo_curator .stages .base import ProcessingStage
2625from nemo_curator .tasks import DocumentBatch , FileGroupTask
26+ from nemo_curator .utils .client_utils import is_remote_url
2727from nemo_curator .utils .file_utils import check_output_mode
2828
2929
@@ -39,26 +39,14 @@ class BaseWriter(ProcessingStage[DocumentBatch, FileGroupTask], ABC):
3939 file_extension : str
4040 write_kwargs : dict [str , Any ] = field (default_factory = dict )
4141 fields : list [str ] | None = None
42- mode : Literal ["ignore" , "overwrite" , "append" , "error" ] = "ignore"
4342 name : str = "BaseWriter"
44- _fs_path : str = field (init = False , repr = False , default = "" )
45- _protocol : str = field (init = False , repr = False , default = "file" )
46- _has_explicit_protocol : bool = field (init = False , repr = False , default = False )
43+ mode : Literal ["ignore" , "overwrite" , "append" , "error" ] = "ignore"
4744 append_mode_implemented : bool = False
4845
4946 def __post_init__ (self ):
50- # Determine protocol and normalized filesystem path
51- path_opts = infer_storage_options (self .path )
52- protocol = path_opts .get ("protocol" , "file" )
53- self ._protocol = protocol or "file"
54- # Track if the user provided an explicit URL-style protocol in the path
55- self ._has_explicit_protocol = "://" in self .path
56- # Use the filesystem-native path (no protocol) for fs operations
57- self ._fs_path = path_opts .get ("path" , self .path )
58-
59- # Only pass user-provided storage options to fsspec
47+ # Use fsspec's url_to_fs to get both filesystem and normalized path
6048 self .storage_options = (self .write_kwargs or {}).get ("storage_options" , {})
61- self .fs = fsspec . filesystem ( protocol , ** self .storage_options )
49+ self .fs , self . _fs_path = url_to_fs ( self . path , ** self .storage_options )
6250 check_output_mode (self .mode , self .fs , self ._fs_path , append_mode_implemented = self .append_mode_implemented )
6351
6452 def inputs (self ) -> tuple [list [str ], list [str ]]:
@@ -95,17 +83,20 @@ def process(self, task: DocumentBatch) -> FileGroupTask:
9583 file_extension = self .get_file_extension ()
9684 file_path = self .fs .sep .join ([self ._fs_path , f"{ filename } .{ file_extension } " ])
9785
86+ # For remote URLs, restore the protocol prefix so downstream code can infer the filesystem
87+ file_path_with_protocol = self .fs .unstrip_protocol (file_path ) if is_remote_url (self .path ) else file_path
88+
9889 if self .fs .exists (file_path ):
99- logger .debug (f"File { file_path } already exists, overwriting it" )
90+ logger .debug (f"File { file_path_with_protocol } already exists, overwriting it" )
10091
101- self .write_data (task , file_path )
102- logger .debug (f"Written { task .num_items } records to { file_path } " )
92+ self .write_data (task , file_path_with_protocol )
93+ logger .debug (f"Written { task .num_items } records to { file_path_with_protocol } " )
10394
104- # Create FileGroupTask with written files
95+ # Create FileGroupTask with written files using the full protocol-prefixed path
10596 return FileGroupTask (
10697 task_id = task .task_id ,
10798 dataset_name = task .dataset_name ,
108- data = [file_path ],
99+ data = [file_path_with_protocol ],
109100 _metadata = {
110101 ** task ._metadata ,
111102 "format" : self .get_file_extension (),
0 commit comments