Skip to content

Commit 01574b8

Browse files
authored
feat: Sync ray 2.4 parquet datasource (#2300)
* Sync parquet datasource with 2.4 ray version * [skip ci] Formatting * Minor fixes & pull s3fs with Modin * Mov read args outside of the for loop * [skip-ci] Remove s3fs from runtime dependencies * Revert to poetry.lock from main
1 parent d3650ff commit 01574b8

File tree

1 file changed

+63
-35
lines changed

1 file changed

+63
-35
lines changed

awswrangler/distributed/ray/datasources/arrow_parquet_datasource.py

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,17 @@
88
# pylint: disable=redefined-outer-name,import-outside-toplevel,reimported
99

1010
import logging
11-
import time
1211
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
1312

1413
import numpy as np
1514

1615
# fs required to implicitly trigger S3 subsystem initialization
1716
import pyarrow.fs # noqa: F401 pylint: disable=unused-import
18-
import ray
1917
from pyarrow.dataset import ParquetFileFragment
2018
from pyarrow.lib import Schema
2119
from ray import cloudpickle
2220
from ray.data._internal.output_buffer import BlockOutputBuffer
21+
from ray.data._internal.progress_bar import ProgressBar
2322
from ray.data.block import Block, BlockAccessor
2423
from ray.data.context import DatasetContext
2524
from ray.data.datasource import Reader, ReadTask
@@ -189,6 +188,7 @@ class _ArrowParquetDatasourceReader(Reader[Any]): # pylint: disable=too-many-in
189188
def __init__(
190189
self,
191190
paths: Union[str, List[str]],
191+
local_uri: bool = False,
192192
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
193193
columns: Optional[List[str]] = None,
194194
schema: Optional[Schema] = None,
@@ -203,6 +203,13 @@ def __init__(
203203
if len(paths) == 1:
204204
paths = paths[0]
205205

206+
self._local_scheduling = None
207+
if local_uri:
208+
import ray
209+
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
210+
211+
self._local_scheduling = NodeAffinitySchedulingStrategy(ray.get_runtime_context().get_node_id(), soft=False)
212+
206213
dataset_kwargs = reader_args.pop("dataset_kwargs", {})
207214
try:
208215
pq_ds = pq.ParquetDataset(paths, **dataset_kwargs, filesystem=filesystem, use_legacy_dataset=False)
@@ -230,7 +237,10 @@ def __init__(
230237
inferred_schema = schema
231238

232239
try:
233-
self._metadata = meta_provider.prefetch_file_metadata(pq_ds.pieces) or []
240+
prefetch_remote_args = {}
241+
if self._local_scheduling:
242+
prefetch_remote_args["scheduling_strategy"] = self._local_scheduling
243+
self._metadata = meta_provider.prefetch_file_metadata(pq_ds.pieces, **prefetch_remote_args) or []
234244
except OSError as e:
235245
_handle_read_os_error(e, paths)
236246
self._pq_ds = pq_ds
@@ -307,7 +317,6 @@ def _estimate_files_encoding_ratio(self) -> float:
307317
# Launch tasks to sample multiple files remotely in parallel.
308318
# Evenly distributed to sample N rows in i-th row group in i-th file.
309319
# TODO(ekl/cheng) take into account column pruning.
310-
start_time = time.perf_counter()
311320
num_files = len(self._pq_ds.pieces)
312321
num_samples = int(num_files * PARQUET_ENCODING_RATIO_ESTIMATE_SAMPLING_RATIO)
313322
min_num_samples = min(PARQUET_ENCODING_RATIO_ESTIMATE_MIN_NUM_SAMPLES, num_files)
@@ -321,18 +330,25 @@ def _estimate_files_encoding_ratio(self) -> float:
321330
]
322331

323332
futures = []
324-
for idx, sample in enumerate(file_samples):
325-
# Sample i-th row group in i-th file.
333+
sample_piece = ray_remote(scheduling_strategy=self._local_scheduling or "SPREAD")(_sample_piece)
334+
for sample in file_samples:
335+
# Sample the first rows batch in i-th file.
326336
# Use SPREAD scheduling strategy to avoid packing many sampling tasks on
327337
# same machine to cause OOM issue, as sampling can be memory-intensive.
328-
futures.append(_sample_piece(_SerializedPiece(sample), idx))
329-
sample_ratios = ray.get(futures)
338+
serialized_sample = _SerializedPiece(sample)
339+
futures.append(
340+
sample_piece(
341+
self._reader_args,
342+
self._columns,
343+
self._schema,
344+
serialized_sample,
345+
)
346+
)
347+
sample_bar = ProgressBar("Parquet Files Sample", len(futures))
348+
sample_ratios = sample_bar.fetch_until_complete(futures)
349+
sample_bar.close() # type: ignore[no-untyped-call]
330350
ratio = np.mean(sample_ratios)
331-
332-
sampling_duration = time.perf_counter() - start_time
333-
if sampling_duration > 5:
334-
_logger.info("Parquet input size estimation took %s seconds.", round(sampling_duration, 2))
335-
_logger.debug("Estimated Parquet encoding ratio from sampling is %s.", ratio)
351+
_logger.debug(f"Estimated Parquet encoding ratio from sampling is {ratio}.")
336352
return max(ratio, PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND) # type: ignore[no-any-return]
337353

338354

@@ -389,32 +405,44 @@ def _read_pieces(
389405
yield output_buffer.next()
390406

391407

392-
@ray_remote(scheduling_strategy="SPREAD")
393408
def _sample_piece(
409+
reader_args: Any,
410+
columns: Optional[List[str]],
411+
schema: Optional[Union[type, "pyarrow.lib.Schema"]],
394412
file_piece: _SerializedPiece,
395-
row_group_id: int,
396413
) -> float:
397-
# Sample the `row_group_id`-th row group from file piece `serialized_piece`.
414+
# Sample the first rows batch from file piece `serialized_piece`.
398415
# Return the encoding ratio calculated from the sampled rows.
399416
piece = _deserialize_pieces_with_retry([file_piece])[0]
400417

401-
# If required row group index is out of boundary, sample the last row group.
402-
row_group_id = min(piece.num_row_groups - 1, row_group_id)
403-
assert (
404-
0 <= row_group_id <= piece.num_row_groups - 1
405-
), f"Required row group id {row_group_id} is not in expected bound"
406-
407-
row_group = piece.subset(row_group_ids=[row_group_id])
408-
metadata = row_group.metadata.row_group(0)
409-
num_rows = min(PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS, metadata.num_rows)
410-
assert num_rows > 0 and metadata.num_rows > 0, (
411-
f"Sampled number of rows: {num_rows} and total number of rows: " f"{metadata.num_rows} should be positive"
418+
# Only sample the first row group.
419+
piece = piece.subset(row_group_ids=[0])
420+
batch_size = max(min(piece.metadata.num_rows, PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS), 1)
421+
# Use the batch_size calculated above, and ignore the one specified by user if set.
422+
# This is to avoid sampling too few or too many rows.
423+
reader_args.pop("batch_size", None)
424+
reader_args.pop("path_root", None)
425+
batches = piece.to_batches(
426+
columns=columns,
427+
schema=schema,
428+
batch_size=batch_size,
429+
**reader_args,
412430
)
413-
414-
parquet_size: float = metadata.total_byte_size / metadata.num_rows
415-
# Set batch_size to num_rows will instruct Arrow Parquet reader to read exactly
416-
# num_rows into memory, o.w. it will read more rows by default in batch manner.
417-
in_memory_size: float = row_group.head(num_rows, batch_size=num_rows).nbytes / num_rows
418-
ratio: float = in_memory_size / parquet_size
419-
_logger.debug("Estimated Parquet encoding ratio is %s for piece %s.", ratio, piece)
420-
return in_memory_size / parquet_size
431+
# Use first batch in-memory size as ratio estimation.
432+
try:
433+
batch = next(batches)
434+
except StopIteration:
435+
ratio = PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND
436+
else:
437+
if batch.num_rows > 0:
438+
in_memory_size = batch.nbytes / batch.num_rows
439+
metadata = piece.metadata
440+
total_size = 0
441+
for idx in range(metadata.num_row_groups):
442+
total_size += metadata.row_group(idx).total_byte_size
443+
file_size = total_size / metadata.num_rows
444+
ratio = in_memory_size / file_size
445+
else:
446+
ratio = PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND
447+
_logger.debug(f"Estimated Parquet encoding ratio is {ratio} for piece {piece} " f"with batch size {batch_size}.")
448+
return ratio

0 commit comments

Comments
 (0)