@@ -1048,6 +1048,77 @@ def test_load_dataset_with_unsupported_extensions(text_dir_with_unsupported_exte
1048
1048
assert ds .num_rows == 4
1049
1049
1050
1050
1051
+ def test_load_dataset_specific_splits (data_dir ):
1052
+ with tempfile .TemporaryDirectory () as tmp_dir :
1053
+ with load_dataset (data_dir , split = "train" , cache_dir = tmp_dir ) as dataset :
1054
+ assert isinstance (dataset , Dataset )
1055
+ assert len (dataset ) > 0
1056
+
1057
+ processed_dataset_dir = load_dataset_builder (data_dir , cache_dir = tmp_dir ).cache_dir
1058
+ arrow_files = Path (processed_dataset_dir ).glob ("*.arrow" )
1059
+ assert all (arrow_file .name .split ("-" , 1 )[1 ].startswith ("train" ) for arrow_file in arrow_files )
1060
+
1061
+ with load_dataset (data_dir , split = "test" , cache_dir = tmp_dir ) as dataset :
1062
+ assert isinstance (dataset , Dataset )
1063
+ assert len (dataset ) > 0
1064
+
1065
+ arrow_files = Path (processed_dataset_dir ).glob ("*.arrow" )
1066
+ assert all (arrow_file .name .split ("-" , 1 )[1 ].startswith (("train" , "test" )) for arrow_file in arrow_files )
1067
+
1068
+ with pytest .raises (ValueError ):
1069
+ load_dataset (data_dir , split = "non-existing-split" , cache_dir = tmp_dir )
1070
+
1071
+
1072
+ def test_load_dataset_specific_splits_then_full (data_dir ):
1073
+ with tempfile .TemporaryDirectory () as tmp_dir :
1074
+ with load_dataset (data_dir , split = "train" , cache_dir = tmp_dir ) as dataset :
1075
+ assert isinstance (dataset , Dataset )
1076
+ assert len (dataset ) > 0
1077
+
1078
+ processed_dataset_dir = load_dataset_builder (data_dir , cache_dir = tmp_dir ).cache_dir
1079
+ arrow_files = Path (processed_dataset_dir ).glob ("*.arrow" )
1080
+ assert all (arrow_file .name .split ("-" , 1 )[1 ].startswith ("train" ) for arrow_file in arrow_files )
1081
+
1082
+ with load_dataset (data_dir , cache_dir = tmp_dir ) as dataset :
1083
+ assert isinstance (dataset , DatasetDict )
1084
+ assert len (dataset ) > 0
1085
+ assert "train" in dataset
1086
+ assert "test" in dataset
1087
+ dataset_splits = list (dataset )
1088
+
1089
+ arrow_files = Path (processed_dataset_dir ).glob ("*.arrow" )
1090
+ assert all (arrow_file .name .split ("-" , 1 )[1 ].startswith (tuple (dataset_splits )) for arrow_file in arrow_files )
1091
+
1092
+
1093
+ @pytest .mark .integration
1094
+ def test_loading_from_the_datasets_hub ():
1095
+ with tempfile .TemporaryDirectory () as tmp_dir :
1096
+ @@ - 1449 ,6 + 1491 ,28 @@ def test_loading_from_the_datasets_hub ():
1097
+ assert len (dataset ["validation" ]) == 3
1098
+
1099
+
1100
+ @pytest .mark .integration
1101
+ def test_loading_from_dataset_from_hub_specific_splits ():
1102
+ with tempfile .TemporaryDirectory () as tmp_dir :
1103
+ with load_dataset (SAMPLE_DATASET_IDENTIFIER2 , split = "train" , cache_dir = tmp_dir ) as dataset :
1104
+ assert isinstance (dataset , Dataset )
1105
+ assert len (dataset ) > 0
1106
+
1107
+ processed_dataset_dir = load_dataset_builder (SAMPLE_DATASET_IDENTIFIER2 , cache_dir = tmp_dir ).cache_dir
1108
+ arrow_files = Path (processed_dataset_dir ).glob ("*.arrow" )
1109
+ assert all (arrow_file .name .split ("-" , 1 )[1 ].startswith ("train" ) for arrow_file in arrow_files )
1110
+
1111
+ with load_dataset (SAMPLE_DATASET_IDENTIFIER2 , split = "test" , cache_dir = tmp_dir ) as dataset :
1112
+ assert isinstance (dataset , Dataset )
1113
+ assert len (dataset ) > 0
1114
+
1115
+ arrow_files = Path (processed_dataset_dir ).glob ("*.arrow" )
1116
+ assert all (arrow_file .name .split ("-" , 1 )[1 ].startswith (("train" , "test" )) for arrow_file in arrow_files )
1117
+
1118
+ with pytest .raises (ValueError ):
1119
+ load_dataset (SAMPLE_DATASET_IDENTIFIER2 , split = "non-existing-split" , cache_dir = tmp_dir )
1120
+
1121
+
1051
1122
@pytest .mark .integration
1052
1123
def test_loading_from_the_datasets_hub_with_token ():
1053
1124
true_request = requests .Session ().request
0 commit comments