@@ -1048,6 +1048,77 @@ def test_load_dataset_with_unsupported_extensions(text_dir_with_unsupported_exte
10481048 assert ds .num_rows == 4
10491049
10501050
1051+ def test_load_dataset_specific_splits (data_dir ):
1052+ with tempfile .TemporaryDirectory () as tmp_dir :
1053+ with load_dataset (data_dir , split = "train" , cache_dir = tmp_dir ) as dataset :
1054+ assert isinstance (dataset , Dataset )
1055+ assert len (dataset ) > 0
1056+
1057+ processed_dataset_dir = load_dataset_builder (data_dir , cache_dir = tmp_dir ).cache_dir
1058+ arrow_files = Path (processed_dataset_dir ).glob ("*.arrow" )
1059+ assert all (arrow_file .name .split ("-" , 1 )[1 ].startswith ("train" ) for arrow_file in arrow_files )
1060+
1061+ with load_dataset (data_dir , split = "test" , cache_dir = tmp_dir ) as dataset :
1062+ assert isinstance (dataset , Dataset )
1063+ assert len (dataset ) > 0
1064+
1065+ arrow_files = Path (processed_dataset_dir ).glob ("*.arrow" )
1066+ assert all (arrow_file .name .split ("-" , 1 )[1 ].startswith (("train" , "test" )) for arrow_file in arrow_files )
1067+
1068+ with pytest .raises (ValueError ):
1069+ load_dataset (data_dir , split = "non-existing-split" , cache_dir = tmp_dir )
1070+
1071+
1072+ def test_load_dataset_specific_splits_then_full (data_dir ):
1073+ with tempfile .TemporaryDirectory () as tmp_dir :
1074+ with load_dataset (data_dir , split = "train" , cache_dir = tmp_dir ) as dataset :
1075+ assert isinstance (dataset , Dataset )
1076+ assert len (dataset ) > 0
1077+
1078+ processed_dataset_dir = load_dataset_builder (data_dir , cache_dir = tmp_dir ).cache_dir
1079+ arrow_files = Path (processed_dataset_dir ).glob ("*.arrow" )
1080+ assert all (arrow_file .name .split ("-" , 1 )[1 ].startswith ("train" ) for arrow_file in arrow_files )
1081+
1082+ with load_dataset (data_dir , cache_dir = tmp_dir ) as dataset :
1083+ assert isinstance (dataset , DatasetDict )
1084+ assert len (dataset ) > 0
1085+ assert "train" in dataset
1086+ assert "test" in dataset
1087+ dataset_splits = list (dataset )
1088+
1089+ arrow_files = Path (processed_dataset_dir ).glob ("*.arrow" )
1090+ assert all (arrow_file .name .split ("-" , 1 )[1 ].startswith (tuple (dataset_splits )) for arrow_file in arrow_files )
1091+
1092+
1093+ @pytest .mark .integration
1094+ def test_loading_from_the_datasets_hub ():
1095+ with tempfile .TemporaryDirectory () as tmp_dir :
1096+ @@ - 1449 ,6 + 1491 ,28 @@ def test_loading_from_the_datasets_hub ():
1097+ assert len (dataset ["validation" ]) == 3
1098+
1099+
1100+ @pytest .mark .integration
1101+ def test_loading_from_dataset_from_hub_specific_splits ():
1102+ with tempfile .TemporaryDirectory () as tmp_dir :
1103+ with load_dataset (SAMPLE_DATASET_IDENTIFIER2 , split = "train" , cache_dir = tmp_dir ) as dataset :
1104+ assert isinstance (dataset , Dataset )
1105+ assert len (dataset ) > 0
1106+
1107+ processed_dataset_dir = load_dataset_builder (SAMPLE_DATASET_IDENTIFIER2 , cache_dir = tmp_dir ).cache_dir
1108+ arrow_files = Path (processed_dataset_dir ).glob ("*.arrow" )
1109+ assert all (arrow_file .name .split ("-" , 1 )[1 ].startswith ("train" ) for arrow_file in arrow_files )
1110+
1111+ with load_dataset (SAMPLE_DATASET_IDENTIFIER2 , split = "test" , cache_dir = tmp_dir ) as dataset :
1112+ assert isinstance (dataset , Dataset )
1113+ assert len (dataset ) > 0
1114+
1115+ arrow_files = Path (processed_dataset_dir ).glob ("*.arrow" )
1116+ assert all (arrow_file .name .split ("-" , 1 )[1 ].startswith (("train" , "test" )) for arrow_file in arrow_files )
1117+
1118+ with pytest .raises (ValueError ):
1119+ load_dataset (SAMPLE_DATASET_IDENTIFIER2 , split = "non-existing-split" , cache_dir = tmp_dir )
1120+
1121+
10511122@pytest .mark .integration
10521123def test_loading_from_the_datasets_hub_with_token ():
10531124 true_request = requests .Session ().request
0 commit comments