77import pyarrow
88import ray
99from ray .data ._internal .delegating_block_builder import DelegatingBlockBuilder
10- from ray .data .block import Block , BlockAccessor
10+ from ray .data .block import Block , BlockAccessor , BlockMetadata
1111from ray .data .datasource .datasource import WriteResult
1212from ray .data .datasource .file_based_datasource import (
1313 BlockWritePathProvider ,
@@ -64,11 +64,24 @@ def __init__(self) -> None:
6464 def _read_file (self , f : pyarrow .NativeFile , path : str , ** reader_args : Any ) -> pd .DataFrame :
6565 raise NotImplementedError ()
6666
67- def do_write (
67+ def do_write ( # pylint: disable=arguments-differ
6868 self ,
6969 blocks : List [ObjectRef [pd .DataFrame ]],
70- * args : Any ,
71- ** kwargs : Any ,
70+ metadata : List [BlockMetadata ],
71+ path : str ,
72+ dataset_uuid : str ,
73+ filesystem : Optional [pyarrow .fs .FileSystem ] = None ,
74+ try_create_dir : bool = True ,
75+ open_stream_args : Optional [Dict [str , Any ]] = None ,
76+ block_path_provider : BlockWritePathProvider = DefaultBlockWritePathProvider (),
77+ write_args_fn : Callable [[], Dict [str , Any ]] = lambda : {},
78+ _block_udf : Optional [Callable [[pd .DataFrame ], pd .DataFrame ]] = None ,
79+ ray_remote_args : Optional [Dict [str , Any ]] = None ,
80+ s3_additional_kwargs : Optional [Dict [str , str ]] = None ,
81+ pandas_kwargs : Optional [Dict [str , Any ]] = None ,
82+ compression : Optional [str ] = None ,
83+ mode : str = "wb" ,
84+ ** write_args : Any ,
7285 ) -> List [ObjectRef [WriteResult ]]:
7386 """Create and return write tasks for a file-based datasource.
7487
@@ -77,21 +90,53 @@ def do_write(
7790 plan allowing query optimisation ("fuse" with other operations). The change is not backward-compatible
7891 with earlier versions still attempting to call do_write().
7992 """
80- write_tasks = []
81- path : str = kwargs .pop ("path" )
82- dataset_uuid : str = kwargs .pop ("dataset_uuid" )
83- ray_remote_args : Dict [str , Any ] = kwargs .pop ("ray_remote_args" ) or {}
93+ _write_block_to_file = self ._write_block
94+
95+ if ray_remote_args is None :
96+ ray_remote_args = {}
97+
98+ if pandas_kwargs is None :
99+ pandas_kwargs = {}
84100
85- _write = ray_remote (** ray_remote_args )(self .write )
101+ if not compression :
102+ compression = pandas_kwargs .get ("compression" )
103+
104+ def write_block (write_path : str , block : pd .DataFrame ) -> str :
105+ if _block_udf is not None :
106+ block = _block_udf (block )
107+
108+ with open_s3_object (
109+ path = write_path ,
110+ mode = mode ,
111+ use_threads = False ,
112+ s3_additional_kwargs = s3_additional_kwargs ,
113+ encoding = write_args .get ("encoding" ),
114+ newline = write_args .get ("newline" ),
115+ ) as f :
116+ _write_block_to_file (
117+ f ,
118+ BlockAccessor .for_block (block ),
119+ pandas_kwargs = pandas_kwargs ,
120+ compression = compression ,
121+ ** write_args ,
122+ )
123+ return write_path
124+
125+ write_block_fn = ray_remote (** ray_remote_args )(write_block )
126+
127+ file_suffix = self ._get_file_suffix (self ._FILE_EXTENSION , compression )
128+ write_tasks = []
86129
87130 for block_idx , block in enumerate (blocks ):
88- write_task = _write (
89- [block ],
90- TaskContext (task_idx = block_idx ),
131+ write_path = block_path_provider (
91132 path ,
92- dataset_uuid ,
93- ** kwargs ,
133+ filesystem = filesystem ,
134+ dataset_uuid = dataset_uuid ,
135+ block = block ,
136+ block_index = block_idx ,
137+ file_format = file_suffix ,
94138 )
139+ write_task = write_block_fn (write_path , block )
95140 write_tasks .append (write_task )
96141
97142 return write_tasks
0 commit comments