99from __future__ import annotations
1010
1111import logging
12+ from dataclasses import dataclass
1213from typing import (
1314 TYPE_CHECKING ,
1415 Any ,
3435)
3536from ray .data .datasource .datasource import ReadTask
3637from ray .data .datasource .file_meta_provider import (
37- DefaultParquetMetadataProvider ,
38- ParquetMetadataProvider ,
3938 _handle_read_os_error ,
4039)
40+ from ray .data .datasource .parquet_meta_provider import (
41+ ParquetMetadataProvider ,
42+ )
4143from ray .data .datasource .partitioning import PathPartitionFilter
4244from ray .data .datasource .path_util import (
4345 _has_file_extension ,
5557
5658_logger : logging .Logger = logging .getLogger (__name__ )
5759
58- FRAGMENTS_PER_META_FETCH = 6
59- PARALLELIZE_META_FETCH_THRESHOLD = 24
6060
6161# The number of rows to read per batch. This is sized to generate 10MiB batches
6262# for rows about 1KiB in size.
9393PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS = 1024
9494
9595
96+ @dataclass (frozen = True )
97+ class _SampleInfo :
98+ actual_bytes_per_row : int | None
99+ estimated_bytes_per_row : int | None
100+
101+
96102# TODO(ekl) this is a workaround for a pyarrow serialization bug, where serializing a
97103# raw pyarrow file fragment causes S3 network calls.
98104class _SerializedFragment :
@@ -117,38 +123,6 @@ def _deserialize_fragments(
117123 return [p .deserialize () for p in serialized_fragments ]
118124
119125
120- class _ParquetFileFragmentMetaData :
121- """Class to store metadata of a Parquet file fragment.
122-
123- This includes all attributes from `pyarrow.parquet.FileMetaData` except for `schema`,
124- which is stored in `self.schema_pickled` as a pickled object from
125- `cloudpickle.loads()`, used in deduplicating schemas across multiple fragments.
126- """
127-
128- def __init__ (self , fragment_metadata : "pyarrow.parquet.FileMetaData" ):
129- self .created_by = fragment_metadata .created_by
130- self .format_version = fragment_metadata .format_version
131- self .num_columns = fragment_metadata .num_columns
132- self .num_row_groups = fragment_metadata .num_row_groups
133- self .num_rows = fragment_metadata .num_rows
134- self .serialized_size = fragment_metadata .serialized_size
135- # This is a pickled schema object, to be set later with
136- # `self.set_schema_pickled()`. To get the underlying schema, use
137- # `cloudpickle.loads(self.schema_pickled)`.
138- self .schema_pickled : bytes | None = None
139-
140- # Calculate the total byte size of the file fragment using the original
141- # object, as it is not possible to access row groups from this class.
142- self .total_byte_size = 0
143- for row_group_idx in range (fragment_metadata .num_row_groups ):
144- row_group_metadata = fragment_metadata .row_group (row_group_idx )
145- self .total_byte_size += row_group_metadata .total_byte_size
146-
147- def set_schema_pickled (self , schema_pickled : bytes ) -> None :
148- """Note: to get the underlying schema, use `cloudpickle.loads(self.schema_pickled)`."""
149- self .schema_pickled = schema_pickled
150-
151-
152126# This retry helps when the upstream datasource is not able to handle
153127# overloaded read request or failed with some retriable failures.
154128# For example when reading data from HA hdfs service, hdfs might
@@ -213,7 +187,7 @@ def __init__( # noqa: PLR0912,PLR0915
213187 arrow_parquet_args : dict [str , Any ] | None = None ,
214188 _block_udf : Callable [[Block ], Block ] | None = None ,
215189 filesystem : "pyarrow.fs.FileSystem" | None = None ,
216- meta_provider : ParquetMetadataProvider = DefaultParquetMetadataProvider (),
190+ meta_provider : ParquetMetadataProvider = ParquetMetadataProvider (),
217191 partition_filter : PathPartitionFilter | None = None ,
218192 shuffle : Literal ["files" ] | None = None ,
219193 include_paths : bool = False ,
@@ -299,8 +273,7 @@ def __init__( # noqa: PLR0912,PLR0915
299273 prefetch_remote_args = {}
300274 if self ._local_scheduling :
301275 prefetch_remote_args ["scheduling_strategy" ] = self ._local_scheduling
302- raw_metadata = meta_provider .prefetch_file_metadata (pq_ds .fragments , ** prefetch_remote_args ) or []
303- self ._metadata = self ._dedupe_metadata (raw_metadata )
276+ self ._metadata = meta_provider .prefetch_file_metadata (pq_ds .fragments , ** prefetch_remote_args ) or []
304277 except OSError as e :
305278 _handle_read_os_error (e , paths )
306279 except pa .ArrowInvalid as ex :
@@ -319,43 +292,15 @@ def __init__( # noqa: PLR0912,PLR0915
319292 self ._columns = columns
320293 self ._schema = schema
321294 self ._arrow_parquet_args = arrow_parquet_args
322- self ._encoding_ratio = self ._estimate_files_encoding_ratio ()
323295 self ._file_metadata_shuffler = None
324296 self ._include_paths = include_paths
325297 self ._path_root = path_root
326298 if shuffle == "files" :
327299 self ._file_metadata_shuffler = np .random .default_rng ()
328300
329- def _dedupe_metadata (
330- self ,
331- raw_metadatas : list ["pyarrow.parquet.FileMetaData" ],
332- ) -> list [_ParquetFileFragmentMetaData ]:
333- """Deduplicate schemas to reduce memory usage.
334-
335- For datasets with a large number of columns, the FileMetaData
336- (in particular the schema) can be very large. We can reduce the
337- memory usage by only keeping unique schema objects across all
338- file fragments. This method deduplicates the schemas and returns
339- a list of `_ParquetFileFragmentMetaData` objects.
340- """
341- schema_to_id : dict [int , Any ] = {} # schema_id -> serialized_schema
342- id_to_schema : dict [Any , bytes ] = {} # serialized_schema -> schema_id
343- stripped_metadatas = []
344- for fragment_metadata in raw_metadatas :
345- stripped_md = _ParquetFileFragmentMetaData (fragment_metadata )
346-
347- schema_ser = cloudpickle .dumps (fragment_metadata .schema .to_arrow_schema ()) # type: ignore[no-untyped-call]
348- if schema_ser not in schema_to_id :
349- schema_id : int | None = len (schema_to_id )
350- schema_to_id [schema_ser ] = schema_id
351- id_to_schema [schema_id ] = schema_ser
352- stripped_md .set_schema_pickled (schema_ser )
353- else :
354- schema_id = schema_to_id .get (schema_ser )
355- existing_schema_ser = id_to_schema [schema_id ]
356- stripped_md .set_schema_pickled (existing_schema_ser )
357- stripped_metadatas .append (stripped_md )
358- return stripped_metadatas
301+ sample_infos = self ._sample_fragments ()
302+ self ._encoding_ratio = _estimate_files_encoding_ratio (sample_infos )
303+ self ._default_read_batch_size_rows = _estimate_default_read_batch_size_rows (sample_infos )
359304
360305 def estimate_inmemory_data_size (self ) -> int | None :
361306 """Return an estimate of the Parquet files encoding ratio.
@@ -414,25 +359,18 @@ def get_read_tasks(self, parallelism: int) -> list[ReadTask]:
414359 if meta .size_bytes is not None :
415360 meta .size_bytes = int (meta .size_bytes * self ._encoding_ratio )
416361
417- if meta .num_rows and meta .size_bytes :
418- # Make sure the batches read are small enough to enable yielding of
419- # output blocks incrementally during the read.
420- row_size = meta .size_bytes / meta .num_rows
421- # Make sure the row batch size is small enough that block splitting
422- # is still effective.
423- max_parquet_reader_row_batch_size_bytes = DataContext .get_current ().target_max_block_size // 10
424- default_read_batch_size_rows = max (
425- 1 ,
426- min (
427- PARQUET_READER_ROW_BATCH_SIZE ,
428- max_parquet_reader_row_batch_size_bytes // row_size ,
429- ),
430- )
431- else :
432- default_read_batch_size_rows = PARQUET_READER_ROW_BATCH_SIZE
433- block_udf , arrow_parquet_args , columns , schema , path_root , include_paths = (
362+ (
363+ block_udf ,
364+ arrow_parquet_args ,
365+ default_read_batch_size_rows ,
366+ columns ,
367+ schema ,
368+ path_root ,
369+ include_paths ,
370+ ) = (
434371 self ._block_udf ,
435372 self ._arrow_parquet_args ,
373+ self ._default_read_batch_size_rows ,
436374 self ._columns ,
437375 self ._schema ,
438376 self ._path_root ,
@@ -456,14 +394,7 @@ def get_read_tasks(self, parallelism: int) -> list[ReadTask]:
456394
457395 return read_tasks
458396
459- def _estimate_files_encoding_ratio (self ) -> float :
460- """Return an estimate of the Parquet files encoding ratio.
461-
462- To avoid OOMs, it is safer to return an over-estimate than an underestimate.
463- """
464- if not DataContext .get_current ().decoding_size_estimation :
465- return PARQUET_ENCODING_RATIO_ESTIMATE_DEFAULT
466-
397+ def _sample_fragments (self ) -> list [_SampleInfo ]:
467398 # Sample a few rows from Parquet files to estimate the encoding ratio.
468399 # Launch tasks to sample multiple files remotely in parallel.
469400 # Evenly distributed to sample N rows in i-th row group in i-th file.
@@ -495,11 +426,10 @@ def _estimate_files_encoding_ratio(self) -> float:
495426 )
496427 )
497428 sample_bar = ProgressBar ("Parquet Files Sample" , len (futures ))
498- sample_ratios = sample_bar .fetch_until_complete (futures )
429+ sample_infos = sample_bar .fetch_until_complete (futures )
499430 sample_bar .close () # type: ignore[no-untyped-call]
500- ratio = np .mean (sample_ratios )
501- _logger .debug (f"Estimated Parquet encoding ratio from sampling is { ratio } ." )
502- return max (ratio , PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND ) # type: ignore[no-any-return]
431+
432+ return sample_infos
503433
504434 def get_name (self ) -> str :
505435 """Return a human-readable name for this datasource.
@@ -577,33 +507,12 @@ def _read_fragments(
577507 yield table
578508
579509
580- def _fetch_metadata_serialization_wrapper (
581- fragments : list [_SerializedFragment ],
582- ) -> list ["pyarrow.parquet.FileMetaData" ]:
583- fragments : list ["pyarrow._dataset.ParquetFileFragment" ] = _deserialize_fragments_with_retry (fragments ) # type: ignore[no-redef]
584-
585- return _fetch_metadata (fragments )
586-
587-
588- def _fetch_metadata (
589- fragments : list ["pyarrow.dataset.ParquetFileFragment" ],
590- ) -> list ["pyarrow.parquet.FileMetaData" ]:
591- fragment_metadata = []
592- for f in fragments :
593- try :
594- fragment_metadata .append (f .metadata )
595- except AttributeError :
596- break
597- return fragment_metadata
598-
599-
600510def _sample_fragment (
601511 columns : list [str ] | None ,
602512 schema : type | "pyarrow.lib.Schema" | None ,
603513 file_fragment : _SerializedFragment ,
604- ) -> float :
514+ ) -> _SampleInfo :
605515 # Sample the first rows batch from file fragment `serialized_fragment`.
606- # Return the encoding ratio calculated from the sampled rows.
607516 fragment = _deserialize_fragments_with_retry ([file_fragment ])[0 ]
608517
609518 # Only sample the first row group.
@@ -616,23 +525,57 @@ def _sample_fragment(
616525 schema = schema ,
617526 batch_size = batch_size ,
618527 )
619- # Use first batch in-memory size as ratio estimation.
528+ # Use first batch in-memory size for estimation.
620529 try :
621530 batch = next (batches )
622531 except StopIteration :
623- ratio = PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND
532+ sample_data = _SampleInfo ( actual_bytes_per_row = None , estimated_bytes_per_row = None )
624533 else :
625534 if batch .num_rows > 0 :
626- in_memory_size = batch .nbytes / batch .num_rows
627535 metadata = fragment .metadata
628536 total_size = 0
629537 for idx in range (metadata .num_row_groups ):
630538 total_size += metadata .row_group (idx ).total_byte_size
631- file_size = total_size / metadata .num_rows
632- ratio = in_memory_size / file_size
539+ sample_data = _SampleInfo (
540+ actual_bytes_per_row = batch .nbytes / batch .num_rows ,
541+ estimated_bytes_per_row = total_size / metadata .num_rows ,
542+ )
633543 else :
634- ratio = PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND
635- _logger .debug (
636- f"Estimated Parquet encoding ratio is { ratio } for fragment { fragment } " f"with batch size { batch_size } ."
637- )
638- return ratio
544+ sample_data = _SampleInfo (actual_bytes_per_row = None , estimated_bytes_per_row = None )
545+ return sample_data
546+
547+
548+ def _estimate_files_encoding_ratio (sample_infos : list [_SampleInfo ]) -> float :
549+ """Return an estimate of the Parquet files encoding ratio.
550+
551+ To avoid OOMs, it is safer to return an over-estimate than an underestimate.
552+ """
553+ if not DataContext .get_current ().decoding_size_estimation :
554+ return PARQUET_ENCODING_RATIO_ESTIMATE_DEFAULT
555+
556+ def compute_encoding_ratio (sample_info : _SampleInfo ) -> float :
557+ if sample_info .actual_bytes_per_row is None or sample_info .estimated_bytes_per_row is None :
558+ return PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND
559+ else :
560+ return sample_info .actual_bytes_per_row / sample_info .estimated_bytes_per_row
561+
562+ ratio = np .mean (list (map (compute_encoding_ratio , sample_infos )))
563+ _logger .debug (f"Estimated Parquet encoding ratio from sampling is { ratio } ." )
564+ return max (ratio , PARQUET_ENCODING_RATIO_ESTIMATE_LOWER_BOUND ) # type: ignore[return-value]
565+
566+
567+ def _estimate_default_read_batch_size_rows (sample_infos : list [_SampleInfo ]) -> int :
568+ def compute_batch_size_rows (sample_info : _SampleInfo ) -> int :
569+ if sample_info .actual_bytes_per_row is None :
570+ return PARQUET_READER_ROW_BATCH_SIZE
571+ else :
572+ max_parquet_reader_row_batch_size_bytes = DataContext .get_current ().target_max_block_size // 10
573+ return max (
574+ 1 ,
575+ min (
576+ PARQUET_READER_ROW_BATCH_SIZE ,
577+ max_parquet_reader_row_batch_size_bytes // sample_info .actual_bytes_per_row ,
578+ ),
579+ )
580+
581+ return np .mean (list (map (compute_batch_size_rows , sample_infos ))) # type: ignore[return-value]
0 commit comments