11import itertools
22from dataclasses import dataclass
3- from typing import List , Optional
3+ from typing import List , Optional , Union
44
55import pyarrow as pa
6+ import pyarrow .dataset as ds
67import pyarrow .parquet as pq
78
89import datasets
@@ -19,6 +20,7 @@ class ParquetConfig(datasets.BuilderConfig):
1920 batch_size : Optional [int ] = None
2021 columns : Optional [List [str ]] = None
2122 features : Optional [datasets .Features ] = None
23+ filters : Optional [Union [ds .Expression , List [tuple ], List [List [tuple ]]]] = None
2224
2325 def __post_init__ (self ):
2426 super ().__post_init__ ()
@@ -77,14 +79,25 @@ def _generate_tables(self, files):
7779 raise ValueError (
7880 f"Tried to load parquet data with columns '{ self .config .columns } ' with mismatching features '{ self .info .features } '"
7981 )
82+ filter_expr = (
83+ pq .filters_to_expression (self .config .filters )
84+ if isinstance (self .config .filters , list )
85+ else self .config .filters
86+ )
8087 for file_idx , file in enumerate (itertools .chain .from_iterable (files )):
8188 with open (file , "rb" ) as f :
82- parquet_file = pq . ParquetFile (f )
83- if parquet_file . metadata . num_row_groups > 0 :
84- batch_size = self .config .batch_size or parquet_file . metadata . row_group ( 0 ) .num_rows
89+ parquet_fragment = ds . ParquetFileFormat (). make_fragment (f )
90+ if parquet_fragment . row_groups :
91+ batch_size = self .config .batch_size or parquet_fragment . row_groups [ 0 ] .num_rows
8592 try :
8693 for batch_idx , record_batch in enumerate (
87- parquet_file .iter_batches (batch_size = batch_size , columns = self .config .columns )
94+ parquet_fragment .to_batches (
95+ batch_size = batch_size ,
96+ columns = self .config .columns ,
97+ filter = filter_expr ,
98+ batch_readahead = 0 ,
99+ fragment_readahead = 0 ,
100+ )
88101 ):
89102 pa_table = pa .Table .from_batches ([record_batch ])
90103 # Uncomment for debugging (will print the Arrow table size and elements)
0 commit comments