Skip to content

Commit 3622bbd

Browse files
committed
Added datasets related to issue #6832
1 parent b36f257 commit 3622bbd

File tree

8 files changed

+73
-17
lines changed

8 files changed

+73
-17
lines changed

src/datasets/arrow_reader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ def read_table(filename, in_memory=False) -> Table:
325325
Returns:
326326
pyarrow.Table
327327
"""
328+
os.makedirs(os.path.dirname(filename), exist_ok=True)
328329
table_cls = InMemoryTable if in_memory else MemoryMappedTable
329330
return table_cls.from_file(filename)
330331

src/datasets/builder.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
import contextlib
2020
import copy
2121
import fnmatch
22+
import hashlib
2223
import inspect
24+
import json
2325
import os
2426
import posixpath
2527
import shutil
@@ -89,14 +91,26 @@
8991
from .utils.sharding import _number_of_shards_in_gen_kwargs, _split_gen_kwargs
9092
from .utils.track import tracked_list
9193

92-
9394
if TYPE_CHECKING:
9495
from .load import DatasetModule
9596

9697

9798
logger = logging.get_logger(__name__)
9899

99100

101+
def hash_dict(d):
102+
"""Hash a dictionary into a short hex string (8 characters)."""
103+
def sanitize(obj):
104+
if isinstance(obj, dict):
105+
return {str(k): sanitize(v) for k, v in obj.items()}
106+
elif isinstance(obj, (list, tuple)):
107+
return [sanitize(i) for i in obj]
108+
else:
109+
return str(obj)
110+
normalized = json.dumps(sanitize(d), sort_keys=True)
111+
return hashlib.sha256(normalized.encode("utf-8")).hexdigest()[:8]
112+
113+
100114
class InvalidConfigName(ValueError):
101115
pass
102116

@@ -391,7 +405,7 @@ def __init__(
391405
if not is_remote_url(self._cache_dir_root):
392406
os.makedirs(self._cache_dir_root, exist_ok=True)
393407
lock_path = os.path.join(
394-
self._cache_dir_root, Path(self._cache_dir).as_posix().replace("/", "_") + ".lock"
408+
self._cache_dir_root, Path(self._relative_data_dir()).as_posix().replace("/", "_") + ".lock"
395409
)
396410
with FileLock(lock_path):
397411
if os.path.exists(self._cache_dir): # check if data exist
@@ -577,11 +591,27 @@ def _create_builder_config(
577591
download_config=DownloadConfig(token=self.token, storage_options=self.storage_options),
578592
)
579593

580-
# compute the config id that is going to be used for caching
594+
runtime_only_config_keys = {"drop_metadata", "drop_labels", "drop_audio", "drop_text", "drop_images"}
595+
hashable_config_kwargs = {k: v for k, v in config_kwargs.items() if k not in runtime_only_config_keys}
596+
581597
config_id = builder_config.create_config_id(
582-
config_kwargs,
598+
hashable_config_kwargs,
583599
custom_features=custom_features,
584600
)
601+
602+
if (
603+
builder_config.name in self.builder_configs
604+
and builder_config != self.builder_configs[builder_config.name]
605+
):
606+
builder_config.name = f"custom-{hash_dict(hashable_config_kwargs)}"
607+
while builder_config.name in self.builder_configs:
608+
builder_config.name += "-x"
609+
config_id = builder_config.create_config_id(
610+
hashable_config_kwargs,
611+
custom_features=custom_features,
612+
)
613+
logger.info(f"Renamed conflicting config to: {builder_config.name}")
614+
585615
is_custom = (config_id not in self.builder_configs) and config_id != "default"
586616
if is_custom:
587617
logger.info(f"Using custom data configuration {config_id}")
@@ -1659,15 +1689,19 @@ def _prepare_split_single(
16591689
shard_id = 0
16601690
num_examples_progress_update = 0
16611691
try:
1692+
path = fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}")
1693+
logger.debug("Creating directory: %s", os.path.dirname(path))
1694+
os.makedirs(os.path.dirname(path), exist_ok=True)
16621695
writer = writer_class(
16631696
features=self.info.features,
1664-
path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
1697+
path=path,
16651698
writer_batch_size=self._writer_batch_size,
16661699
hash_salt=split_info.name,
16671700
check_duplicates=check_duplicate_keys,
16681701
storage_options=self._fs.storage_options,
16691702
embed_local_files=embed_local_files,
16701703
)
1704+
16711705
try:
16721706
_time = time.time()
16731707
for key, record in generator:
@@ -1678,9 +1712,12 @@ def _prepare_split_single(
16781712
total_num_examples += num_examples
16791713
total_num_bytes += num_bytes
16801714
shard_id += 1
1715+
path = fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}")
1716+
logger.debug("Creating directory: %s", os.path.dirname(path))
1717+
os.makedirs(os.path.dirname(path), exist_ok=True)
16811718
writer = writer_class(
16821719
features=writer._features,
1683-
path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
1720+
path=path,
16841721
writer_batch_size=self._writer_batch_size,
16851722
hash_salt=split_info.name,
16861723
check_duplicates=check_duplicate_keys,
@@ -1908,9 +1945,12 @@ def _prepare_split_single(
19081945
shard_id = 0
19091946
num_examples_progress_update = 0
19101947
try:
1948+
path = fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}")
1949+
logger.debug("Creating directory: %s", os.path.dirname(path))
1950+
os.makedirs(os.path.dirname(path), exist_ok=True)
19111951
writer = writer_class(
19121952
features=self.info.features,
1913-
path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
1953+
path=path,
19141954
writer_batch_size=self._writer_batch_size,
19151955
storage_options=self._fs.storage_options,
19161956
embed_local_files=embed_local_files,
@@ -1925,9 +1965,12 @@ def _prepare_split_single(
19251965
total_num_examples += num_examples
19261966
total_num_bytes += num_bytes
19271967
shard_id += 1
1968+
path = fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}")
1969+
logger.debug("Creating directory: %s", os.path.dirname(path))
1970+
os.makedirs(os.path.dirname(path), exist_ok=True)
19281971
writer = writer_class(
19291972
features=writer._features,
1930-
path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
1973+
path=path,
19311974
writer_batch_size=self._writer_batch_size,
19321975
storage_options=self._fs.storage_options,
19331976
embed_local_files=embed_local_files,

src/datasets/load.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,8 +1161,11 @@ def load_dataset_builder(
11611161
error_msg += f'\nFor example `data_files={{"train": "path/to/data/train/*.{example_extensions[0]}"}}`'
11621162
raise ValueError(error_msg)
11631163

1164+
runtime_only_config_keys = {"drop_metadata", "drop_labels", "drop_audio", "drop_text", "drop_images"}
1165+
hashable_config_kwargs = {k: v for k, v in config_kwargs.items() if k not in runtime_only_config_keys}
1166+
full_config_kwargs = config_kwargs.copy()
1167+
config_kwargs_for_config = hashable_config_kwargs.copy()
11641168
builder_cls = get_dataset_builder_class(dataset_module, dataset_name=dataset_name)
1165-
# Instantiate the dataset builder
11661169
builder_instance: DatasetBuilder = builder_cls(
11671170
cache_dir=cache_dir,
11681171
dataset_name=dataset_name,
@@ -1175,7 +1178,7 @@ def load_dataset_builder(
11751178
token=token,
11761179
storage_options=storage_options,
11771180
**builder_kwargs,
1178-
**config_kwargs,
1181+
**full_config_kwargs,
11791182
)
11801183
builder_instance._use_legacy_cache_dir_if_possible(dataset_module)
11811184

src/datasets/packaged_modules/cache/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import shutil
55
import time
66
from pathlib import Path
7-
from typing import Optional, Union
7+
from typing import List, Optional, Union
88

99
import pyarrow as pa
1010

src/datasets/packaged_modules/csv/csv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import itertools
22
from dataclasses import dataclass
3-
from typing import Any, Callable, Optional, Union
3+
from typing import Any, Callable, List, Optional, Union
44

55
import pandas as pd
66
import pyarrow as pa

src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import itertools
44
import os
55
from dataclasses import dataclass
6-
from typing import Any, Callable, Iterator, Optional, Union
6+
from typing import Any, Callable, Iterator, List, Optional, Union
77

88
import pandas as pd
99
import pyarrow as pa
@@ -71,6 +71,7 @@ def _available_splits(self) -> Optional[List[str]]:
7171
return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None
7272

7373
def _split_generators(self, dl_manager, splits: Optional[List[str]] = None):
74+
data_files = self.config.data_files
7475
if not self.config.data_files:
7576
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
7677
dl_manager.download_config.extract_on_the_fly = True
@@ -248,7 +249,14 @@ def _set_feature(feature):
248249
# before building the features
249250
if self.config.features is None:
250251
if add_metadata:
251-
self.info.features = metadata_features
252+
if self.config.drop_metadata and isinstance(metadata_features, dict):
253+
filtered = {
254+
k: v for k, v in metadata_features.items()
255+
if k == self.BASE_COLUMN_NAME # e.g. "image"
256+
}
257+
self.info.features = datasets.Features(filtered)
258+
else:
259+
self.info.features = metadata_features
252260
elif add_labels:
253261
self.info.features = datasets.Features(
254262
{

src/datasets/packaged_modules/parquet/parquet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import itertools
22
from dataclasses import dataclass
3-
from typing import Optional, Union
3+
from typing import List, Optional, Union
44

55
import pyarrow as pa
66
import pyarrow.dataset as ds

tests/test_load.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
)
5050

5151

52+
SAMPLE_DATASET_IDENTIFIER = "hf-internal-testing/librispeech_asr_dummy"
5253
SAMPLE_DATASET_IDENTIFIER2 = "hf-internal-testing/dataset_with_data_files" # only has data files
5354
SAMPLE_DATASET_IDENTIFIER3 = "hf-internal-testing/multi_dir_dataset" # has multiple data directories
5455
SAMPLE_DATASET_IDENTIFIER4 = "hf-internal-testing/imagefolder_with_metadata" # imagefolder with a metadata file inside the train/test directories
@@ -1093,8 +1094,8 @@ def test_load_dataset_specific_splits_then_full(data_dir):
10931094
@pytest.mark.integration
10941095
def test_loading_from_the_datasets_hub():
10951096
with tempfile.TemporaryDirectory() as tmp_dir:
1096-
@@ -1449,6 +1491,28 @@ def test_loading_from_the_datasets_hub():
1097-
assert len(dataset["validation"]) == 3
1097+
dataset = load_dataset(SAMPLE_DATASET_IDENTIFIER, cache_dir=tmp_dir)
1098+
assert len(dataset["validation"]) >= 3
10981099

10991100

11001101
@pytest.mark.integration

0 commit comments

Comments
 (0)