88# pylint: disable=redefined-outer-name,import-outside-toplevel,reimported
99
1010import logging
11- import time
1211from typing import Any , Callable , Dict , Iterator , List , Optional , Union
1312
1413import numpy as np
1514
1615# fs required to implicitly trigger S3 subsystem initialization
1716import pyarrow .fs # noqa: F401 pylint: disable=unused-import
18- import ray
1917from pyarrow .dataset import ParquetFileFragment
2018from pyarrow .lib import Schema
2119from ray import cloudpickle
2220from ray .data ._internal .output_buffer import BlockOutputBuffer
21+ from ray .data ._internal .progress_bar import ProgressBar
2322from ray .data .block import Block , BlockAccessor
2423from ray .data .context import DatasetContext
2524from ray .data .datasource import Reader , ReadTask
@@ -189,6 +188,7 @@ class _ArrowParquetDatasourceReader(Reader[Any]): # pylint: disable=too-many-in
189188 def __init__ (
190189 self ,
191190 paths : Union [str , List [str ]],
191+ local_uri : bool = False ,
192192 filesystem : Optional ["pyarrow.fs.FileSystem" ] = None ,
193193 columns : Optional [List [str ]] = None ,
194194 schema : Optional [Schema ] = None ,
@@ -203,6 +203,13 @@ def __init__(
203203 if len (paths ) == 1 :
204204 paths = paths [0 ]
205205
206+ self ._local_scheduling = None
207+ if local_uri :
208+ import ray
209+ from ray .util .scheduling_strategies import NodeAffinitySchedulingStrategy
210+
211+ self ._local_scheduling = NodeAffinitySchedulingStrategy (ray .get_runtime_context ().get_node_id (), soft = False )
212+
206213 dataset_kwargs = reader_args .pop ("dataset_kwargs" , {})
207214 try :
208215 pq_ds = pq .ParquetDataset (paths , ** dataset_kwargs , filesystem = filesystem , use_legacy_dataset = False )
@@ -230,7 +237,10 @@ def __init__(
230237 inferred_schema = schema
231238
232239 try :
233- self ._metadata = meta_provider .prefetch_file_metadata (pq_ds .pieces ) or []
240+ prefetch_remote_args = {}
241+ if self ._local_scheduling :
242+ prefetch_remote_args ["scheduling_strategy" ] = self ._local_scheduling
243+ self ._metadata = meta_provider .prefetch_file_metadata (pq_ds .pieces , ** prefetch_remote_args ) or []
234244 except OSError as e :
235245 _handle_read_os_error (e , paths )
236246 self ._pq_ds = pq_ds
@@ -307,7 +317,6 @@ def _estimate_files_encoding_ratio(self) -> float:
307317 # Launch tasks to sample multiple files remotely in parallel.
308318 # Evenly distributed to sample N rows in i-th row group in i-th file.
309319 # TODO(ekl/cheng) take into account column pruning.
310- start_time = time .perf_counter ()
311320 num_files = len (self ._pq_ds .pieces )
312321 num_samples = int (num_files * PARQUET_ENCODING_RATIO_ESTIMATE_SAMPLING_RATIO )
313322 min_num_samples = min (PARQUET_ENCODING_RATIO_ESTIMATE_MIN_NUM_SAMPLES , num_files )
@@ -321,18 +330,25 @@ def _estimate_files_encoding_ratio(self) -> float:
321330 ]
322331
323332 futures = []
324- for idx , sample in enumerate (file_samples ):
325- # Sample i-th row group in i-th file.
333+ sample_piece = ray_remote (scheduling_strategy = self ._local_scheduling or "SPREAD" )(_sample_piece )
334+ for sample in file_samples :
335+ # Sample the first rows batch in i-th file.
326336 # Use SPREAD scheduling strategy to avoid packing many sampling tasks on
327337 # same machine to cause OOM issue, as sampling can be memory-intensive.
328- futures .append (_sample_piece (_SerializedPiece (sample ), idx ))
329- sample_ratios = ray .get (futures )
338+ serialized_sample = _SerializedPiece (sample )
339+ futures .append (
340+ sample_piece (
341+ self ._reader_args ,
342+ self ._columns ,
343+ self ._schema ,
344+ serialized_sample ,
345+ )
346+ )
347+ sample_bar = ProgressBar ("Parquet Files Sample" , len (futures ))
348+ sample_ratios = sample_bar .fetch_until_complete (futures )
349+ sample_bar .close () # type: ignore[no-untyped-call]
330350 ratio = np .mean (sample_ratios )
331-
332- sampling_duration = time .perf_counter () - start_time
333- if sampling_duration > 5 :
334- _logger .info ("Parquet input size estimation took %s seconds." , round (sampling_duration , 2 ))
335- _logger .debug ("Estimated Parquet encoding ratio from sampling is %s." , ratio )
351+ _logger .debug (f"Estimated Parquet encoding ratio from sampling is { ratio } ." )
336352 return max (ratio , PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND ) # type: ignore[no-any-return]
337353
338354
@@ -389,32 +405,44 @@ def _read_pieces(
389405 yield output_buffer .next ()
390406
391407
392- @ray_remote (scheduling_strategy = "SPREAD" )
393408def _sample_piece (
409+ reader_args : Any ,
410+ columns : Optional [List [str ]],
411+ schema : Optional [Union [type , "pyarrow.lib.Schema" ]],
394412 file_piece : _SerializedPiece ,
395- row_group_id : int ,
396413) -> float :
397- # Sample the `row_group_id`-th row group from file piece `serialized_piece`.
414+ # Sample the first rows batch from file piece `serialized_piece`.
398415 # Return the encoding ratio calculated from the sampled rows.
399416 piece = _deserialize_pieces_with_retry ([file_piece ])[0 ]
400417
401- # If required row group index is out of boundary, sample the last row group.
402- row_group_id = min (piece .num_row_groups - 1 , row_group_id )
403- assert (
404- 0 <= row_group_id <= piece .num_row_groups - 1
405- ), f"Required row group id { row_group_id } is not in expected bound"
406-
407- row_group = piece .subset (row_group_ids = [row_group_id ])
408- metadata = row_group .metadata .row_group (0 )
409- num_rows = min (PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS , metadata .num_rows )
410- assert num_rows > 0 and metadata .num_rows > 0 , (
411- f"Sampled number of rows: { num_rows } and total number of rows: " f"{ metadata .num_rows } should be positive"
418+ # Only sample the first row group.
419+ piece = piece .subset (row_group_ids = [0 ])
420+ batch_size = max (min (piece .metadata .num_rows , PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS ), 1 )
421+ # Use the batch_size calculated above, and ignore the one specified by user if set.
422+ # This is to avoid sampling too few or too many rows.
423+ reader_args .pop ("batch_size" , None )
424+ reader_args .pop ("path_root" , None )
425+ batches = piece .to_batches (
426+ columns = columns ,
427+ schema = schema ,
428+ batch_size = batch_size ,
429+ ** reader_args ,
412430 )
413-
414- parquet_size : float = metadata .total_byte_size / metadata .num_rows
415- # Set batch_size to num_rows will instruct Arrow Parquet reader to read exactly
416- # num_rows into memory, o.w. it will read more rows by default in batch manner.
417- in_memory_size : float = row_group .head (num_rows , batch_size = num_rows ).nbytes / num_rows
418- ratio : float = in_memory_size / parquet_size
419- _logger .debug ("Estimated Parquet encoding ratio is %s for piece %s." , ratio , piece )
420- return in_memory_size / parquet_size
431+ # Use first batch in-memory size as ratio estimation.
432+ try :
433+ batch = next (batches )
434+ except StopIteration :
435+ ratio = PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND
436+ else :
437+ if batch .num_rows > 0 :
438+ in_memory_size = batch .nbytes / batch .num_rows
439+ metadata = piece .metadata
440+ total_size = 0
441+ for idx in range (metadata .num_row_groups ):
442+ total_size += metadata .row_group (idx ).total_byte_size
443+ file_size = total_size / metadata .num_rows
444+ ratio = in_memory_size / file_size
445+ else :
446+ ratio = PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND
447+ _logger .debug (f"Estimated Parquet encoding ratio is { ratio } for piece { piece } " f"with batch size { batch_size } ." )
448+ return ratio
0 commit comments