1717from fsspec .implementations .http import HTTPFileSystem
1818
1919from libcommon .parquet_utils import (
20- Indexer ,
2120 ParquetIndexWithMetadata ,
2221 RowsIndex ,
2322 SchemaMismatchError ,
@@ -346,56 +345,25 @@ def dataset_image_with_config_parquet() -> dict[str, Any]:
346345 return config_parquet_content
347346
348347
348+ # TODO(kszucs): this fixture is used in a single test case, but the tests starts
349+ # to fail if I move the index creation there.
349350@pytest .fixture
350351def rows_index_with_parquet_metadata (
351- indexer : Indexer ,
352352 ds_sharded : Dataset ,
353353 ds_sharded_fs : AbstractFileSystem ,
354354 dataset_sharded_with_config_parquet_metadata : dict [str , Any ],
355- ) -> Generator [RowsIndex , None , None ]:
356- with ds_sharded_fs .open ("default/train/0003.parquet" ) as f :
357- with patch ("libcommon.parquet_utils.HTTPFile" , return_value = f ):
358- yield indexer .get_rows_index ("ds_sharded" , "default" , "train" )
359-
360-
361- @pytest .fixture
362- def rows_index_with_empty_dataset (
363- indexer : Indexer ,
364- ds_empty : Dataset ,
365- ds_empty_fs : AbstractFileSystem ,
366- dataset_empty_with_config_parquet_metadata : dict [str , Any ],
367- ) -> Generator [RowsIndex , None , None ]:
368- with ds_empty_fs .open ("default/train/0000.parquet" ) as f :
369- with patch ("libcommon.parquet_utils.HTTPFile" , return_value = f ):
370- yield indexer .get_rows_index ("ds_empty" , "default" , "train" )
371-
372-
373- @pytest .fixture
374- def rows_index_with_too_big_rows (
375355 parquet_metadata_directory : StrPath ,
376- ds_sharded : Dataset ,
377- ds_sharded_fs : AbstractFileSystem ,
378- dataset_sharded_with_config_parquet_metadata : dict [str , Any ],
379356) -> Generator [RowsIndex , None , None ]:
380- indexer = Indexer (
381- parquet_metadata_directory = parquet_metadata_directory ,
382- httpfs = HTTPFileSystem (),
383- max_arrow_data_in_memory = 1 ,
384- )
385357 with ds_sharded_fs .open ("default/train/0003.parquet" ) as f :
386358 with patch ("libcommon.parquet_utils.HTTPFile" , return_value = f ):
387- yield indexer .get_rows_index ("ds_sharded" , "default" , "train" )
388-
389-
390- @pytest .fixture
391- def indexer (
392- parquet_metadata_directory : StrPath ,
393- ) -> Indexer :
394- return Indexer (
395- parquet_metadata_directory = parquet_metadata_directory ,
396- httpfs = HTTPFileSystem (),
397- max_arrow_data_in_memory = 9999999999 ,
398- )
359+ yield RowsIndex (
360+ dataset = "ds_sharded" ,
361+ config = "default" ,
362+ split = "train" ,
363+ parquet_metadata_directory = parquet_metadata_directory ,
364+ httpfs = HTTPFileSystem (),
365+ max_arrow_data_in_memory = 9999999999 ,
366+ )
399367
400368
401369def test_parquet_export_is_partial () -> None :
@@ -411,11 +379,22 @@ def test_parquet_export_is_partial() -> None:
411379
412380
413381def test_indexer_get_rows_index_with_parquet_metadata (
414- indexer : Indexer , ds : Dataset , ds_fs : AbstractFileSystem , dataset_with_config_parquet_metadata : dict [str , Any ]
382+ ds : Dataset ,
383+ ds_fs : AbstractFileSystem ,
384+ parquet_metadata_directory : StrPath ,
385+ dataset_with_config_parquet_metadata : dict [str , Any ],
415386) -> None :
416387 with ds_fs .open ("default/train/0000.parquet" ) as f :
417388 with patch ("libcommon.parquet_utils.HTTPFile" , return_value = f ):
418- index = indexer .get_rows_index ("ds" , "default" , "train" )
389+ index = RowsIndex (
390+ dataset = "ds" ,
391+ config = "default" ,
392+ split = "train" ,
393+ parquet_metadata_directory = parquet_metadata_directory ,
394+ httpfs = HTTPFileSystem (),
395+ max_arrow_data_in_memory = 9999999999 ,
396+ )
397+
419398 assert isinstance (index .parquet_index , ParquetIndexWithMetadata )
420399 assert index .parquet_index .features == ds .features
421400 assert index .parquet_index .num_rows == [len (ds )]
@@ -429,15 +408,23 @@ def test_indexer_get_rows_index_with_parquet_metadata(
429408
430409
431410def test_indexer_get_rows_index_sharded_with_parquet_metadata (
432- indexer : Indexer ,
433411 ds : Dataset ,
434412 ds_sharded : Dataset ,
435413 ds_sharded_fs : AbstractFileSystem ,
414+ parquet_metadata_directory : StrPath ,
436415 dataset_sharded_with_config_parquet_metadata : dict [str , Any ],
437416) -> None :
438417 with ds_sharded_fs .open ("default/train/0003.parquet" ) as f :
439418 with patch ("libcommon.parquet_utils.HTTPFile" , return_value = f ):
440- index = indexer .get_rows_index ("ds_sharded" , "default" , "train" )
419+ index = RowsIndex (
420+ dataset = "ds_sharded" ,
421+ config = "default" ,
422+ split = "train" ,
423+ parquet_metadata_directory = parquet_metadata_directory ,
424+ httpfs = HTTPFileSystem (),
425+ max_arrow_data_in_memory = 9999999999 ,
426+ )
427+
441428 assert isinstance (index .parquet_index , ParquetIndexWithMetadata )
442429 assert index .parquet_index .features == ds_sharded .features
443430 assert index .parquet_index .num_rows == [len (ds )] * 4
@@ -463,28 +450,67 @@ def test_rows_index_query_with_parquet_metadata(
463450 rows_index_with_parquet_metadata .query (offset = - 1 , length = 2 )
464451
465452
466- def test_rows_index_query_with_too_big_rows (rows_index_with_too_big_rows : RowsIndex , ds_sharded : Dataset ) -> None :
453+ def test_rows_index_query_with_too_big_rows (
454+ parquet_metadata_directory : StrPath ,
455+ ds_sharded : Dataset ,
456+ ds_sharded_fs : AbstractFileSystem ,
457+ dataset_sharded_with_config_parquet_metadata : dict [str , Any ],
458+ ) -> None :
459+ with ds_sharded_fs .open ("default/train/0003.parquet" ) as f :
460+ with patch ("libcommon.parquet_utils.HTTPFile" , return_value = f ):
461+ index = RowsIndex (
462+ dataset = "ds_sharded" ,
463+ config = "default" ,
464+ split = "train" ,
465+ parquet_metadata_directory = parquet_metadata_directory ,
466+ httpfs = HTTPFileSystem (),
467+ max_arrow_data_in_memory = 1 ,
468+ )
469+
467470 with pytest .raises (TooBigRows ):
468- rows_index_with_too_big_rows .query (offset = 0 , length = 3 )
471+ index .query (offset = 0 , length = 3 )
469472
470473
471- def test_rows_index_query_with_empty_dataset (rows_index_with_empty_dataset : RowsIndex , ds_sharded : Dataset ) -> None :
472- assert isinstance (rows_index_with_empty_dataset .parquet_index , ParquetIndexWithMetadata )
473- assert rows_index_with_empty_dataset .query (offset = 0 , length = 1 ).to_pydict () == ds_sharded [:0 ]
474+ def test_rows_index_query_with_empty_dataset (
475+ ds_empty : Dataset ,
476+ ds_empty_fs : AbstractFileSystem ,
477+ dataset_empty_with_config_parquet_metadata : dict [str , Any ],
478+ parquet_metadata_directory : StrPath ,
479+ ) -> None :
480+ with ds_empty_fs .open ("default/train/0000.parquet" ) as f :
481+ with patch ("libcommon.parquet_utils.HTTPFile" , return_value = f ):
482+ index = RowsIndex (
483+ dataset = "ds_empty" ,
484+ config = "default" ,
485+ split = "train" ,
486+ parquet_metadata_directory = parquet_metadata_directory ,
487+ httpfs = HTTPFileSystem (),
488+ max_arrow_data_in_memory = 9999999999 ,
489+ )
490+
491+ assert isinstance (index .parquet_index , ParquetIndexWithMetadata )
492+ assert index .query (offset = 0 , length = 1 ).to_pydict () == ds_empty [:0 ]
474493 with pytest .raises (IndexError ):
475- rows_index_with_empty_dataset .query (offset = - 1 , length = 2 )
494+ index .query (offset = - 1 , length = 2 )
476495
477496
478497def test_indexer_schema_mistmatch_error (
479- indexer : Indexer ,
480498 ds_sharded_fs : AbstractFileSystem ,
481499 ds_sharded_fs_with_different_schema : AbstractFileSystem ,
482500 dataset_sharded_with_config_parquet_metadata : dict [str , Any ],
501+ parquet_metadata_directory : StrPath ,
483502) -> None :
484503 with ds_sharded_fs_with_different_schema .open ("default/train/0000.parquet" ) as first_parquet :
485504 with ds_sharded_fs_with_different_schema .open ("default/train/0001.parquet" ) as second_parquet :
486505 with patch ("libcommon.parquet_utils.HTTPFile" , side_effect = [first_parquet , second_parquet ]):
487- index = indexer .get_rows_index ("ds_sharded" , "default" , "train" )
506+ index = RowsIndex (
507+ dataset = "ds_sharded" ,
508+ config = "default" ,
509+ split = "train" ,
510+ parquet_metadata_directory = parquet_metadata_directory ,
511+ httpfs = HTTPFileSystem (),
512+ max_arrow_data_in_memory = 9999999999 ,
513+ )
488514 with pytest .raises (SchemaMismatchError ):
489515 index .query (offset = 0 , length = 3 )
490516
0 commit comments