@@ -138,11 +138,30 @@ def is_list_pa_type(parquet_file_path: Path, feature_name: str) -> bool:
138138 return is_list
139139
140140
141+ def truncate_binary_columns (table : pa .Table , max_binary_length : int , features : Features ) -> tuple [pa .Table , list [str ]]:
142+ # truncate binary columns in the Arrow table to the specified maximum length
143+ # return a new Arrow table and the list of truncated columns
144+ if max_binary_length < 0 :
145+ return table , []
146+
147+ columns : dict [str , pa .Array ] = {}
148+ truncated_column_names : list [str ] = []
149+ for field_idx , field in enumerate (table .schema ): # noqa: F402
150+ if features [field .name ] == Value ("binary" ) and table [field_idx ].nbytes > max_binary_length :
151+ truncated_array = pc .binary_slice (table [field_idx ], 0 , max_binary_length // len (table ))
152+ columns [field .name ] = truncated_array
153+ truncated_column_names .append (field .name )
154+ else :
155+ columns [field .name ] = table [field_idx ]
156+
157+ return pa .table (columns ), truncated_column_names
158+
159+
141160@dataclass
142161class RowGroupReader :
143162 parquet_file : pq .ParquetFile
144163 group_id : int
145- features : Features
164+ schema : pa . Schema
146165
147166 def read (self , columns : list [str ]) -> pa .Table :
148167 if not set (self .parquet_file .schema_arrow .names ) <= set (columns ):
@@ -151,18 +170,7 @@ def read(self, columns: list[str]) -> pa.Table:
151170 )
152171 pa_table = self .parquet_file .read_row_group (i = self .group_id , columns = columns )
153172 # cast_table_to_schema adds null values to missing columns
154- return cast_table_to_schema (pa_table , self .features .arrow_schema )
155-
156- def read_truncated_binary (self , columns : list [str ], max_binary_length : int ) -> tuple [pa .Table , list [str ]]:
157- pa_table = self .parquet_file .read_row_group (i = self .group_id , columns = columns )
158- truncated_columns : list [str ] = []
159- if max_binary_length :
160- for field_idx , field in enumerate (pa_table .schema ):
161- if self .features [field .name ] == Value ("binary" ) and pa_table [field_idx ].nbytes > max_binary_length :
162- truncated_array = pc .binary_slice (pa_table [field_idx ], 0 , max_binary_length // len (pa_table ))
163- pa_table = pa_table .set_column (field_idx , field , truncated_array )
164- truncated_columns .append (field .name )
165- return cast_table_to_schema (pa_table , self .features .arrow_schema ), truncated_columns
173+ return cast_table_to_schema (pa_table , self .schema )
166174
167175 def read_size (self , columns : Optional [Iterable [str ]] = None ) -> int :
168176 if columns is None :
@@ -179,32 +187,33 @@ def read_size(self, columns: Optional[Iterable[str]] = None) -> int:
179187
180188@dataclass
181189class ParquetIndexWithMetadata :
190+ files : list [ParquetFileMetadataItem ]
182191 features : Features
183- parquet_files_urls : list [str ]
184- metadata_paths : list [str ]
185- num_bytes : list [int ]
186- num_rows : list [int ]
187192 httpfs : HTTPFileSystem
188193 max_arrow_data_in_memory : int
189194 partial : bool
195+ metadata_dir : Path
190196
197+ file_offsets : np .ndarray = field (init = False )
191198 num_rows_total : int = field (init = False )
192199
193200 def __post_init__ (self ) -> None :
194201 if self .httpfs ._session is None :
195202 self .httpfs_session = asyncio .run (self .httpfs .set_session ())
196203 else :
197204 self .httpfs_session = self .httpfs ._session
198- self .num_rows_total = sum (self .num_rows )
199205
200- def query_truncated_binary (self , offset : int , length : int ) -> tuple [pa .Table , list [str ]]:
206+ num_rows = np .array ([f ["num_rows" ] for f in self .files ])
207+ self .file_offsets = np .cumsum (num_rows )
208+ self .num_rows_total = np .sum (num_rows )
209+
210+ def query (self , offset : int , length : int ) -> tuple [pa .Table , list [str ]]:
201211 """Query the parquet files
202212
203213 Note that this implementation will always read at least one row group, to get the list of columns and always
204214 have the same schema, even if the requested rows are invalid (out of range).
205215
206- This is the same as query() except that:
207-
216+ If binary columns are present, then:
208217 - it computes a maximum size to allocate to binary data in step "parquet_index_with_metadata.row_groups_size_check_truncated_binary"
209218 - it uses `read_truncated_binary()` in step "parquet_index_with_metadata.query_truncated_binary".
210219
@@ -219,27 +228,19 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
219228 `pa.Table`: The requested rows.
220229 `list[strl]: List of truncated columns.
221230 """
222- all_columns = set (self .features )
223- binary_columns = set (column for column , feature in self .features .items () if feature == Value ("binary" ))
224- if not binary_columns :
225- return self .query (offset = offset , length = length ), []
226231 with StepProfiler (
227232 method = "parquet_index_with_metadata.query" , step = "get the parquet files than contain the requested rows"
228233 ):
229- parquet_file_offsets = np .cumsum (self .num_rows )
230-
231- last_row_in_parquet = parquet_file_offsets [- 1 ] - 1
234+ last_row_in_parquet = self .file_offsets [- 1 ] - 1
232235 first_row = min (offset , last_row_in_parquet )
233236 last_row = min (offset + length - 1 , last_row_in_parquet )
234237 first_parquet_file_id , last_parquet_file_id = np .searchsorted (
235- parquet_file_offsets , [first_row , last_row ], side = "right"
238+ self . file_offsets , [first_row , last_row ], side = "right"
236239 )
237240 parquet_offset = (
238- offset - parquet_file_offsets [first_parquet_file_id - 1 ] if first_parquet_file_id > 0 else offset
241+ offset - self . file_offsets [first_parquet_file_id - 1 ] if first_parquet_file_id > 0 else offset
239242 )
240- urls = self .parquet_files_urls [first_parquet_file_id : last_parquet_file_id + 1 ] # noqa: E203
241- metadata_paths = self .metadata_paths [first_parquet_file_id : last_parquet_file_id + 1 ] # noqa: E203
242- num_bytes = self .num_bytes [first_parquet_file_id : last_parquet_file_id + 1 ] # noqa: E203
243+ files_to_scan = self .files [first_parquet_file_id : last_parquet_file_id + 1 ] # noqa: E203
243244
244245 with StepProfiler (
245246 method = "parquet_index_with_metadata.query" , step = "load the remote parquet files using metadata from disk"
@@ -248,17 +249,17 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
248249 pq .ParquetFile (
249250 HTTPFile (
250251 self .httpfs ,
251- url ,
252+ f [ " url" ] ,
252253 session = self .httpfs_session ,
253- size = size ,
254+ size = f [ " size" ] ,
254255 loop = self .httpfs .loop ,
255256 cache_type = None ,
256257 ** self .httpfs .kwargs ,
257258 ),
258- metadata = pq .read_metadata (metadata_path ),
259+ metadata = pq .read_metadata (self . metadata_dir / f [ "parquet_metadata_subpath" ] ),
259260 pre_buffer = True ,
260261 )
261- for url , metadata_path , size in zip ( urls , metadata_paths , num_bytes )
262+ for f in files_to_scan
262263 ]
263264
264265 with StepProfiler (
@@ -272,7 +273,7 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
272273 ]
273274 )
274275 row_group_readers = [
275- RowGroupReader (parquet_file = parquet_file , group_id = group_id , features = self .features )
276+ RowGroupReader (parquet_file = parquet_file , group_id = group_id , schema = self .features . arrow_schema )
276277 for parquet_file in parquet_files
277278 for group_id in range (parquet_file .metadata .num_row_groups )
278279 ]
@@ -290,6 +291,28 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
290291 row_group_offsets , [first_row , last_row ], side = "right"
291292 )
292293
294+ all_columns = set (self .features )
295+ binary_columns = set (column for column , feature in self .features .items () if feature == Value ("binary" ))
296+ if binary_columns :
297+ pa_table , truncated_columns = self ._read_with_binary (
298+ row_group_readers , first_row_group_id , last_row_group_id , all_columns , binary_columns
299+ )
300+ else :
301+ pa_table , truncated_columns = self ._read_without_binary (
302+ row_group_readers , first_row_group_id , last_row_group_id
303+ )
304+
305+ first_row_in_pa_table = row_group_offsets [first_row_group_id - 1 ] if first_row_group_id > 0 else 0
306+ return pa_table .slice (parquet_offset - first_row_in_pa_table , length ), truncated_columns
307+
308+ def _read_with_binary (
309+ self ,
310+ row_group_readers : list [RowGroupReader ],
311+ first_row_group_id : int ,
312+ last_row_group_id : int ,
313+ all_columns : set [str ],
314+ binary_columns : set [str ],
315+ ) -> tuple [pa .Table , list [str ]]:
293316 with StepProfiler (
294317 method = "parquet_index_with_metadata.row_groups_size_check_truncated_binary" ,
295318 step = "check if the rows can fit in memory" ,
@@ -329,100 +352,21 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
329352 columns = list (self .features .keys ())
330353 truncated_columns : set [str ] = set ()
331354 for i in range (first_row_group_id , last_row_group_id + 1 ):
332- rg_pa_table , rg_truncated_columns = row_group_readers [i ].read_truncated_binary (
333- columns , max_binary_length = max_binary_length
355+ rg_pa_table = row_group_readers [i ].read (columns )
356+ rg_pa_table , rg_truncated_columns = truncate_binary_columns (
357+ rg_pa_table , max_binary_length , self .features
334358 )
335359 pa_tables .append (rg_pa_table )
336360 truncated_columns |= set (rg_truncated_columns )
337361 pa_table = pa .concat_tables (pa_tables )
338362 except ArrowInvalid as err :
339363 raise SchemaMismatchError ("Parquet files have different schema." , err )
340- first_row_in_pa_table = row_group_offsets [first_row_group_id - 1 ] if first_row_group_id > 0 else 0
341- return pa_table .slice (parquet_offset - first_row_in_pa_table , length ), list (truncated_columns )
342-
343- def query (self , offset : int , length : int ) -> pa .Table :
344- """Query the parquet files
345-
346- Note that this implementation will always read at least one row group, to get the list of columns and always
347- have the same schema, even if the requested rows are invalid (out of range).
348-
349- Args:
350- offset (`int`): The first row to read.
351- length (`int`): The number of rows to read.
352-
353- Raises:
354- [`TooBigRows`]: if the arrow data from the parquet row groups is bigger than max_arrow_data_in_memory
355-
356- Returns:
357- `pa.Table`: The requested rows.
358- """
359- with StepProfiler (
360- method = "parquet_index_with_metadata.query" , step = "get the parquet files than contain the requested rows"
361- ):
362- parquet_file_offsets = np .cumsum (self .num_rows )
363364
364- last_row_in_parquet = parquet_file_offsets [- 1 ] - 1
365- first_row = min (offset , last_row_in_parquet )
366- last_row = min (offset + length - 1 , last_row_in_parquet )
367- first_parquet_file_id , last_parquet_file_id = np .searchsorted (
368- parquet_file_offsets , [first_row , last_row ], side = "right"
369- )
370- parquet_offset = (
371- offset - parquet_file_offsets [first_parquet_file_id - 1 ] if first_parquet_file_id > 0 else offset
372- )
373- urls = self .parquet_files_urls [first_parquet_file_id : last_parquet_file_id + 1 ] # noqa: E203
374- metadata_paths = self .metadata_paths [first_parquet_file_id : last_parquet_file_id + 1 ] # noqa: E203
375- num_bytes = self .num_bytes [first_parquet_file_id : last_parquet_file_id + 1 ] # noqa: E203
376-
377- with StepProfiler (
378- method = "parquet_index_with_metadata.query" , step = "load the remote parquet files using metadata from disk"
379- ):
380- parquet_files = [
381- pq .ParquetFile (
382- HTTPFile (
383- self .httpfs ,
384- url ,
385- session = self .httpfs_session ,
386- size = size ,
387- loop = self .httpfs .loop ,
388- cache_type = None ,
389- ** self .httpfs .kwargs ,
390- ),
391- metadata = pq .read_metadata (metadata_path ),
392- pre_buffer = True ,
393- )
394- for url , metadata_path , size in zip (urls , metadata_paths , num_bytes )
395- ]
396-
397- with StepProfiler (
398- method = "parquet_index_with_metadata.query" , step = "get the row groups than contain the requested rows"
399- ):
400- row_group_offsets = np .cumsum (
401- [
402- parquet_file .metadata .row_group (group_id ).num_rows
403- for parquet_file in parquet_files
404- for group_id in range (parquet_file .metadata .num_row_groups )
405- ]
406- )
407- row_group_readers = [
408- RowGroupReader (parquet_file = parquet_file , group_id = group_id , features = self .features )
409- for parquet_file in parquet_files
410- for group_id in range (parquet_file .metadata .num_row_groups )
411- ]
412-
413- if len (row_group_offsets ) == 0 or row_group_offsets [- 1 ] == 0 : # if the dataset is empty
414- if offset < 0 :
415- raise IndexError ("Offset must be non-negative" )
416- return cast_table_to_schema (parquet_files [0 ].read (), self .features .arrow_schema )
417-
418- last_row_in_parquet = row_group_offsets [- 1 ] - 1
419- first_row = min (parquet_offset , last_row_in_parquet )
420- last_row = min (parquet_offset + length - 1 , last_row_in_parquet )
421-
422- first_row_group_id , last_row_group_id = np .searchsorted (
423- row_group_offsets , [first_row , last_row ], side = "right"
424- )
365+ return pa_table , list (truncated_columns )
425366
367+ def _read_without_binary (
368+ self , row_group_readers : list [RowGroupReader ], first_row_group_id : int , last_row_group_id : int
369+ ) -> tuple [pa .Table , list [str ]]:
426370 with StepProfiler (
427371 method = "parquet_index_with_metadata.row_groups_size_check" , step = "check if the rows can fit in memory"
428372 ):
@@ -443,8 +387,8 @@ def query(self, offset: int, length: int) -> pa.Table:
443387 )
444388 except ArrowInvalid as err :
445389 raise SchemaMismatchError ("Parquet files have different schema." , err )
446- first_row_in_pa_table = row_group_offsets [ first_row_group_id - 1 ] if first_row_group_id > 0 else 0
447- return pa_table . slice ( parquet_offset - first_row_in_pa_table , length )
390+
391+ return pa_table , []
448392
449393 @staticmethod
450394 def from_parquet_metadata_items (
@@ -458,40 +402,31 @@ def from_parquet_metadata_items(
458402 raise EmptyParquetMetadataError ("No parquet files found." )
459403
460404 partial = parquet_export_is_partial (parquet_file_metadata_items [0 ]["url" ])
405+ metadata_dir = Path (parquet_metadata_directory )
461406
462407 with StepProfiler (
463408 method = "parquet_index_with_metadata.from_parquet_metadata_items" ,
464409 step = "get the index from parquet metadata" ,
465410 ):
466411 try :
467- parquet_files_metadata = sorted (
468- parquet_file_metadata_items , key = lambda parquet_file_metadata : parquet_file_metadata ["filename" ]
469- )
470- parquet_files_urls = [parquet_file_metadata ["url" ] for parquet_file_metadata in parquet_files_metadata ]
471- metadata_paths = [
472- os .path .join (parquet_metadata_directory , parquet_file_metadata ["parquet_metadata_subpath" ])
473- for parquet_file_metadata in parquet_files_metadata
474- ]
475- num_bytes = [parquet_file_metadata ["size" ] for parquet_file_metadata in parquet_files_metadata ]
476- num_rows = [parquet_file_metadata ["num_rows" ] for parquet_file_metadata in parquet_files_metadata ]
412+ files = sorted (parquet_file_metadata_items , key = lambda f : f ["filename" ])
477413 except Exception as e :
478414 raise ParquetResponseFormatError (f"Could not parse the list of parquet files: { e } " ) from e
479415
480416 with StepProfiler (
481417 method = "parquet_index_with_metadata.from_parquet_metadata_items" , step = "get the dataset's features"
482418 ):
483419 if features is None : # config-parquet version<6 didn't have features
484- features = Features .from_arrow_schema (pq .read_schema (metadata_paths [0 ]))
420+ first_arrow_schema = pq .read_schema (metadata_dir / files [0 ]["parquet_metadata_subpath" ])
421+ features = Features .from_arrow_schema (first_arrow_schema )
485422
486423 return ParquetIndexWithMetadata (
424+ files = files ,
487425 features = features ,
488- parquet_files_urls = parquet_files_urls ,
489- metadata_paths = metadata_paths ,
490- num_bytes = num_bytes ,
491- num_rows = num_rows ,
492426 httpfs = httpfs ,
493427 max_arrow_data_in_memory = max_arrow_data_in_memory ,
494428 partial = partial ,
429+ metadata_dir = metadata_dir ,
495430 )
496431
497432
@@ -551,28 +486,7 @@ def _init_parquet_index(
551486
552487 # note that this cache size is global for the class, not per instance
553488 @lru_cache (maxsize = 1 )
554- def query (self , offset : int , length : int ) -> pa .Table :
555- """Query the parquet files
556-
557- Note that this implementation will always read at least one row group, to get the list of columns and always
558- have the same schema, even if the requested rows are invalid (out of range).
559-
560- Args:
561- offset (`int`): The first row to read.
562- length (`int`): The number of rows to read.
563-
564- Returns:
565- `pa.Table`: The requested rows.
566- """
567- logging .info (
568- f"Query { type (self .parquet_index ).__name__ } for dataset={ self .dataset } , config={ self .config } ,"
569- f" split={ self .split } , offset={ offset } , length={ length } "
570- )
571- return self .parquet_index .query (offset = offset , length = length )
572-
573- # note that this cache size is global for the class, not per instance
574- @lru_cache (maxsize = 1 )
575- def query_truncated_binary (self , offset : int , length : int ) -> tuple [pa .Table , list [str ]]:
489+ def query (self , offset : int , length : int ) -> tuple [pa .Table , list [str ]]:
576490 """Query the parquet files
577491
578492 Note that this implementation will always read at least one row group, to get the list of columns and always
@@ -590,4 +504,4 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
590504 f"Query { type (self .parquet_index ).__name__ } for dataset={ self .dataset } , config={ self .config } ,"
591505 f" split={ self .split } , offset={ offset } , length={ length } , with truncated binary"
592506 )
593- return self .parquet_index .query_truncated_binary (offset = offset , length = length )
507+ return self .parquet_index .query (offset = offset , length = length )
0 commit comments