Skip to content

Commit 896616c

Browse files
fix(webdataset): don't .lower() field_name (#7726)
* wds: lower everywhere * better: just use lower for checks * make style
1 parent b47e71c commit 896616c

File tree

2 files changed

+74
-14
lines changed

2 files changed

+74
-14
lines changed

src/datasets/packaged_modules/webdataset/webdataset.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,16 @@ def _get_pipeline_from_tar(cls, tar_path, tar_iterator):
4141
current_example = {}
4242
current_example["__key__"] = example_key
4343
current_example["__url__"] = tar_path
44-
current_example[field_name.lower()] = f.read()
45-
if field_name.split(".")[-1] in SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL:
46-
fs.write_bytes(filename, current_example[field_name.lower()])
44+
current_example[field_name] = f.read()
45+
if field_name.split(".")[-1].lower() in SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL:
46+
fs.write_bytes(filename, current_example[field_name])
4747
extracted_file_path = streaming_download_manager.extract(f"memory://{filename}")
4848
with fsspec.open(extracted_file_path) as f:
49-
current_example[field_name.lower()] = f.read()
49+
current_example[field_name] = f.read()
5050
fs.delete(filename)
51-
data_extension = xbasename(extracted_file_path).split(".")[-1]
51+
data_extension = xbasename(extracted_file_path).split(".")[-1].lower()
5252
else:
53-
data_extension = field_name.split(".")[-1]
53+
data_extension = field_name.split(".")[-1].lower()
5454
if data_extension in cls.DECODERS:
5555
current_example[field_name] = cls.DECODERS[data_extension](current_example[field_name])
5656
if current_example:
@@ -91,19 +91,15 @@ def _split_generators(self, dl_manager):
9191
inferred_arrow_schema = pa.concat_tables(pa_tables, promote_options="default").schema
9292
features = datasets.Features.from_arrow_schema(inferred_arrow_schema)
9393

94-
# Set Image types
9594
for field_name in first_examples[0]:
96-
extension = field_name.rsplit(".", 1)[-1]
95+
extension = field_name.rsplit(".", 1)[-1].lower()
96+
# Set Image types
9797
if extension in self.IMAGE_EXTENSIONS:
9898
features[field_name] = datasets.Image()
99-
# Set Audio types
100-
for field_name in first_examples[0]:
101-
extension = field_name.rsplit(".", 1)[-1]
99+
# Set Audio types
102100
if extension in self.AUDIO_EXTENSIONS:
103101
features[field_name] = datasets.Audio()
104-
# Set Video types
105-
for field_name in first_examples[0]:
106-
extension = field_name.rsplit(".", 1)[-1]
102+
# Set Video types
107103
if extension in self.VIDEO_EXTENSIONS:
108104
features[field_name] = datasets.Video()
109105
self.info.features = features

tests/packaged_modules/test_webdataset.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,27 @@ def image_wds_file(tmp_path, image_file):
3939
return str(filename)
4040

4141

42+
@pytest.fixture
43+
def upper_lower_case_file(tmp_path):
44+
tar_path = tmp_path / "file.tar"
45+
num_examples = 3
46+
variants = [
47+
("INFO1", "json"),
48+
("info2", "json"),
49+
("info3", "JSON"),
50+
("info3", "json"), # should probably remove if testing on a case insensitive filesystem
51+
]
52+
with tarfile.open(tar_path, "w") as tar:
53+
for example_idx in range(num_examples):
54+
example_name = f"{example_idx:05d}_{'a' if example_idx % 2 else 'A'}"
55+
for tag, ext in variants:
56+
caption_path = tmp_path / f"{example_name}.{tag}.{ext}"
57+
caption_text = {"caption": f"caption for {example_name}.{tag}.{ext}"}
58+
caption_path.write_text(json.dumps(caption_text), encoding="utf-8")
59+
tar.add(caption_path, arcname=f"{example_name}.{tag}.{ext}")
60+
return str(tar_path)
61+
62+
4263
@pytest.fixture
4364
def audio_wds_file(tmp_path, audio_file):
4465
json_file = tmp_path / "data.json"
@@ -133,6 +154,49 @@ def test_image_webdataset(image_wds_file):
133154
assert isinstance(decoded["jpg"], PIL.Image.Image)
134155

135156

157+
def test_upper_lower_case(upper_lower_case_file):
158+
variants = [
159+
("INFO1", "json"),
160+
("info2", "json"),
161+
("info3", "JSON"),
162+
("info3", "json"),
163+
]
164+
165+
data_files = {"train": [upper_lower_case_file]}
166+
webdataset = WebDataset(data_files=data_files)
167+
split_generators = webdataset._split_generators(DownloadManager())
168+
169+
variant_keys = [f"{tag}.{ext}" for tag, ext in variants]
170+
assert webdataset.info.features == Features(
171+
{
172+
"__key__": Value("string"),
173+
"__url__": Value("string"),
174+
**{k: {"caption": Value("string")} for k in variant_keys},
175+
}
176+
)
177+
178+
assert len(split_generators) == 1
179+
split_generator = split_generators[0]
180+
assert split_generator.name == "train"
181+
generator = webdataset._generate_examples(**split_generator.gen_kwargs)
182+
_, examples = zip(*generator)
183+
184+
assert len(examples) == 3
185+
for example_idx, example in enumerate(examples):
186+
example_name = example["__key__"]
187+
expected_example_name = f"{example_idx:05d}_{'a' if example_idx % 2 else 'A'}"
188+
189+
assert example_name == expected_example_name
190+
for key in variant_keys:
191+
assert isinstance(example[key], dict)
192+
assert example[key]["caption"] == f"caption for {example_name}.{key}"
193+
194+
encoded = webdataset.info.features.encode_example(example)
195+
decoded = webdataset.info.features.decode_example(encoded)
196+
for key in variant_keys:
197+
assert decoded[key]["caption"] == example[key]["caption"]
198+
199+
136200
@require_pil
137201
def test_image_webdataset_missing_keys(image_wds_file):
138202
import PIL.Image

0 commit comments

Comments
 (0)