|
1 | 1 | """Distributed ParquetDatasource Module.""" |
2 | 2 |
|
3 | 3 | import logging |
4 | | -from typing import Any, Callable, Iterator, List, Optional, Union |
| 4 | +from typing import Any, Callable, Dict, Iterator, List, Optional, Union |
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 | import pyarrow as pa |
8 | 8 |
|
9 | 9 | # fs required to implicitly trigger S3 subsystem initialization |
10 | 10 | import pyarrow.fs # noqa: F401 pylint: disable=unused-import |
11 | 11 | import pyarrow.parquet as pq |
12 | | -from ray import cloudpickle |
| 12 | +from ray import cloudpickle # pylint: disable=wrong-import-order,ungrouped-imports |
| 13 | +from ray.data.block import Block, BlockAccessor, BlockMetadata |
13 | 14 | from ray.data.context import DatasetContext |
14 | | -from ray.data.datasource.datasource import ReadTask |
15 | | -from ray.data.datasource.file_based_datasource import _resolve_paths_and_filesystem |
| 15 | +from ray.data.datasource import BlockWritePathProvider, DefaultBlockWritePathProvider |
| 16 | +from ray.data.datasource.datasource import ReadTask, WriteResult |
| 17 | +from ray.data.datasource.file_based_datasource import ( |
| 18 | + _resolve_paths_and_filesystem, |
| 19 | + _S3FileSystemWrapper, |
| 20 | + _wrap_s3_serialization_workaround, |
| 21 | +) |
16 | 22 | from ray.data.datasource.file_meta_provider import DefaultParquetMetadataProvider, ParquetMetadataProvider |
17 | 23 | from ray.data.datasource.parquet_datasource import ( |
18 | 24 | _deregister_parquet_file_fragment_serialization, |
19 | 25 | _register_parquet_file_fragment_serialization, |
20 | 26 | ) |
21 | 27 | from ray.data.impl.output_buffer import BlockOutputBuffer |
| 28 | +from ray.data.impl.remote_fn import cached_remote_fn |
| 29 | +from ray.types import ObjectRef |
22 | 30 |
|
23 | 31 | from awswrangler._arrow import _add_table_partitions |
24 | 32 |
|
|
29 | 37 | PARQUET_READER_ROW_BATCH_SIZE = 100000 |
30 | 38 |
|
31 | 39 |
|
| 40 | +class UserProvidedKeyBlockWritePathProvider(BlockWritePathProvider): |
| 41 | + """Block write path provider. |
| 42 | +
|
| 43 | + Used when writing single-block datasets into a user-provided S3 key. |
| 44 | + """ |
| 45 | + |
| 46 | + def _get_write_path_for_block( |
| 47 | + self, |
| 48 | + base_path: str, |
| 49 | + *, |
| 50 | + filesystem: Optional["pyarrow.fs.FileSystem"] = None, |
| 51 | + dataset_uuid: Optional[str] = None, |
| 52 | + block: Optional[ObjectRef[Block[Any]]] = None, |
| 53 | + block_index: Optional[int] = None, |
| 54 | + file_format: Optional[str] = None, |
| 55 | + ) -> str: |
| 56 | + return base_path |
| 57 | + |
| 58 | + |
32 | 59 | class ParquetDatasource: |
33 | 60 | """Parquet datasource, for reading and writing Parquet files.""" |
34 | 61 |
|
| 62 | + def __init__(self) -> None: |
| 63 | + self._write_paths: List[str] = [] |
| 64 | + |
35 | 65 | # Original: https://github.com/ray-project/ray/blob/releases/1.13.0/python/ray/data/datasource/parquet_datasource.py |
36 | 66 | def prepare_read( |
37 | 67 | self, |
@@ -135,3 +165,110 @@ def read_pieces(serialized_pieces: str) -> Iterator[pa.Table]: |
135 | 165 | _deregister_parquet_file_fragment_serialization() # type: ignore |
136 | 166 |
|
137 | 167 | return read_tasks |
| 168 | + |
| 169 | + # Original implementation: |
| 170 | + # https://github.com/ray-project/ray/blob/releases/1.13.0/python/ray/data/datasource/file_based_datasource.py |
| 171 | + def do_write( |
| 172 | + self, |
| 173 | + blocks: List[ObjectRef[Block[Any]]], |
| 174 | + _: List[BlockMetadata], |
| 175 | + path: str, |
| 176 | + dataset_uuid: str, |
| 177 | + filesystem: Optional["pyarrow.fs.FileSystem"] = None, |
| 178 | + try_create_dir: bool = True, |
| 179 | + open_stream_args: Optional[Dict[str, Any]] = None, |
| 180 | + block_path_provider: BlockWritePathProvider = DefaultBlockWritePathProvider(), |
| 181 | + write_args_fn: Callable[[], Dict[str, Any]] = lambda: {}, |
| 182 | + _block_udf: Optional[Callable[[Block[Any]], Block[Any]]] = None, |
| 183 | + ray_remote_args: Optional[Dict[str, Any]] = None, |
| 184 | + **write_args: Any, |
| 185 | + ) -> List[ObjectRef[WriteResult]]: |
| 186 | + """Create write tasks for a parquet file datasource.""" |
| 187 | + paths, filesystem = _resolve_paths_and_filesystem(path, filesystem) |
| 188 | + path = paths[0] |
| 189 | + if try_create_dir: |
| 190 | + filesystem.create_dir(path, recursive=True) |
| 191 | + filesystem = _wrap_s3_serialization_workaround(filesystem) |
| 192 | + |
| 193 | + _write_block_to_file = self._write_block |
| 194 | + |
| 195 | + if open_stream_args is None: |
| 196 | + open_stream_args = {} |
| 197 | + |
| 198 | + if ray_remote_args is None: |
| 199 | + ray_remote_args = {} |
| 200 | + |
| 201 | + def write_block(write_path: str, block: Block[Any]) -> str: |
| 202 | + _logger.debug("Writing %s file.", write_path) |
| 203 | + fs: Optional["pyarrow.fs.FileSystem"] = filesystem |
| 204 | + if isinstance(fs, _S3FileSystemWrapper): |
| 205 | + fs = fs.unwrap() # type: ignore |
| 206 | + if _block_udf is not None: |
| 207 | + block = _block_udf(block) |
| 208 | + |
| 209 | + with fs.open_output_stream(write_path, **open_stream_args) as f: |
| 210 | + _write_block_to_file( |
| 211 | + f, |
| 212 | + BlockAccessor.for_block(block), |
| 213 | + writer_args_fn=write_args_fn, |
| 214 | + **write_args, |
| 215 | + ) |
| 216 | + # This is a change from original FileBasedDatasource.do_write that does not return paths |
| 217 | + return write_path |
| 218 | + |
| 219 | + write_block = cached_remote_fn(write_block).options(**ray_remote_args) |
| 220 | + |
| 221 | + file_format = self._file_format() |
| 222 | + write_tasks = [] |
| 223 | + for block_idx, block in enumerate(blocks): |
| 224 | + write_path = block_path_provider( |
| 225 | + path, |
| 226 | + filesystem=filesystem, |
| 227 | + dataset_uuid=dataset_uuid, |
| 228 | + block=block, |
| 229 | + block_index=block_idx, |
| 230 | + file_format=file_format, |
| 231 | + ) |
| 232 | + write_task = write_block.remote(write_path, block) # type: ignore |
| 233 | + write_tasks.append(write_task) |
| 234 | + |
| 235 | + return write_tasks |
| 236 | + |
| 237 | + def on_write_complete(self, write_results: List[Any], **_: Any) -> None: |
| 238 | + """Execute callback on write complete.""" |
| 239 | + _logger.debug("Write complete %s.", write_results) |
| 240 | + # Collect and return all write task paths |
| 241 | + self._write_paths.extend(write_results) |
| 242 | + |
| 243 | + def on_write_failed(self, write_results: List[ObjectRef[Any]], error: Exception, **_: Any) -> None: |
| 244 | + """Execute callback on write failed.""" |
| 245 | + _logger.debug("Write failed %s.", write_results) |
| 246 | + raise error |
| 247 | + |
| 248 | + def get_write_paths(self) -> List[str]: |
| 249 | + """Return S3 paths of where the results have been written.""" |
| 250 | + return self._write_paths |
| 251 | + |
| 252 | + def _write_block( |
| 253 | + self, |
| 254 | + f: "pyarrow.NativeFile", |
| 255 | + block: BlockAccessor[Any], |
| 256 | + writer_args_fn: Callable[[], Dict[str, Any]] = lambda: {}, |
| 257 | + **writer_args: Any, |
| 258 | + ) -> None: |
| 259 | + """Write a block to S3.""" |
| 260 | + import pyarrow.parquet as pq # pylint: disable=import-outside-toplevel,redefined-outer-name,reimported |
| 261 | + |
| 262 | + writer_args = _resolve_kwargs(writer_args_fn, **writer_args) |
| 263 | + pq.write_table(block.to_arrow(), f, **writer_args) |
| 264 | + |
| 265 | + def _file_format(self) -> str: |
| 266 | + """Return file format.""" |
| 267 | + return "parquet" |
| 268 | + |
| 269 | + |
| 270 | +def _resolve_kwargs(kwargs_fn: Callable[[], Dict[str, Any]], **kwargs: Any) -> Dict[str, Any]: |
| 271 | + if kwargs_fn: |
| 272 | + kwarg_overrides = kwargs_fn() |
| 273 | + kwargs.update(kwarg_overrides) |
| 274 | + return kwargs |
0 commit comments