Skip to content

Commit 709edfe

Browse files
authored
Update test_load.py
1 parent 761dbb6 commit 709edfe

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

tests/test_load.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,77 @@ def test_load_dataset_with_unsupported_extensions(text_dir_with_unsupported_exte
10481048
assert ds.num_rows == 4
10491049

10501050

1051+
def test_load_dataset_specific_splits(data_dir):
1052+
with tempfile.TemporaryDirectory() as tmp_dir:
1053+
with load_dataset(data_dir, split="train", cache_dir=tmp_dir) as dataset:
1054+
assert isinstance(dataset, Dataset)
1055+
assert len(dataset) > 0
1056+
1057+
processed_dataset_dir = load_dataset_builder(data_dir, cache_dir=tmp_dir).cache_dir
1058+
arrow_files = Path(processed_dataset_dir).glob("*.arrow")
1059+
assert all(arrow_file.name.split("-", 1)[1].startswith("train") for arrow_file in arrow_files)
1060+
1061+
with load_dataset(data_dir, split="test", cache_dir=tmp_dir) as dataset:
1062+
assert isinstance(dataset, Dataset)
1063+
assert len(dataset) > 0
1064+
1065+
arrow_files = Path(processed_dataset_dir).glob("*.arrow")
1066+
assert all(arrow_file.name.split("-", 1)[1].startswith(("train", "test")) for arrow_file in arrow_files)
1067+
1068+
with pytest.raises(ValueError):
1069+
load_dataset(data_dir, split="non-existing-split", cache_dir=tmp_dir)
1070+
1071+
1072+
def test_load_dataset_specific_splits_then_full(data_dir):
1073+
with tempfile.TemporaryDirectory() as tmp_dir:
1074+
with load_dataset(data_dir, split="train", cache_dir=tmp_dir) as dataset:
1075+
assert isinstance(dataset, Dataset)
1076+
assert len(dataset) > 0
1077+
1078+
processed_dataset_dir = load_dataset_builder(data_dir, cache_dir=tmp_dir).cache_dir
1079+
arrow_files = Path(processed_dataset_dir).glob("*.arrow")
1080+
assert all(arrow_file.name.split("-", 1)[1].startswith("train") for arrow_file in arrow_files)
1081+
1082+
with load_dataset(data_dir, cache_dir=tmp_dir) as dataset:
1083+
assert isinstance(dataset, DatasetDict)
1084+
assert len(dataset) > 0
1085+
assert "train" in dataset
1086+
assert "test" in dataset
1087+
dataset_splits = list(dataset)
1088+
1089+
arrow_files = Path(processed_dataset_dir).glob("*.arrow")
1090+
assert all(arrow_file.name.split("-", 1)[1].startswith(tuple(dataset_splits)) for arrow_file in arrow_files)
1091+
1092+
1093+
@pytest.mark.integration
1094+
def test_loading_from_the_datasets_hub():
1095+
with tempfile.TemporaryDirectory() as tmp_dir:
1096+
@@ -1449,6 +1491,28 @@ def test_loading_from_the_datasets_hub():
1097+
assert len(dataset["validation"]) == 3
1098+
1099+
1100+
@pytest.mark.integration
1101+
def test_loading_from_dataset_from_hub_specific_splits():
1102+
with tempfile.TemporaryDirectory() as tmp_dir:
1103+
with load_dataset(SAMPLE_DATASET_IDENTIFIER2, split="train", cache_dir=tmp_dir) as dataset:
1104+
assert isinstance(dataset, Dataset)
1105+
assert len(dataset) > 0
1106+
1107+
processed_dataset_dir = load_dataset_builder(SAMPLE_DATASET_IDENTIFIER2, cache_dir=tmp_dir).cache_dir
1108+
arrow_files = Path(processed_dataset_dir).glob("*.arrow")
1109+
assert all(arrow_file.name.split("-", 1)[1].startswith("train") for arrow_file in arrow_files)
1110+
1111+
with load_dataset(SAMPLE_DATASET_IDENTIFIER2, split="test", cache_dir=tmp_dir) as dataset:
1112+
assert isinstance(dataset, Dataset)
1113+
assert len(dataset) > 0
1114+
1115+
arrow_files = Path(processed_dataset_dir).glob("*.arrow")
1116+
assert all(arrow_file.name.split("-", 1)[1].startswith(("train", "test")) for arrow_file in arrow_files)
1117+
1118+
with pytest.raises(ValueError):
1119+
load_dataset(SAMPLE_DATASET_IDENTIFIER2, split="non-existing-split", cache_dir=tmp_dir)
1120+
1121+
10511122
@pytest.mark.integration
10521123
def test_loading_from_the_datasets_hub_with_token():
10531124
true_request = requests.Session().request

0 commit comments

Comments
 (0)