|
| 1 | +"""Ray PandasFileBasedDatasource Module.""" |
| 2 | +import logging |
| 3 | +from typing import Any, Callable, Dict, List, Optional |
| 4 | + |
| 5 | +import pandas as pd |
| 6 | +import pyarrow |
| 7 | +from ray.data._internal.pandas_block import PandasBlockAccessor |
| 8 | +from ray.data._internal.remote_fn import cached_remote_fn |
| 9 | +from ray.data.block import Block, BlockAccessor, BlockMetadata |
| 10 | +from ray.data.datasource.datasource import WriteResult |
| 11 | +from ray.data.datasource.file_based_datasource import ( |
| 12 | + BlockWritePathProvider, |
| 13 | + DefaultBlockWritePathProvider, |
| 14 | + FileBasedDatasource, |
| 15 | +) |
| 16 | +from ray.types import ObjectRef |
| 17 | + |
| 18 | +from awswrangler.s3._fs import open_s3_object |
| 19 | +from awswrangler.s3._write import _COMPRESSION_2_EXT |
| 20 | + |
| 21 | +_logger: logging.Logger = logging.getLogger(__name__) |
| 22 | + |
| 23 | + |
| 24 | +class UserProvidedKeyBlockWritePathProvider(BlockWritePathProvider): |
| 25 | + """Block write path provider. |
| 26 | +
|
| 27 | + Used when writing single-block datasets into a user-provided S3 key. |
| 28 | + """ |
| 29 | + |
| 30 | + def _get_write_path_for_block( |
| 31 | + self, |
| 32 | + base_path: str, |
| 33 | + *, |
| 34 | + filesystem: Optional["pyarrow.fs.FileSystem"] = None, |
| 35 | + dataset_uuid: Optional[str] = None, |
| 36 | + block: Optional[ObjectRef[Block[Any]]] = None, |
| 37 | + block_index: Optional[int] = None, |
| 38 | + file_format: Optional[str] = None, |
| 39 | + ) -> str: |
| 40 | + return base_path |
| 41 | + |
| 42 | + |
| 43 | +class PandasFileBasedDatasource(FileBasedDatasource): # pylint: disable=abstract-method |
| 44 | + """Pandas file based datasource, for reading and writing Pandas blocks.""" |
| 45 | + |
| 46 | + _FILE_EXTENSION: Optional[str] = None |
| 47 | + |
| 48 | + def __init__(self) -> None: |
| 49 | + super().__init__() |
| 50 | + |
| 51 | + self._write_paths: List[str] = [] |
| 52 | + |
| 53 | + def _read_file(self, f: pyarrow.NativeFile, path: str, **reader_args: Any) -> pd.DataFrame: |
| 54 | + raise NotImplementedError() |
| 55 | + |
| 56 | + def do_write( # type: ignore # pylint: disable=arguments-differ |
| 57 | + self, |
| 58 | + blocks: List[ObjectRef[pd.DataFrame]], |
| 59 | + metadata: List[BlockMetadata], |
| 60 | + path: str, |
| 61 | + dataset_uuid: str, |
| 62 | + filesystem: Optional[pyarrow.fs.FileSystem] = None, |
| 63 | + try_create_dir: bool = True, |
| 64 | + open_stream_args: Optional[Dict[str, Any]] = None, |
| 65 | + block_path_provider: BlockWritePathProvider = DefaultBlockWritePathProvider(), |
| 66 | + write_args_fn: Callable[[], Dict[str, Any]] = lambda: {}, |
| 67 | + _block_udf: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, |
| 68 | + ray_remote_args: Optional[Dict[str, Any]] = None, |
| 69 | + s3_additional_kwargs: Optional[Dict[str, str]] = None, |
| 70 | + pandas_kwargs: Optional[Dict[str, Any]] = None, |
| 71 | + compression: Optional[str] = None, |
| 72 | + mode: str = "wb", |
| 73 | + **write_args: Any, |
| 74 | + ) -> List[ObjectRef[WriteResult]]: |
| 75 | + """Create and return write tasks for a file-based datasource.""" |
| 76 | + _write_block_to_file = self._write_block |
| 77 | + |
| 78 | + if ray_remote_args is None: |
| 79 | + ray_remote_args = {} |
| 80 | + |
| 81 | + if pandas_kwargs is None: |
| 82 | + pandas_kwargs = {} |
| 83 | + |
| 84 | + if not compression: |
| 85 | + compression = pandas_kwargs.get("compression") |
| 86 | + |
| 87 | + def write_block(write_path: str, block: pd.DataFrame) -> str: |
| 88 | + _logger.debug("Writing %s file.", write_path) |
| 89 | + |
| 90 | + if _block_udf is not None: |
| 91 | + block = _block_udf(block) |
| 92 | + |
| 93 | + with open_s3_object( |
| 94 | + path=write_path, |
| 95 | + mode=mode, |
| 96 | + use_threads=False, |
| 97 | + s3_additional_kwargs=s3_additional_kwargs, |
| 98 | + encoding=write_args.get("encoding"), |
| 99 | + newline=write_args.get("newline"), |
| 100 | + ) as f: |
| 101 | + _write_block_to_file( |
| 102 | + f, |
| 103 | + PandasBlockAccessor(block), |
| 104 | + pandas_kwargs=pandas_kwargs, |
| 105 | + compression=compression, |
| 106 | + **write_args, |
| 107 | + ) |
| 108 | + return write_path |
| 109 | + |
| 110 | + write_block_fn = cached_remote_fn(write_block).options(**ray_remote_args) |
| 111 | + |
| 112 | + file_format = self._FILE_EXTENSION |
| 113 | + write_tasks = [] |
| 114 | + |
| 115 | + for block_idx, block in enumerate(blocks): |
| 116 | + write_path = block_path_provider( |
| 117 | + path, |
| 118 | + filesystem=filesystem, |
| 119 | + dataset_uuid=dataset_uuid, |
| 120 | + block=block, |
| 121 | + block_index=block_idx, |
| 122 | + file_format=f"{file_format}{_COMPRESSION_2_EXT.get(compression)}", |
| 123 | + ) |
| 124 | + write_task = write_block_fn.remote(write_path, block) |
| 125 | + write_tasks.append(write_task) |
| 126 | + |
| 127 | + return write_tasks |
| 128 | + |
| 129 | + def _write_block( |
| 130 | + self, |
| 131 | + f: "pyarrow.NativeFile", |
| 132 | + block: BlockAccessor[Any], |
| 133 | + writer_args_fn: Callable[[], Dict[str, Any]] = lambda: {}, |
| 134 | + **writer_args: Any, |
| 135 | + ) -> None: |
| 136 | + raise NotImplementedError("Subclasses of PandasFileBasedDatasource must implement _write_block().") |
| 137 | + |
| 138 | + def on_write_complete(self, write_results: List[Any], **_: Any) -> None: |
| 139 | + """Execute callback on write complete.""" |
| 140 | + _logger.debug("Write complete %s.", write_results) |
| 141 | + |
| 142 | + # Collect and return all write task paths |
| 143 | + self._write_paths.extend(write_results) |
| 144 | + |
| 145 | + def on_write_failed(self, write_results: List[ObjectRef[Any]], error: Exception, **_: Any) -> None: |
| 146 | + """Execute callback on write failed.""" |
| 147 | + _logger.debug("Write failed %s.", write_results) |
| 148 | + raise error |
| 149 | + |
| 150 | + def get_write_paths(self) -> List[str]: |
| 151 | + """Return S3 paths of where the results have been written.""" |
| 152 | + return self._write_paths |
0 commit comments