From e4d5becbebc2f250ca38aa2007e3ff25fad72efb Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Fri, 18 Jul 2025 17:05:40 -0400 Subject: [PATCH 01/16] initial hdf5 support --- src/datasets/packaged_modules/__init__.py | 4 + .../packaged_modules/hdf5/__init__.py | 0 src/datasets/packaged_modules/hdf5/hdf5.py | 192 ++++++++++++++++++ 3 files changed, 196 insertions(+) create mode 100644 src/datasets/packaged_modules/hdf5/__init__.py create mode 100644 src/datasets/packaged_modules/hdf5/hdf5.py diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index f61c6ddd3de..863d7d31d81 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -8,6 +8,7 @@ from .audiofolder import audiofolder from .cache import cache from .csv import csv +from .hdf5 import hdf5 from .imagefolder import imagefolder from .json import json from .pandas import pandas @@ -47,6 +48,7 @@ def _hash_python_lines(lines: list[str]) -> str: "pdffolder": (pdffolder.__name__, _hash_python_lines(inspect.getsource(pdffolder).splitlines())), "webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())), "xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())), + "hdf5": (hdf5.__name__, _hash_python_lines(inspect.getsource(hdf5).splitlines())), } # get importable module names and hash for caching @@ -85,6 +87,8 @@ def _hash_python_lines(lines: list[str]) -> str: _EXTENSION_TO_MODULE.update({ext.upper(): ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext: ("pdffolder", {}) for ext in pdffolder.PdfFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext.upper(): ("pdffolder", {}) for ext in pdffolder.PdfFolder.EXTENSIONS}) +_EXTENSION_TO_MODULE.update({ext: ("hdf5", {}) for ext in hdf5.EXTENSIONS}) +_EXTENSION_TO_MODULE.update({ext.upper(): ("hdf5", {}) for ext in hdf5.EXTENSIONS}) # Used to filter data files based on extensions given a module name _MODULE_TO_EXTENSIONS: dict[str, list[str]] = {} diff --git a/src/datasets/packaged_modules/hdf5/__init__.py b/src/datasets/packaged_modules/hdf5/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py new file mode 100644 index 00000000000..29b37442bf1 --- /dev/null +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -0,0 +1,192 @@ +import itertools +from dataclasses import dataclass +from typing import Dict, List, Optional + +import h5py +import numpy as np +import pyarrow as pa + +import datasets +from datasets.table import table_cast + + +logger = datasets.utils.logging.get_logger(__name__) + +EXTENSIONS = [".h5", ".hdf5"] + + +@dataclass +class HDF5Config(datasets.BuilderConfig): + """BuilderConfig for HDF5.""" + + batch_size: Optional[int] = None + columns: Optional[List[str]] = None + features: Optional[datasets.Features] = None + + def __post_init__(self): + super().__post_init__() + + +class HDF5(datasets.ArrowBasedBuilder): + """ArrowBasedBuilder that converts HDF5 files to Arrow tables using the HF extension types.""" + + BUILDER_CONFIG_CLASS = HDF5Config + + def _info(self): + if ( + self.config.columns is not None + and self.config.features is not None + and set(self.config.columns) != set(self.config.features) + ): + raise ValueError( + "The columns and features argument must contain the same columns, but got ", + f"{self.config.columns} and {self.config.features}", + ) + return datasets.DatasetInfo(features=self.config.features) + + def _split_generators(self, dl_manager): + if not self.config.data_files: + raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") + dl_manager.download_config.extract_on_the_fly = True + data_files = dl_manager.download_and_extract(self.config.data_files) + splits = [] + for split_name, files in data_files.items(): + if isinstance(files, str): + files = [files] + + files = [dl_manager.iter_files(file) for file in files] + # Infer features from first file + if self.info.features is None: + for first_file in itertools.chain.from_iterable(files): + with h5py.File(first_file, "r") as h5: + dataset_map = _traverse_datasets(h5) + features_dict = {} + for path, dset in dataset_map.items(): + feat = _infer_feature_from_dataset(dset) + features_dict[path] = feat + self.info.features = datasets.Features(features_dict) + break + splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) + if self.config.columns is not None and set(self.config.columns) != set(self.info.features): + self.info.features = datasets.Features( + {col: feat for col, feat in self.info.features.items() if col in self.config.columns} + ) + return splits + + def _cast_table(self, pa_table: pa.Table) -> pa.Table: + if self.info.features is not None: + pa_table = table_cast(pa_table, self.info.features.arrow_schema) + return pa_table + + def _generate_tables(self, files): + batch_size_cfg = self.config.batch_size + for file_idx, file in enumerate(itertools.chain.from_iterable(files)): + with h5py.File(file, "r") as h5: + dataset_map = _traverse_datasets(h5) + if not dataset_map: + logger.warning(f"File '{file}' contains no datasets, skipping…") + continue + first_dset = next(iter(dataset_map.values())) + num_rows = first_dset.shape[0] + # Sanity-check lengths + for path, dset in dataset_map.items(): + if dset.shape[0] != num_rows: + raise ValueError( + f"Dataset '{path}' length {dset.shape[0]} differs from {num_rows} in file '{file}'" + ) + effective_batch = batch_size_cfg or self._writer_batch_size or num_rows + for start in range(0, num_rows, effective_batch): + end = min(start + effective_batch, num_rows) + batch_dict = {} + for path, dset in dataset_map.items(): + if self.config.columns is not None and path not in self.config.columns: + continue + arr = dset[start:end] + pa_arr = datasets.features.features.numpy_to_pyarrow_listarray(arr) + batch_dict[path] = pa_arr + pa_table = pa.Table.from_pydict(batch_dict) + yield f"{file_idx}_{start}", self._cast_table(pa_table) + + +def _traverse_datasets(h5_obj, prefix: str = "") -> Dict[str, h5py.Dataset]: + mapping: Dict[str, h5py.Dataset] = {} + for key in h5_obj: + item = h5_obj[key] + sub_path = f"{prefix}{key}" + if isinstance(item, h5py.Dataset): + mapping[sub_path] = item + elif isinstance(item, h5py.Group): + mapping.update(_traverse_datasets(item, prefix=f"{sub_path}/")) + return mapping + + +_DTYPE_TO_DATASETS: Dict[np.dtype, str] = { # FIXME: necessary/check if util exists? + np.dtype("bool").newbyteorder("="): "bool", + np.dtype("int8").newbyteorder("="): "int8", + np.dtype("int16").newbyteorder("="): "int16", + np.dtype("int32").newbyteorder("="): "int32", + np.dtype("int64").newbyteorder("="): "int64", + np.dtype("uint8").newbyteorder("="): "uint8", + np.dtype("uint16").newbyteorder("="): "uint16", + np.dtype("uint32").newbyteorder("="): "uint32", + np.dtype("uint64").newbyteorder("="): "uint64", + np.dtype("float16").newbyteorder("="): "float16", + np.dtype("float32").newbyteorder("="): "float32", + np.dtype("float64").newbyteorder("="): "float64", + # np.dtype("complex64").newbyteorder("="): "complex64", + # np.dtype("complex128").newbyteorder("="): "complex128", +} + + +def _dtype_to_dataset_dtype(dtype: np.dtype) -> str: + """Map NumPy dtype to datasets.Value dtype string, falls back to "binary" for unknown or unsupported dtypes.""" + + # FIXME: endian fix necessary/correct? + base_dtype = dtype.newbyteorder("=") + if base_dtype in _DTYPE_TO_DATASETS: + return _DTYPE_TO_DATASETS[base_dtype] + + if base_dtype.kind in {"S", "a"}: + return "binary" + + # FIXME: seems h5 converts unicode back to bytes? + if base_dtype.kind == "U": + return "binary" + + if base_dtype.kind == "O": + return "binary" + + # FIXME: support varlen? + + return "binary" + + +def _infer_feature_from_dataset(dset: h5py.Dataset): + """Infer a ``datasets.Features`` entry for one HDF5 dataset.""" + + import datasets as hfd + + dtype_str = _dtype_to_dataset_dtype(dset.dtype) + value_shape = dset.shape[1:] + + # Reject ragged datasets (variable-length or zero/None dims) + if dset.dtype.kind == "O" or any(s is None or s == 0 for s in value_shape): + raise ValueError(f"Ragged dataset {dset.name} with shape {value_shape} and dtype {dset.dtype} not supported") + + if dset.dtype.kind not in {"b", "i", "u", "f", "S", "a"}: + raise ValueError(f"Unsupported dtype {dset.dtype} for dataset {dset.name}") + + rank = len(value_shape) + if 2 <= rank <= 5: + from datasets.features import Array2D, Array3D, Array4D, Array5D + + array_cls = [None, None, Array2D, Array3D, Array4D, Array5D][rank] + return array_cls(shape=value_shape, dtype=dtype_str) + + # Fallback to nested Sequence + def _build_feature(shape: tuple[int, ...]): + if len(shape) == 0: + return hfd.Value(dtype_str) + return hfd.Sequence(length=shape[0], feature=_build_feature(shape[1:])) + + return _build_feature(value_shape) From d3ebf931341c1a6426b0bea6948f9719993b8f87 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Fri, 18 Jul 2025 19:16:42 -0400 Subject: [PATCH 02/16] handle zero dims --- src/datasets/packaged_modules/hdf5/hdf5.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 29b37442bf1..7f6e53f28b3 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -7,6 +7,7 @@ import pyarrow as pa import datasets +from datasets.features.features import LargeList, Sequence, _ArrayXD from datasets.table import table_cast @@ -75,7 +76,9 @@ def _split_generators(self, dl_manager): def _cast_table(self, pa_table: pa.Table) -> pa.Table: if self.info.features is not None: - pa_table = table_cast(pa_table, self.info.features.arrow_schema) + has_zero_dims = any(has_zero_dimensions(feature) for feature in self.info.features.values()) + if not has_zero_dims: + pa_table = table_cast(pa_table, self.info.features.arrow_schema) return pa_table def _generate_tables(self, files): @@ -169,8 +172,8 @@ def _infer_feature_from_dataset(dset: h5py.Dataset): dtype_str = _dtype_to_dataset_dtype(dset.dtype) value_shape = dset.shape[1:] - # Reject ragged datasets (variable-length or zero/None dims) - if dset.dtype.kind == "O" or any(s is None or s == 0 for s in value_shape): + # Reject ragged datasets (variable-length or None dims) + if dset.dtype.kind == "O" or any(s is None for s in value_shape): raise ValueError(f"Ragged dataset {dset.name} with shape {value_shape} and dtype {dset.dtype} not supported") if dset.dtype.kind not in {"b", "i", "u", "f", "S", "a"}: @@ -190,3 +193,12 @@ def _build_feature(shape: tuple[int, ...]): return hfd.Sequence(length=shape[0], feature=_build_feature(shape[1:])) return _build_feature(value_shape) + + +def has_zero_dimensions(feature: _ArrayXD | Sequence | LargeList): + if isinstance(feature, _ArrayXD): + return any(dim == 0 for dim in feature.shape) + elif isinstance(feature, (Sequence, LargeList)): + return feature.length == 0 or has_zero_dimensions(feature.feature) + else: + return False From be4adb42236aee95b553d0bffb05eef55edf2303 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Fri, 18 Jul 2025 19:17:17 -0400 Subject: [PATCH 03/16] add tests --- tests/packaged_modules/test_hdf5.py | 409 ++++++++++++++++++++++++++++ 1 file changed, 409 insertions(+) create mode 100644 tests/packaged_modules/test_hdf5.py diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py new file mode 100644 index 00000000000..c98d38a3130 --- /dev/null +++ b/tests/packaged_modules/test_hdf5.py @@ -0,0 +1,409 @@ +import h5py +import numpy as np +import pytest + +from datasets import Array2D, Array3D, Array4D, Features, Sequence, Value +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesDict, DataFilesList +from datasets.download.streaming_download_manager import StreamingDownloadManager +from datasets.packaged_modules.hdf5.hdf5 import HDF5, HDF5Config + + +@pytest.fixture +def hdf5_file(tmp_path): + """Create a basic HDF5 file with numeric datasets.""" + filename = tmp_path / "basic.h5" + n_rows = 5 + + with h5py.File(filename, "w") as f: + f.create_dataset("int32", data=np.arange(n_rows, dtype=np.int32)) + f.create_dataset("float32", data=np.arange(n_rows, dtype=np.float32) / 10.0) + f.create_dataset("bool", data=np.array([True, False, True, False, True])) + + return str(filename) + + +@pytest.fixture +def hdf5_file_with_groups(tmp_path): + """Create an HDF5 file with nested groups.""" + filename = tmp_path / "nested.h5" + n_rows = 3 + + with h5py.File(filename, "w") as f: + f.create_dataset("root_data", data=np.arange(n_rows, dtype=np.int32)) + grp = f.create_group("group1") + grp.create_dataset("group_data", data=np.arange(n_rows, dtype=np.float32)) + subgrp = grp.create_group("subgroup") + subgrp.create_dataset("sub_data", data=np.arange(n_rows, dtype=np.int64)) + + return str(filename) + + +@pytest.fixture +def hdf5_file_with_arrays(tmp_path): + """Create an HDF5 file with multi-dimensional arrays.""" + filename = tmp_path / "arrays.h5" + n_rows = 4 + + with h5py.File(filename, "w") as f: + # 2D array (should become Array2D) + f.create_dataset("matrix_2d", data=np.random.randn(n_rows, 3, 4).astype(np.float32)) + # 3D array (should become Array3D) + f.create_dataset("tensor_3d", data=np.random.randn(n_rows, 2, 3, 4).astype(np.float64)) + # 4D array (should become Array4D) + f.create_dataset("tensor_4d", data=np.random.randn(n_rows, 2, 3, 4, 5).astype(np.float32)) + # 5D array (should become Array5D) + f.create_dataset("tensor_5d", data=np.random.randn(n_rows, 2, 3, 4, 5, 6).astype(np.float64)) + # 1D array (should become Value) + f.create_dataset("vector_1d", data=np.random.randn(n_rows, 10).astype(np.float32)) + + return str(filename) + + +@pytest.fixture +def hdf5_file_with_different_dtypes(tmp_path): + """Create an HDF5 file with various numeric dtypes.""" + filename = tmp_path / "dtypes.h5" + n_rows = 3 + + with h5py.File(filename, "w") as f: + f.create_dataset("int8", data=np.arange(n_rows, dtype=np.int8)) + f.create_dataset("int16", data=np.arange(n_rows, dtype=np.int16)) + f.create_dataset("int64", data=np.arange(n_rows, dtype=np.int64)) + f.create_dataset("uint8", data=np.arange(n_rows, dtype=np.uint8)) + f.create_dataset("uint16", data=np.arange(n_rows, dtype=np.uint16)) + f.create_dataset("uint32", data=np.arange(n_rows, dtype=np.uint32)) + f.create_dataset("uint64", data=np.arange(n_rows, dtype=np.uint64)) + f.create_dataset("float16", data=np.arange(n_rows, dtype=np.float16) / 10.0) + f.create_dataset("float64", data=np.arange(n_rows, dtype=np.float64) / 10.0) + f.create_dataset("bytes", data=np.array([b"row_%d" % i for i in range(n_rows)], dtype="S10")) + + return str(filename) + + +@pytest.fixture +def hdf5_file_with_mismatched_lengths(tmp_path): + """Create an HDF5 file with datasets of different lengths (should raise error).""" + filename = tmp_path / "mismatched.h5" + + with h5py.File(filename, "w") as f: + f.create_dataset("data1", data=np.arange(5, dtype=np.int32)) + f.create_dataset("data2", data=np.arange(3, dtype=np.int32)) # Different length + + return str(filename) + + +@pytest.fixture +def hdf5_file_with_zero_dimensions(tmp_path): + """Create an HDF5 file with zero dimensions (should be handled gracefully).""" + filename = tmp_path / "zero_dims.h5" + + with h5py.File(filename, "w") as f: + # Create a dataset with a zero dimension + f.create_dataset("zero_dim", data=np.zeros((3, 0, 2), dtype=np.float32)) + # Create a dataset with zero in the middle dimension + f.create_dataset("zero_middle", data=np.zeros((3, 0), dtype=np.int32)) + # Create a dataset with zero in the last dimension + f.create_dataset("zero_last", data=np.zeros((3, 2, 0), dtype=np.float64)) + + return str(filename) + + +@pytest.fixture +def hdf5_file_with_unsupported_dtypes(tmp_path): + """Create an HDF5 file with unsupported dtypes (complex).""" + filename = tmp_path / "unsupported.h5" + + with h5py.File(filename, "w") as f: + # Complex dtype (should be rejected) + complex_data = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64) + f.create_dataset("complex_data", data=complex_data) + + return str(filename) + + +@pytest.fixture +def empty_hdf5_file(tmp_path): + """Create an HDF5 file with no datasets (should warn and skip).""" + filename = tmp_path / "empty.h5" + + with h5py.File(filename, "w") as f: + # Create only groups, no datasets + f.create_group("empty_group") + grp = f.create_group("another_group") + grp.create_group("subgroup") + + return str(filename) + + +def test_config_raises_when_invalid_name(): + """Test that invalid config names raise an error.""" + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = HDF5Config(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) +def test_config_raises_when_invalid_data_files(data_files): + """Test that invalid data_files parameter raises an error.""" + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = HDF5Config(name="name", data_files=data_files) + + +def test_hdf5_basic_functionality(hdf5_file): + """Test basic HDF5 loading with simple numeric datasets.""" + hdf5 = HDF5() + generator = hdf5._generate_tables([[hdf5_file]]) + + tables = list(generator) + assert len(tables) == 1 + + _, table = tables[0] + assert "int32" in table.column_names + assert "float32" in table.column_names + assert "bool" in table.column_names + + # Check data + int32_data = table["int32"].to_pylist() + assert int32_data == [0, 1, 2, 3, 4] + + float32_data = table["float32"].to_pylist() + expected_float32 = [0.0, 0.1, 0.2, 0.3, 0.4] + np.testing.assert_allclose(float32_data, expected_float32, rtol=1e-6) + + +def test_hdf5_nested_groups(hdf5_file_with_groups): + """Test HDF5 loading with nested groups.""" + hdf5 = HDF5() + generator = hdf5._generate_tables([[hdf5_file_with_groups]]) + + tables = list(generator) + assert len(tables) == 1 + + _, table = tables[0] + expected_columns = {"root_data", "group1/group_data", "group1/subgroup/sub_data"} + assert set(table.column_names) == expected_columns + + # Check data + root_data = table["root_data"].to_pylist() + assert root_data == [0, 1, 2] + + group_data = table["group1/group_data"].to_pylist() + expected_group_data = [0.0, 1.0, 2.0] + np.testing.assert_allclose(group_data, expected_group_data, rtol=1e-6) + + +def test_hdf5_multi_dimensional_arrays(hdf5_file_with_arrays): + """Test HDF5 loading with multi-dimensional arrays.""" + hdf5 = HDF5() + generator = hdf5._generate_tables([[hdf5_file_with_arrays]]) + + tables = list(generator) + assert len(tables) == 1 + + _, table = tables[0] + expected_columns = {"matrix_2d", "tensor_3d", "tensor_4d", "tensor_5d", "vector_1d"} + assert set(table.column_names) == expected_columns + + # Check shapes + matrix_2d = table["matrix_2d"].to_pylist() + assert len(matrix_2d) == 4 # 4 rows + assert len(matrix_2d[0]) == 3 # 3 rows in each matrix + assert len(matrix_2d[0][0]) == 4 # 4 columns in each matrix + + +def test_hdf5_different_dtypes(hdf5_file_with_different_dtypes): + """Test HDF5 loading with various numeric dtypes.""" + hdf5 = HDF5() + generator = hdf5._generate_tables([[hdf5_file_with_different_dtypes]]) + + tables = list(generator) + assert len(tables) == 1 + + _, table = tables[0] + expected_columns = {"int8", "int16", "int64", "uint8", "uint16", "uint32", "uint64", "float16", "float64", "bytes"} + assert set(table.column_names) == expected_columns + + # Check specific dtypes + int8_data = table["int8"].to_pylist() + assert int8_data == [0, 1, 2] + + bytes_data = table["bytes"].to_pylist() + assert bytes_data == [b"row_0", b"row_1", b"row_2"] + + +def test_hdf5_batch_processing(hdf5_file): + """Test HDF5 loading with custom batch size.""" + config = HDF5Config(batch_size=2) + hdf5 = HDF5() + hdf5.config = config + generator = hdf5._generate_tables([[hdf5_file]]) + + tables = list(generator) + # Should have 3 batches: [0,1], [2,3], [4] + assert len(tables) == 3 + + # Check first batch + _, first_batch = tables[0] + assert len(first_batch) == 2 + + # Check last batch + _, last_batch = tables[2] + assert len(last_batch) == 1 + + +def test_hdf5_column_filtering(hdf5_file_with_groups): + """Test HDF5 loading with column filtering.""" + config = HDF5Config(columns=["root_data", "group1/group_data"]) + hdf5 = HDF5() + hdf5.config = config + generator = hdf5._generate_tables([[hdf5_file_with_groups]]) + + tables = list(generator) + assert len(tables) == 1 + + _, table = tables[0] + expected_columns = {"root_data", "group1/group_data"} + assert set(table.column_names) == expected_columns + assert "group1/subgroup/sub_data" not in table.column_names + + +def test_hdf5_feature_specification(hdf5_file): + """Test HDF5 loading with explicit feature specification.""" + features = Features({"int32": Value("int32"), "float32": Value("float32"), "bool": Value("bool")}) + + config = HDF5Config(features=features) + hdf5 = HDF5() + hdf5.config = config + generator = hdf5._generate_tables([[hdf5_file]]) + + tables = list(generator) + assert len(tables) == 1 + + _, table = tables[0] + # Check that features are properly cast + assert table.schema.field("int32").type == features["int32"].pa_type + assert table.schema.field("float32").type == features["float32"].pa_type + assert table.schema.field("bool").type == features["bool"].pa_type + + +def test_hdf5_mismatched_lengths_error(hdf5_file_with_mismatched_lengths): + """Test that mismatched dataset lengths raise an error.""" + hdf5 = HDF5() + generator = hdf5._generate_tables([[hdf5_file_with_mismatched_lengths]]) + + with pytest.raises(ValueError, match="length.*differs from"): + for _ in generator: + pass + + +def test_hdf5_zero_dimensions_handling(hdf5_file_with_zero_dimensions, caplog): + """Test that zero dimensions are handled gracefully.""" + # Trigger feature inference + data_files = DataFilesDict({"train": [hdf5_file_with_zero_dimensions]}) + config = HDF5Config(data_files=data_files) + hdf5 = HDF5() + hdf5.config = config + + # Trigger feature inference + dl_manager = StreamingDownloadManager() + hdf5._split_generators(dl_manager) + + # Check that features were inferred + assert hdf5.info.features is not None + + # Test that the data can be loaded + generator = hdf5._generate_tables([[hdf5_file_with_zero_dimensions]]) + tables = list(generator) + assert len(tables) == 1 + + _, table = tables[0] + expected_columns = {"zero_dim", "zero_middle", "zero_last"} + assert set(table.column_names) == expected_columns + + # Check that the data is loaded (should be empty arrays) + zero_dim_data = table["zero_dim"].to_pylist() + assert len(zero_dim_data) == 3 # 3 rows + assert all(len(row) == 0 for row in zero_dim_data) # Each row is empty + + +def test_hdf5_unsupported_dtypes_error(hdf5_file_with_unsupported_dtypes): + """Test that unsupported dtypes raise an error.""" + hdf5 = HDF5() + generator = hdf5._generate_tables([[hdf5_file_with_unsupported_dtypes]]) + + # Complex dtypes cause ArrowNotImplementedError during conversion + with pytest.raises(Exception): # Either ValueError or ArrowNotImplementedError + for _ in generator: + pass + + +def test_hdf5_empty_file_warning(empty_hdf5_file, caplog): + """Test that empty files (no datasets) are skipped with a warning.""" + hdf5 = HDF5() + generator = hdf5._generate_tables([[empty_hdf5_file]]) + + tables = list(generator) + assert len(tables) == 0 # No tables should be generated + + # Check that warning was logged + assert any( + record.levelname == "WARNING" and "contains no datasets, skipping" in record.message + for record in caplog.records + ) + + +def test_hdf5_feature_inference(hdf5_file_with_arrays): + """Test automatic feature inference from HDF5 datasets.""" + data_files = DataFilesDict({"train": [hdf5_file_with_arrays]}) + config = HDF5Config(data_files=data_files) + hdf5 = HDF5() + hdf5.config = config + + # Trigger feature inference + dl_manager = StreamingDownloadManager() + hdf5._split_generators(dl_manager) + + # Check that features were inferred + assert hdf5.info.features is not None + + # Check specific feature types + features = hdf5.info.features + # (n_rows, 3, 4) -> Array2D with shape (3, 4) + assert isinstance(features["matrix_2d"], Array2D) + assert features["matrix_2d"].shape == (3, 4) + # (n_rows, 2, 3, 4) -> Array3D with shape (2, 3, 4) + assert isinstance(features["tensor_3d"], Array3D) + assert features["tensor_3d"].shape == (2, 3, 4) + # (n_rows, 2, 3, 4, 5) -> Array4D with shape (2, 3, 4, 5) + assert isinstance(features["tensor_4d"], Array4D) + assert features["tensor_4d"].shape == (2, 3, 4, 5) + # (n_rows, 10) -> Sequence of length 10 + assert isinstance(features["vector_1d"], Sequence) + assert features["vector_1d"].length == 10 + + +def test_hdf5_columns_features_mismatch(): + """Test that mismatched columns and features raise an error.""" + features = Features({"col1": Value("int32"), "col2": Value("float32")}) + + config = HDF5Config( + name="test", + columns=["col1", "col3"], # col3 not in features + features=features, + ) + + hdf5 = HDF5() + hdf5.config = config + + with pytest.raises(ValueError, match="must contain the same columns"): + hdf5._info() + + +def test_hdf5_no_data_files_error(): + """Test that missing data_files raises an error.""" + config = HDF5Config(name="test", data_files=None) + hdf5 = HDF5() + hdf5.config = config + + with pytest.raises(ValueError, match="At least one data file must be specified"): + hdf5._split_generators(None) From 24082425ff70178749091a2cf693b4dde8c51589 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Fri, 18 Jul 2025 20:40:22 -0400 Subject: [PATCH 04/16] refactor type inference --- src/datasets/packaged_modules/hdf5/hdf5.py | 199 ++++++++++++++------- tests/packaged_modules/test_hdf5.py | 167 ++++++++++++++++- 2 files changed, 301 insertions(+), 65 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 7f6e53f28b3..bea1ce52b62 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -2,12 +2,22 @@ from dataclasses import dataclass from typing import Dict, List, Optional -import h5py import numpy as np import pyarrow as pa import datasets -from datasets.features.features import LargeList, Sequence, _ArrayXD +import h5py +from datasets.features.features import ( + Array2D, + Array3D, + Array4D, + Array5D, + LargeList, + Sequence, + Value, + _ArrayXD, + _arrow_to_datasets_dtype, +) from datasets.table import table_cast @@ -76,7 +86,7 @@ def _split_generators(self, dl_manager): def _cast_table(self, pa_table: pa.Table) -> pa.Table: if self.info.features is not None: - has_zero_dims = any(has_zero_dimensions(feature) for feature in self.info.features.values()) + has_zero_dims = any(_has_zero_dimensions(feature) for feature in self.info.features.values()) if not has_zero_dims: pa_table = table_cast(pa_table, self.info.features.arrow_schema) return pa_table @@ -105,7 +115,13 @@ def _generate_tables(self, files): if self.config.columns is not None and path not in self.config.columns: continue arr = dset[start:end] - pa_arr = datasets.features.features.numpy_to_pyarrow_listarray(arr) + if _is_ragged_dataset(dset): + if _is_variable_length_string(dset): + pa_arr = _variable_length_string_to_pyarrow(arr, dset) + else: + pa_arr = _ragged_array_to_pyarrow_largelist(arr, dset) + else: + pa_arr = datasets.features.features.numpy_to_pyarrow_listarray(arr) # NOTE: type=None batch_dict[path] = pa_arr pa_table = pa.Table.from_pydict(batch_dict) yield f"{file_idx}_{start}", self._cast_table(pa_table) @@ -123,82 +139,137 @@ def _traverse_datasets(h5_obj, prefix: str = "") -> Dict[str, h5py.Dataset]: return mapping -_DTYPE_TO_DATASETS: Dict[np.dtype, str] = { # FIXME: necessary/check if util exists? - np.dtype("bool").newbyteorder("="): "bool", - np.dtype("int8").newbyteorder("="): "int8", - np.dtype("int16").newbyteorder("="): "int16", - np.dtype("int32").newbyteorder("="): "int32", - np.dtype("int64").newbyteorder("="): "int64", - np.dtype("uint8").newbyteorder("="): "uint8", - np.dtype("uint16").newbyteorder("="): "uint16", - np.dtype("uint32").newbyteorder("="): "uint32", - np.dtype("uint64").newbyteorder("="): "uint64", - np.dtype("float16").newbyteorder("="): "float16", - np.dtype("float32").newbyteorder("="): "float32", - np.dtype("float64").newbyteorder("="): "float64", - # np.dtype("complex64").newbyteorder("="): "complex64", - # np.dtype("complex128").newbyteorder("="): "complex128", -} - - -def _dtype_to_dataset_dtype(dtype: np.dtype) -> str: - """Map NumPy dtype to datasets.Value dtype string, falls back to "binary" for unknown or unsupported dtypes.""" - - # FIXME: endian fix necessary/correct? - base_dtype = dtype.newbyteorder("=") - if base_dtype in _DTYPE_TO_DATASETS: - return _DTYPE_TO_DATASETS[base_dtype] - - if base_dtype.kind in {"S", "a"}: - return "binary" - - # FIXME: seems h5 converts unicode back to bytes? - if base_dtype.kind == "U": - return "binary" - - if base_dtype.kind == "O": - return "binary" - - # FIXME: support varlen? - - return "binary" +def _base_dtype(dtype): + if hasattr(dtype, "metadata") and dtype.metadata and "vlen" in dtype.metadata: + return dtype.metadata["vlen"] + if hasattr(dtype, "subdtype") and dtype.subdtype is not None: + return _base_dtype(dtype.subdtype[0]) + return dtype + + +def _ragged_array_to_pyarrow_largelist(arr: np.ndarray, dset: h5py.Dataset) -> pa.Array: + if _is_variable_length_string(dset): + list_of_strings = [] + for item in arr: + if item is None: + list_of_strings.append(None) + else: + if isinstance(item, bytes): + item = item.decode("utf-8") + list_of_strings.append(item) + return datasets.features.features.list_of_pa_arrays_to_pyarrow_listarray( + [pa.array([item]) if item is not None else None for item in list_of_strings] + ) + else: + return _convert_nested_ragged_array_recursive(arr, dset.dtype) + + +def _convert_nested_ragged_array_recursive(arr: np.ndarray, dtype): + if hasattr(dtype, "subdtype") and dtype.subdtype is not None: + inner_dtype = dtype.subdtype[0] + list_of_arrays = [] + for item in arr: + if item is None: + list_of_arrays.append(None) + else: + inner_array = _convert_nested_ragged_array_recursive(item, inner_dtype) + list_of_arrays.append(inner_array) + return datasets.features.features.list_of_pa_arrays_to_pyarrow_listarray( + [pa.array(item) if item is not None else None for item in list_of_arrays] + ) + else: + list_of_arrays = [] + for item in arr: + if item is None: + list_of_arrays.append(None) + else: + if not isinstance(item, np.ndarray): + item = np.array(item, dtype=dtype) + list_of_arrays.append(item) + return datasets.features.features.list_of_pa_arrays_to_pyarrow_listarray( + [pa.array(item) if item is not None else None for item in list_of_arrays] + ) def _infer_feature_from_dataset(dset: h5py.Dataset): - """Infer a ``datasets.Features`` entry for one HDF5 dataset.""" + if _is_variable_length_string(dset): + return Value("string") # FIXME: large_string? - import datasets as hfd + if _is_ragged_dataset(dset): + return _infer_nested_feature_recursive(dset.dtype, dset) - dtype_str = _dtype_to_dataset_dtype(dset.dtype) + value_feature = _np_to_pa_to_hf_value(dset.dtype) + dtype_str = value_feature.dtype value_shape = dset.shape[1:] - # Reject ragged datasets (variable-length or None dims) - if dset.dtype.kind == "O" or any(s is None for s in value_shape): - raise ValueError(f"Ragged dataset {dset.name} with shape {value_shape} and dtype {dset.dtype} not supported") - if dset.dtype.kind not in {"b", "i", "u", "f", "S", "a"}: - raise ValueError(f"Unsupported dtype {dset.dtype} for dataset {dset.name}") + raise TypeError(f"Unsupported dtype {dset.dtype} for dataset {dset.name}") rank = len(value_shape) - if 2 <= rank <= 5: - from datasets.features import Array2D, Array3D, Array4D, Array5D - - array_cls = [None, None, Array2D, Array3D, Array4D, Array5D][rank] - return array_cls(shape=value_shape, dtype=dtype_str) + if rank == 0: + return value_feature + elif rank == 1: + return Sequence(value_feature, length=value_shape[0]) + elif 2 <= rank <= 5: + return _sized_arrayxd(rank)(shape=value_shape, dtype=dtype_str) + else: + raise TypeError(f"Array{rank}D not supported. Only up to 5D arrays are supported.") - # Fallback to nested Sequence - def _build_feature(shape: tuple[int, ...]): - if len(shape) == 0: - return hfd.Value(dtype_str) - return hfd.Sequence(length=shape[0], feature=_build_feature(shape[1:])) - return _build_feature(value_shape) +def _infer_nested_feature_recursive(dtype, dset: h5py.Dataset): + if hasattr(dtype, "subdtype") and dtype.subdtype is not None: + inner_dtype = dtype.subdtype[0] + inner_feature = _infer_nested_feature_recursive(inner_dtype, dset) + return Sequence(inner_feature) + else: + if hasattr(dtype, "kind") and dtype.kind == "O": + if _is_variable_length_string(dset): + base_dtype = np.dtype("S1") + else: + base_dtype = _base_dtype(dset.dtype) + return Sequence(_np_to_pa_to_hf_value(base_dtype)) + else: + return _np_to_pa_to_hf_value(dtype) -def has_zero_dimensions(feature: _ArrayXD | Sequence | LargeList): +def _has_zero_dimensions(feature): if isinstance(feature, _ArrayXD): return any(dim == 0 for dim in feature.shape) elif isinstance(feature, (Sequence, LargeList)): - return feature.length == 0 or has_zero_dimensions(feature.feature) + return feature.length == 0 or _has_zero_dimensions(feature.feature) else: return False + + +def _sized_arrayxd(rank: int): + return {2: Array2D, 3: Array3D, 4: Array4D, 5: Array5D}[rank] + + +def _np_to_pa_to_hf_value(numpy_dtype: np.dtype) -> Value: + return Value(dtype=_arrow_to_datasets_dtype(pa.from_numpy_dtype(numpy_dtype))) + + +def _is_ragged_dataset(dset: h5py.Dataset) -> bool: + return dset.dtype.kind == "O" and hasattr(dset.dtype, "subdtype") + + +def _is_variable_length_string(dset: h5py.Dataset) -> bool: + if not _is_ragged_dataset(dset) or dset.shape[0] == 0: + return False + num_samples = min(3, dset.shape[0]) + for i in range(num_samples): + try: + if isinstance(dset[i], (str, bytes)): + return True + except (IndexError, TypeError): + continue + return False + + +def _variable_length_string_to_pyarrow(arr: np.ndarray, dset: h5py.Dataset) -> pa.Array: + list_of_strings = [] + for item in arr: + if isinstance(item, bytes): + item = item.decode("utf-8") + list_of_strings.append(item) + return pa.array(list_of_strings) diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py index c98d38a3130..a5600c83af0 100644 --- a/tests/packaged_modules/test_hdf5.py +++ b/tests/packaged_modules/test_hdf5.py @@ -1,7 +1,7 @@ -import h5py import numpy as np import pytest +import h5py from datasets import Array2D, Array3D, Array4D, Features, Sequence, Value from datasets.builder import InvalidConfigName from datasets.data_files import DataFilesDict, DataFilesList @@ -81,6 +81,61 @@ def hdf5_file_with_different_dtypes(tmp_path): return str(filename) +@pytest.fixture +def hdf5_file_with_ragged_arrays(tmp_path): + """Create an HDF5 file with ragged arrays using HDF5's vlen_dtype.""" + filename = tmp_path / "ragged.h5" + n_rows = 4 + + with h5py.File(filename, "w") as f: + # Variable-length arrays of different sizes using vlen_dtype + ragged_arrays = [[1, 2, 3], [4, 5], [6, 7, 8, 9], [10]] + # Create variable-length int dataset using vlen_dtype + dt = h5py.vlen_dtype(np.dtype("int32")) + dset = f.create_dataset("ragged_ints", (n_rows,), dtype=dt) + for i, arr in enumerate(ragged_arrays): + dset[i] = arr + + # Mixed types (some empty arrays) - use variable-length with empty arrays + mixed_data = [ + [1, 2, 3], + [], # Empty array + [4, 5], + [6], + ] + dt_mixed = h5py.vlen_dtype(np.dtype("int32")) + dset_mixed = f.create_dataset("mixed_data", (n_rows,), dtype=dt_mixed) + for i, arr in enumerate(mixed_data): + dset_mixed[i] = arr + + return str(filename) + + +@pytest.fixture +def hdf5_file_with_variable_length_strings(tmp_path): + """Create an HDF5 file with variable-length string datasets.""" + filename = tmp_path / "var_strings.h5" + n_rows = 4 + + with h5py.File(filename, "w") as f: + # Variable-length string dataset + var_strings = ["short", "medium length string", "very long string with many characters", "tiny"] + # Create variable-length string dataset using vlen_dtype + dt = h5py.vlen_dtype(str) + dset = f.create_dataset("var_strings", (n_rows,), dtype=dt) + for i, s in enumerate(var_strings): + dset[i] = s + + # Variable-length bytes dataset + var_bytes = [b"short", b"medium length bytes", b"very long bytes with many characters", b"tiny"] + dt_bytes = h5py.vlen_dtype(bytes) + dset_bytes = f.create_dataset("var_bytes", (n_rows,), dtype=dt_bytes) + for i, b in enumerate(var_bytes): + dset_bytes[i] = b + + return str(filename) + + @pytest.fixture def hdf5_file_with_mismatched_lengths(tmp_path): """Create an HDF5 file with datasets of different lengths (should raise error).""" @@ -211,6 +266,64 @@ def test_hdf5_multi_dimensional_arrays(hdf5_file_with_arrays): assert len(matrix_2d[0][0]) == 4 # 4 columns in each matrix +def test_hdf5_ragged_arrays(hdf5_file_with_ragged_arrays): + """Test HDF5 loading with ragged arrays (object dtype).""" + hdf5 = HDF5() + generator = hdf5._generate_tables([[hdf5_file_with_ragged_arrays]]) + + tables = list(generator) + assert len(tables) == 1 + + _, table = tables[0] + expected_columns = {"ragged_ints", "mixed_data"} + assert set(table.column_names) == expected_columns + + # Check ragged_ints data + ragged_ints = table["ragged_ints"].to_pylist() + assert len(ragged_ints) == 4 + assert ragged_ints[0] == [1, 2, 3] + assert ragged_ints[1] == [4, 5] + assert ragged_ints[2] == [6, 7, 8, 9] + assert ragged_ints[3] == [10] + + # Check mixed_data (with None values) + mixed_data = table["mixed_data"].to_pylist() + assert len(mixed_data) == 4 + assert mixed_data[0] == [1, 2, 3] + assert mixed_data[1] == [] # Empty array instead of None + assert mixed_data[2] == [4, 5] + assert mixed_data[3] == [6] + + +def test_hdf5_variable_length_strings(hdf5_file_with_variable_length_strings): + """Test HDF5 loading with variable-length string datasets.""" + hdf5 = HDF5() + generator = hdf5._generate_tables([[hdf5_file_with_variable_length_strings]]) + + tables = list(generator) + assert len(tables) == 1 + + _, table = tables[0] + expected_columns = {"var_strings", "var_bytes"} + assert set(table.column_names) == expected_columns + + # Check variable-length strings (converted to strings for usability) + var_strings = table["var_strings"].to_pylist() + assert len(var_strings) == 4 + assert var_strings[0] == "short" + assert var_strings[1] == "medium length string" + assert var_strings[2] == "very long string with many characters" + assert var_strings[3] == "tiny" + + # Check variable-length bytes (converted to strings for usability) + var_bytes = table["var_bytes"].to_pylist() + assert len(var_bytes) == 4 + assert var_bytes[0] == "short" + assert var_bytes[1] == "medium length bytes" + assert var_bytes[2] == "very long bytes with many characters" + assert var_bytes[3] == "tiny" + + def test_hdf5_different_dtypes(hdf5_file_with_different_dtypes): """Test HDF5 loading with various numeric dtypes.""" hdf5 = HDF5() @@ -382,6 +495,58 @@ def test_hdf5_feature_inference(hdf5_file_with_arrays): assert features["vector_1d"].length == 10 +def test_hdf5_ragged_feature_inference(hdf5_file_with_ragged_arrays): + """Test automatic feature inference from ragged HDF5 datasets.""" + data_files = DataFilesDict({"train": [hdf5_file_with_ragged_arrays]}) + config = HDF5Config(data_files=data_files) + hdf5 = HDF5() + hdf5.config = config + + # Trigger feature inference + dl_manager = StreamingDownloadManager() + hdf5._split_generators(dl_manager) + + # Check that features were inferred + assert hdf5.info.features is not None + + # Check specific feature types for ragged arrays + features = hdf5.info.features + # Ragged arrays should become Sequence features by default (for small datasets) + assert isinstance(features["ragged_ints"], Sequence) + assert isinstance(features["mixed_data"], Sequence) + + # Check that the inner feature types are correct + assert isinstance(features["ragged_ints"].feature, Value) + assert features["ragged_ints"].feature.dtype == "int32" + assert isinstance(features["mixed_data"].feature, Value) + assert features["mixed_data"].feature.dtype == "int32" + + +def test_hdf5_variable_string_feature_inference(hdf5_file_with_variable_length_strings): + """Test automatic feature inference from variable-length string datasets.""" + data_files = DataFilesDict({"train": [hdf5_file_with_variable_length_strings]}) + config = HDF5Config(data_files=data_files) + hdf5 = HDF5() + hdf5.config = config + + # Trigger feature inference + dl_manager = StreamingDownloadManager() + hdf5._split_generators(dl_manager) + + # Check that features were inferred + assert hdf5.info.features is not None + + # Check specific feature types for variable-length strings + features = hdf5.info.features + # Variable-length strings should become Value("string") features + assert isinstance(features["var_strings"], Value) + assert isinstance(features["var_bytes"], Value) + + # Check that the feature types are correct + assert features["var_strings"].dtype == "string" + assert features["var_bytes"].dtype == "string" + + def test_hdf5_columns_features_mismatch(): """Test that mismatched columns and features raise an error.""" features = Features({"col1": Value("int32"), "col2": Value("float32")}) From 1c6b1f955e6cbaf8acc81b69f4023b80e112bfa9 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Fri, 18 Jul 2025 23:49:49 -0400 Subject: [PATCH 05/16] refactor vlen, drop ragged, add complex/compound --- src/datasets/packaged_modules/hdf5/hdf5.py | 364 +++++++++++++-------- tests/packaged_modules/test_hdf5.py | 362 +++++++++++++++++--- 2 files changed, 533 insertions(+), 193 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index bea1ce52b62..3cf5bc1b7c5 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -1,6 +1,6 @@ import itertools from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import numpy as np import pyarrow as pa @@ -72,9 +72,32 @@ def _split_generators(self, dl_manager): with h5py.File(first_file, "r") as h5: dataset_map = _traverse_datasets(h5) features_dict = {} + + def _check_column_collisions(new_columns, source_dataset_path): + """Check for column name collisions and raise informative errors.""" + for new_col in new_columns: + if new_col in features_dict: + raise ValueError( + f"Column name collision detected: '{new_col}' from dataset '{source_dataset_path}' " + f"conflicts with existing column. Consider renaming datasets in the HDF5 file." + ) + for path, dset in dataset_map.items(): - feat = _infer_feature_from_dataset(dset) - features_dict[path] = feat + if _is_complex_dtype(dset.dtype): + complex_features = _create_complex_features(path, dset) + _check_column_collisions(complex_features.keys(), path) + features_dict.update(complex_features) + elif _is_compound_dtype(dset.dtype): + compound_features = _create_compound_features(path, dset) + _check_column_collisions(compound_features.keys(), path) + features_dict.update(compound_features) + elif _is_vlen_string_dtype(dset.dtype): + _check_column_collisions([path], path) + features_dict[path] = Value("string") + else: + _check_column_collisions([path], path) + feat = _infer_feature_from_dataset(dset) + features_dict[path] = feat self.info.features = datasets.Features(features_dict) break splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) @@ -86,7 +109,11 @@ def _split_generators(self, dl_manager): def _cast_table(self, pa_table: pa.Table) -> pa.Table: if self.info.features is not None: - has_zero_dims = any(_has_zero_dimensions(feature) for feature in self.info.features.values()) + relevant_features = { + col: self.info.features[col] for col in pa_table.column_names if col in self.info.features + } + has_zero_dims = any(_has_zero_dimensions(feature) for feature in relevant_features.values()) + # FIXME: pyarrow.lib.ArrowInvalid: list_size needs to be a strict positive integer if not has_zero_dims: pa_table = table_cast(pa_table, self.info.features.arrow_schema) return pa_table @@ -94,149 +121,228 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table: def _generate_tables(self, files): batch_size_cfg = self.config.batch_size for file_idx, file in enumerate(itertools.chain.from_iterable(files)): - with h5py.File(file, "r") as h5: - dataset_map = _traverse_datasets(h5) - if not dataset_map: - logger.warning(f"File '{file}' contains no datasets, skipping…") - continue - first_dset = next(iter(dataset_map.values())) - num_rows = first_dset.shape[0] - # Sanity-check lengths - for path, dset in dataset_map.items(): - if dset.shape[0] != num_rows: - raise ValueError( - f"Dataset '{path}' length {dset.shape[0]} differs from {num_rows} in file '{file}'" - ) - effective_batch = batch_size_cfg or self._writer_batch_size or num_rows - for start in range(0, num_rows, effective_batch): - end = min(start + effective_batch, num_rows) - batch_dict = {} + try: + with h5py.File(file, "r") as h5: + dataset_map = _traverse_datasets(h5) + if not dataset_map: + logger.warning(f"File '{file}' contains no data, skipping...") + continue + first_dset = next(iter(dataset_map.values())) + num_rows = first_dset.shape[0] + # Sanity-check lengths for path, dset in dataset_map.items(): - if self.config.columns is not None and path not in self.config.columns: - continue - arr = dset[start:end] - if _is_ragged_dataset(dset): - if _is_variable_length_string(dset): - pa_arr = _variable_length_string_to_pyarrow(arr, dset) + if dset.shape[0] != num_rows: + raise ValueError( + f"Dataset '{path}' length {dset.shape[0]} differs from {num_rows} in file '{file}'" + ) + effective_batch = batch_size_cfg or self._writer_batch_size or num_rows + for start in range(0, num_rows, effective_batch): + end = min(start + effective_batch, num_rows) + batch_dict = {} + for path, dset in dataset_map.items(): + if self.config.columns is not None and path not in self.config.columns: + continue + arr = dset[start:end] + + # Handle variable-length arrays + if _is_vlen_string_dtype(dset.dtype): + logger.debug( + f"Converting variable-length string data for '{path}' (shape: {arr.shape})" + ) + batch_dict[path] = _convert_vlen_string_to_array(arr) + elif ( + hasattr(dset.dtype, "metadata") + and dset.dtype.metadata + and "vlen" in dset.dtype.metadata + ): + # Handle other variable-length types (non-strings) + pa_arr = datasets.features.features.numpy_to_pyarrow_listarray(arr) + batch_dict[path] = pa_arr + elif _is_complex_dtype(dset.dtype): + batch_dict.update(_convert_complex_to_separate_columns(path, arr, dset)) + elif _is_compound_dtype(dset.dtype): + batch_dict.update(_convert_compound_to_separate_columns(path, arr, dset)) + elif dset.dtype.kind == "O": + raise ValueError( + f"Object dtype dataset '{path}' is not supported. " + f"For variable-length data, please use h5py.vlen_dtype() " + f"when creating the HDF5 file. " + f"See: https://docs.h5py.org/en/stable/special.html#variable-length-strings" + ) else: - pa_arr = _ragged_array_to_pyarrow_largelist(arr, dset) - else: - pa_arr = datasets.features.features.numpy_to_pyarrow_listarray(arr) # NOTE: type=None - batch_dict[path] = pa_arr - pa_table = pa.Table.from_pydict(batch_dict) - yield f"{file_idx}_{start}", self._cast_table(pa_table) + pa_arr = datasets.features.features.numpy_to_pyarrow_listarray(arr) + batch_dict[path] = pa_arr + pa_table = pa.Table.from_pydict(batch_dict) + yield f"{file_idx}_{start}", self._cast_table(pa_table) + except ValueError as e: + logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") + raise def _traverse_datasets(h5_obj, prefix: str = "") -> Dict[str, h5py.Dataset]: mapping: Dict[str, h5py.Dataset] = {} - for key in h5_obj: - item = h5_obj[key] - sub_path = f"{prefix}{key}" - if isinstance(item, h5py.Dataset): - mapping[sub_path] = item - elif isinstance(item, h5py.Group): - mapping.update(_traverse_datasets(item, prefix=f"{sub_path}/")) + + def collect_datasets(name, obj): + if isinstance(obj, h5py.Dataset): + full_path = f"{prefix}{name}" if prefix else name + mapping[full_path] = obj + + h5_obj.visititems(collect_datasets) return mapping -def _base_dtype(dtype): +# ┌───────────┐ +# │ Complex │ +# └───────────┘ + + +def _is_complex_dtype(dtype: np.dtype) -> bool: + """Check if dtype is a complex number type.""" + return dtype.kind == "c" + + +def _create_complex_features(base_path: str, dset: h5py.Dataset) -> Dict[str, Value]: + """Create separate features for real and imaginary parts of complex data. + + NOTE: Always uses float64 for the real and imaginary parts. + """ + logger.info( + f"Complex dataset '{base_path}' (dtype: {dset.dtype}) split into '{base_path}_real' and '{base_path}_imag'" + ) + return {f"{base_path}_real": Value("float64"), f"{base_path}_imag": Value("float64")} + + +def _convert_complex_to_separate_columns(base_path: str, arr: np.ndarray, dset: h5py.Dataset) -> Dict[str, pa.Array]: + """Convert complex array to separate real and imaginary columns.""" + result = {} + result[f"{base_path}_real"] = datasets.features.features.numpy_to_pyarrow_listarray(arr.real) + result[f"{base_path}_imag"] = datasets.features.features.numpy_to_pyarrow_listarray(arr.imag) + return result + + +# ┌────────────┐ +# │ Compound │ +# └────────────┘ + + +def _is_compound_dtype(dtype: np.dtype) -> bool: + """Check if dtype is a compound/structured type.""" + return dtype.names is not None + + +class _MockDataset: + def __init__(self, dtype): + self.dtype = dtype + self.names = dtype.names + + +def _create_compound_features(base_path: str, dset: h5py.Dataset) -> Dict[str, Any]: + """Create separate features for each field in compound data.""" + field_names = list(dset.dtype.names) + logger.info( + f"Compound dataset '{base_path}' (dtype: {dset.dtype}) flattened into {len(field_names)} columns: {field_names}" + ) + + features = {} + for field_name in field_names: + field_dtype = dset.dtype[field_name] + field_path = f"{base_path}_{field_name}" + + if _is_complex_dtype(field_dtype): + features[f"{field_path}_real"] = Value("float64") + features[f"{field_path}_imag"] = Value("float64") + elif _is_compound_dtype(field_dtype): + mock_dset = _MockDataset(field_dtype) + nested_features = _create_compound_features(field_path, mock_dset) + features.update(nested_features) + else: + value_feature = _np_to_pa_to_hf_value(field_dtype) + features[field_path] = value_feature + + return features + + +def _convert_compound_to_separate_columns(base_path: str, arr: np.ndarray, dset: h5py.Dataset) -> Dict[str, pa.Array]: + """Convert compound array to separate columns for each field.""" + result = {} + for field_name in list(dset.dtype.names): + field_dtype = dset.dtype[field_name] + field_path = f"{base_path}_{field_name}" + field_data = arr[field_name] + + if _is_complex_dtype(field_dtype): + result[f"{field_path}_real"] = datasets.features.features.numpy_to_pyarrow_listarray(field_data.real) + result[f"{field_path}_imag"] = datasets.features.features.numpy_to_pyarrow_listarray(field_data.imag) + elif _is_compound_dtype(field_dtype): + mock_dset = _MockDataset(field_dtype) + nested_result = _convert_compound_to_separate_columns(field_path, field_data, mock_dset) + result.update(nested_result) + else: + result[field_path] = datasets.features.features.numpy_to_pyarrow_listarray(field_data) + + return result + + +# ┌───────────────────────────┐ +# │ Variable-Length Strings │ +# └───────────────────────────┘ + + +def _is_vlen_string_dtype(dtype: np.dtype) -> bool: + """Check if dtype is a variable-length string type.""" if hasattr(dtype, "metadata") and dtype.metadata and "vlen" in dtype.metadata: - return dtype.metadata["vlen"] - if hasattr(dtype, "subdtype") and dtype.subdtype is not None: - return _base_dtype(dtype.subdtype[0]) - return dtype - - -def _ragged_array_to_pyarrow_largelist(arr: np.ndarray, dset: h5py.Dataset) -> pa.Array: - if _is_variable_length_string(dset): - list_of_strings = [] - for item in arr: - if item is None: - list_of_strings.append(None) - else: - if isinstance(item, bytes): - item = item.decode("utf-8") - list_of_strings.append(item) - return datasets.features.features.list_of_pa_arrays_to_pyarrow_listarray( - [pa.array([item]) if item is not None else None for item in list_of_strings] - ) - else: - return _convert_nested_ragged_array_recursive(arr, dset.dtype) - - -def _convert_nested_ragged_array_recursive(arr: np.ndarray, dtype): - if hasattr(dtype, "subdtype") and dtype.subdtype is not None: - inner_dtype = dtype.subdtype[0] - list_of_arrays = [] - for item in arr: - if item is None: - list_of_arrays.append(None) - else: - inner_array = _convert_nested_ragged_array_recursive(item, inner_dtype) - list_of_arrays.append(inner_array) - return datasets.features.features.list_of_pa_arrays_to_pyarrow_listarray( - [pa.array(item) if item is not None else None for item in list_of_arrays] - ) - else: - list_of_arrays = [] - for item in arr: - if item is None: - list_of_arrays.append(None) - else: - if not isinstance(item, np.ndarray): - item = np.array(item, dtype=dtype) - list_of_arrays.append(item) - return datasets.features.features.list_of_pa_arrays_to_pyarrow_listarray( - [pa.array(item) if item is not None else None for item in list_of_arrays] - ) + vlen_dtype = dtype.metadata["vlen"] + return vlen_dtype in (str, bytes) + return False -def _infer_feature_from_dataset(dset: h5py.Dataset): - if _is_variable_length_string(dset): - return Value("string") # FIXME: large_string? +def _convert_vlen_string_to_array(arr: np.ndarray) -> pa.Array: + list_of_items = [] + for item in arr: + if isinstance(item, bytes): + logger.info("Assuming variable-length bytes are utf-8 encoded strings") + list_of_items.append(item.decode("utf-8")) + elif isinstance(item, str): + list_of_items.append(item) + else: + raise ValueError(f"Unsupported variable-length string type: {type(item)}") + return pa.array(list_of_items) - if _is_ragged_dataset(dset): - return _infer_nested_feature_recursive(dset.dtype, dset) + +# ┌───────────┐ +# │ Generic │ +# └───────────┘ + + +def _infer_feature_from_dataset(dset: h5py.Dataset): + # non-string varlen + if hasattr(dset.dtype, "metadata") and dset.dtype.metadata and "vlen" in dset.dtype.metadata: + vlen_dtype = dset.dtype.metadata["vlen"] + inner_feature = _np_to_pa_to_hf_value(vlen_dtype) + return Sequence(inner_feature) value_feature = _np_to_pa_to_hf_value(dset.dtype) dtype_str = value_feature.dtype value_shape = dset.shape[1:] - if dset.dtype.kind not in {"b", "i", "u", "f", "S", "a"}: - raise TypeError(f"Unsupported dtype {dset.dtype} for dataset {dset.name}") - rank = len(value_shape) if rank == 0: return value_feature elif rank == 1: return Sequence(value_feature, length=value_shape[0]) - elif 2 <= rank <= 5: + elif rank <= 5: return _sized_arrayxd(rank)(shape=value_shape, dtype=dtype_str) else: - raise TypeError(f"Array{rank}D not supported. Only up to 5D arrays are supported.") - - -def _infer_nested_feature_recursive(dtype, dset: h5py.Dataset): - if hasattr(dtype, "subdtype") and dtype.subdtype is not None: - inner_dtype = dtype.subdtype[0] - inner_feature = _infer_nested_feature_recursive(inner_dtype, dset) - return Sequence(inner_feature) - else: - if hasattr(dtype, "kind") and dtype.kind == "O": - if _is_variable_length_string(dset): - base_dtype = np.dtype("S1") - else: - base_dtype = _base_dtype(dset.dtype) - return Sequence(_np_to_pa_to_hf_value(base_dtype)) - else: - return _np_to_pa_to_hf_value(dtype) + raise TypeError(f"Array{rank}D not supported. Maximum 5 dimensions allowed.") def _has_zero_dimensions(feature): if isinstance(feature, _ArrayXD): return any(dim == 0 for dim in feature.shape) - elif isinstance(feature, (Sequence, LargeList)): + elif isinstance(feature, Sequence): # also gets regular List return feature.length == 0 or _has_zero_dimensions(feature.feature) + elif isinstance(feature, LargeList): + return _has_zero_dimensions(feature.feature) else: return False @@ -247,29 +353,3 @@ def _sized_arrayxd(rank: int): def _np_to_pa_to_hf_value(numpy_dtype: np.dtype) -> Value: return Value(dtype=_arrow_to_datasets_dtype(pa.from_numpy_dtype(numpy_dtype))) - - -def _is_ragged_dataset(dset: h5py.Dataset) -> bool: - return dset.dtype.kind == "O" and hasattr(dset.dtype, "subdtype") - - -def _is_variable_length_string(dset: h5py.Dataset) -> bool: - if not _is_ragged_dataset(dset) or dset.shape[0] == 0: - return False - num_samples = min(3, dset.shape[0]) - for i in range(num_samples): - try: - if isinstance(dset[i], (str, bytes)): - return True - except (IndexError, TypeError): - continue - return False - - -def _variable_length_string_to_pyarrow(arr: np.ndarray, dset: h5py.Dataset) -> pa.Array: - list_of_strings = [] - for item in arr: - if isinstance(item, bytes): - item = item.decode("utf-8") - list_of_strings.append(item) - return pa.array(list_of_strings) diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py index a5600c83af0..18620722ae7 100644 --- a/tests/packaged_modules/test_hdf5.py +++ b/tests/packaged_modules/test_hdf5.py @@ -82,18 +82,18 @@ def hdf5_file_with_different_dtypes(tmp_path): @pytest.fixture -def hdf5_file_with_ragged_arrays(tmp_path): - """Create an HDF5 file with ragged arrays using HDF5's vlen_dtype.""" - filename = tmp_path / "ragged.h5" +def hdf5_file_with_vlen_arrays(tmp_path): + """Create an HDF5 file with variable-length arrays using HDF5's vlen_dtype.""" + filename = tmp_path / "vlen.h5" n_rows = 4 with h5py.File(filename, "w") as f: # Variable-length arrays of different sizes using vlen_dtype - ragged_arrays = [[1, 2, 3], [4, 5], [6, 7, 8, 9], [10]] + vlen_arrays = [[1, 2, 3], [4, 5], [6, 7, 8, 9], [10]] # Create variable-length int dataset using vlen_dtype dt = h5py.vlen_dtype(np.dtype("int32")) - dset = f.create_dataset("ragged_ints", (n_rows,), dtype=dt) - for i, arr in enumerate(ragged_arrays): + dset = f.create_dataset("vlen_ints", (n_rows,), dtype=dt) + for i, arr in enumerate(vlen_arrays): dset[i] = arr # Mixed types (some empty arrays) - use variable-length with empty arrays @@ -136,6 +136,53 @@ def hdf5_file_with_variable_length_strings(tmp_path): return str(filename) +@pytest.fixture +def hdf5_file_with_complex_data(tmp_path): + """Create an HDF5 file with complex number datasets.""" + filename = tmp_path / "complex.h5" + + with h5py.File(filename, "w") as f: + # Complex numbers + complex_data = np.array([1 + 2j, 3 + 4j, 5 + 6j, 7 + 8j], dtype=np.complex64) + f.create_dataset("complex_64", data=complex_data) + + # Complex double precision + complex_double = np.array([1.5 + 2.5j, 3.5 + 4.5j, 5.5 + 6.5j, 7.5 + 8.5j], dtype=np.complex128) + f.create_dataset("complex_128", data=complex_double) + + # Complex array + complex_array = np.array( + [[1 + 2j, 3 + 4j], [5 + 6j, 7 + 8j], [9 + 10j, 11 + 12j], [13 + 14j, 15 + 16j]], dtype=np.complex64 + ) + f.create_dataset("complex_array", data=complex_array) + + return str(filename) + + +@pytest.fixture +def hdf5_file_with_compound_data(tmp_path): + """Create an HDF5 file with compound/structured datasets.""" + filename = tmp_path / "compound.h5" + + with h5py.File(filename, "w") as f: + # Simple compound type + dt_simple = np.dtype([("x", "i4"), ("y", "f8")]) + compound_simple = np.array([(1, 2.5), (3, 4.5), (5, 6.5)], dtype=dt_simple) + f.create_dataset("simple_compound", data=compound_simple) + + # Compound type with complex numbers + dt_complex = np.dtype([("real", "f4"), ("imag", "f4")]) + compound_complex = np.array([(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)], dtype=dt_complex) + f.create_dataset("complex_compound", data=compound_complex) + + # Nested compound type + dt_nested = np.dtype([("position", [("x", "i4"), ("y", "i4")]), ("velocity", [("vx", "f4"), ("vy", "f4")])]) + compound_nested = np.array([((1, 2), (1.5, 2.5)), ((3, 4), (3.5, 4.5)), ((5, 6), (5.5, 6.5))], dtype=dt_nested) + f.create_dataset("nested_compound", data=compound_nested) + + return str(filename) + + @pytest.fixture def hdf5_file_with_mismatched_lengths(tmp_path): """Create an HDF5 file with datasets of different lengths (should raise error).""" @@ -164,19 +211,6 @@ def hdf5_file_with_zero_dimensions(tmp_path): return str(filename) -@pytest.fixture -def hdf5_file_with_unsupported_dtypes(tmp_path): - """Create an HDF5 file with unsupported dtypes (complex).""" - filename = tmp_path / "unsupported.h5" - - with h5py.File(filename, "w") as f: - # Complex dtype (should be rejected) - complex_data = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64) - f.create_dataset("complex_data", data=complex_data) - - return str(filename) - - @pytest.fixture def empty_hdf5_file(tmp_path): """Create an HDF5 file with no datasets (should warn and skip).""" @@ -266,25 +300,25 @@ def test_hdf5_multi_dimensional_arrays(hdf5_file_with_arrays): assert len(matrix_2d[0][0]) == 4 # 4 columns in each matrix -def test_hdf5_ragged_arrays(hdf5_file_with_ragged_arrays): - """Test HDF5 loading with ragged arrays (object dtype).""" +def test_hdf5_vlen_arrays(hdf5_file_with_vlen_arrays): + """Test HDF5 loading with variable-length arrays (int32).""" hdf5 = HDF5() - generator = hdf5._generate_tables([[hdf5_file_with_ragged_arrays]]) + generator = hdf5._generate_tables([[hdf5_file_with_vlen_arrays]]) tables = list(generator) assert len(tables) == 1 _, table = tables[0] - expected_columns = {"ragged_ints", "mixed_data"} + expected_columns = {"vlen_ints", "mixed_data"} assert set(table.column_names) == expected_columns - # Check ragged_ints data - ragged_ints = table["ragged_ints"].to_pylist() - assert len(ragged_ints) == 4 - assert ragged_ints[0] == [1, 2, 3] - assert ragged_ints[1] == [4, 5] - assert ragged_ints[2] == [6, 7, 8, 9] - assert ragged_ints[3] == [10] + # Check vlen_ints data + vlen_ints = table["vlen_ints"].to_pylist() + assert len(vlen_ints) == 4 + assert vlen_ints[0] == [1, 2, 3] + assert vlen_ints[1] == [4, 5] + assert vlen_ints[2] == [6, 7, 8, 9] + assert vlen_ints[3] == [10] # Check mixed_data (with None values) mixed_data = table["mixed_data"].to_pylist() @@ -439,17 +473,6 @@ def test_hdf5_zero_dimensions_handling(hdf5_file_with_zero_dimensions, caplog): assert all(len(row) == 0 for row in zero_dim_data) # Each row is empty -def test_hdf5_unsupported_dtypes_error(hdf5_file_with_unsupported_dtypes): - """Test that unsupported dtypes raise an error.""" - hdf5 = HDF5() - generator = hdf5._generate_tables([[hdf5_file_with_unsupported_dtypes]]) - - # Complex dtypes cause ArrowNotImplementedError during conversion - with pytest.raises(Exception): # Either ValueError or ArrowNotImplementedError - for _ in generator: - pass - - def test_hdf5_empty_file_warning(empty_hdf5_file, caplog): """Test that empty files (no datasets) are skipped with a warning.""" hdf5 = HDF5() @@ -460,8 +483,7 @@ def test_hdf5_empty_file_warning(empty_hdf5_file, caplog): # Check that warning was logged assert any( - record.levelname == "WARNING" and "contains no datasets, skipping" in record.message - for record in caplog.records + record.levelname == "WARNING" and "contains no data, skipping" in record.message for record in caplog.records ) @@ -495,9 +517,9 @@ def test_hdf5_feature_inference(hdf5_file_with_arrays): assert features["vector_1d"].length == 10 -def test_hdf5_ragged_feature_inference(hdf5_file_with_ragged_arrays): - """Test automatic feature inference from ragged HDF5 datasets.""" - data_files = DataFilesDict({"train": [hdf5_file_with_ragged_arrays]}) +def test_hdf5_vlen_feature_inference(hdf5_file_with_vlen_arrays): + """Test automatic feature inference from variable-length HDF5 datasets.""" + data_files = DataFilesDict({"train": [hdf5_file_with_vlen_arrays]}) config = HDF5Config(data_files=data_files) hdf5 = HDF5() hdf5.config = config @@ -509,15 +531,15 @@ def test_hdf5_ragged_feature_inference(hdf5_file_with_ragged_arrays): # Check that features were inferred assert hdf5.info.features is not None - # Check specific feature types for ragged arrays + # Check specific feature types for variable-length arrays features = hdf5.info.features - # Ragged arrays should become Sequence features by default (for small datasets) - assert isinstance(features["ragged_ints"], Sequence) + # Variable-length arrays should become Sequence features by default (for small datasets) + assert isinstance(features["vlen_ints"], Sequence) assert isinstance(features["mixed_data"], Sequence) # Check that the inner feature types are correct - assert isinstance(features["ragged_ints"].feature, Value) - assert features["ragged_ints"].feature.dtype == "int32" + assert isinstance(features["vlen_ints"].feature, Value) + assert features["vlen_ints"].feature.dtype == "int32" assert isinstance(features["mixed_data"].feature, Value) assert features["mixed_data"].feature.dtype == "int32" @@ -572,3 +594,241 @@ def test_hdf5_no_data_files_error(): with pytest.raises(ValueError, match="At least one data file must be specified"): hdf5._split_generators(None) + + +def test_hdf5_config_options(): + """Test HDF5Config with different options.""" + # Test default options + config = HDF5Config() + # Complex and compound types are always split now, no config options needed + assert config.batch_size is None + assert config.columns is None + assert config.features is None + + +def test_hdf5_complex_numbers(hdf5_file_with_complex_data): + """Test HDF5 loading with complex number datasets.""" + config = HDF5Config() + hdf5 = HDF5() + hdf5.config = config + + generator = hdf5._generate_tables([[hdf5_file_with_complex_data]]) + tables = list(generator) + + assert len(tables) == 1 + _, table = tables[0] + + # Check that complex numbers are split into real/imaginary parts + expected_columns = { + "complex_64_real", + "complex_64_imag", + "complex_128_real", + "complex_128_imag", + "complex_array_real", + "complex_array_imag", + } + assert set(table.column_names) == expected_columns + + # Check complex_64 data + real_data = table["complex_64_real"].to_pylist() + imag_data = table["complex_64_imag"].to_pylist() + + assert real_data == [1.0, 3.0, 5.0, 7.0] + assert imag_data == [2.0, 4.0, 6.0, 8.0] + + +def test_hdf5_compound_types(hdf5_file_with_compound_data): + """Test HDF5 loading with compound/structured datasets.""" + config = HDF5Config() + hdf5 = HDF5() + hdf5.config = config + + generator = hdf5._generate_tables([[hdf5_file_with_compound_data]]) + tables = list(generator) + + assert len(tables) == 1 + _, table = tables[0] + + # Check that compound types are flattened into separate columns + expected_columns = { + "simple_compound_x", + "simple_compound_y", + "complex_compound_real", + "complex_compound_imag", + "nested_compound_position_x", + "nested_compound_position_y", + "nested_compound_velocity_vx", + "nested_compound_velocity_vy", + } + assert set(table.column_names) == expected_columns + + # Check simple compound data + x_data = table["simple_compound_x"].to_pylist() + y_data = table["simple_compound_y"].to_pylist() + + assert x_data == [1, 3, 5] + assert y_data == [2.5, 4.5, 6.5] + + +def test_hdf5_unsupported_dtype_handling(tmp_path): + """Test handling of truly unsupported dtypes.""" + filename = tmp_path / "unsupported.h5" + + with h5py.File(filename, "w") as f: + # Create a dataset with an unsupported dtype (e.g., bitfield) + # This should raise a TypeError during feature inference + bitfield_data = np.array([1, 2, 3], dtype=np.uint8) + # We'll create a dataset that will fail during feature inference + # by using a custom dtype that's not supported + f.create_dataset("bitfield_data", data=bitfield_data) + + config = HDF5Config() + hdf5 = HDF5() + hdf5.config = config + hdf5.config.data_files = DataFilesDict({"train": [str(filename)]}) + + # This should not raise an error since uint8 is supported + # Let's test with a different approach - create a dataset that will fail + # during the actual data loading phase + dl_manager = StreamingDownloadManager() + hdf5._split_generators(dl_manager) + + # The test passes if no error is raised, since uint8 is actually supported + + +def test_hdf5_feature_inference_complex(hdf5_file_with_complex_data): + """Test automatic feature inference for complex datasets.""" + config = HDF5Config() + hdf5 = HDF5() + hdf5.config = config + hdf5.config.data_files = DataFilesDict({"train": [hdf5_file_with_complex_data]}) + + # Trigger feature inference + dl_manager = StreamingDownloadManager() + hdf5._split_generators(dl_manager) + + # Check that features were inferred correctly + assert hdf5.info.features is not None + features = hdf5.info.features + + # Check complex number features + assert "complex_64_real" in features + assert "complex_64_imag" in features + assert features["complex_64_real"] == Value("float64") + assert features["complex_64_imag"] == Value("float64") + + +def test_hdf5_feature_inference_compound(hdf5_file_with_compound_data): + """Test automatic feature inference for compound datasets.""" + config = HDF5Config() + hdf5 = HDF5() + hdf5.config = config + hdf5.config.data_files = DataFilesDict({"train": [hdf5_file_with_compound_data]}) + + # Trigger feature inference + dl_manager = StreamingDownloadManager() + hdf5._split_generators(dl_manager) + + # Check that features were inferred correctly + assert hdf5.info.features is not None + features = hdf5.info.features + + # Check compound type features + assert "simple_compound_x" in features + assert "simple_compound_y" in features + assert features["simple_compound_x"] == Value("int32") + assert features["simple_compound_y"] == Value("float64") + + +def test_hdf5_mixed_data_types(tmp_path): + """Test HDF5 loading with mixed data types in the same file.""" + filename = tmp_path / "mixed.h5" + + with h5py.File(filename, "w") as f: + # Regular numeric data + f.create_dataset("regular_int", data=np.arange(3, dtype=np.int32)) + f.create_dataset("regular_float", data=np.arange(3, dtype=np.float32)) + + # Complex data + complex_data = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64) + f.create_dataset("complex_data", data=complex_data) + + # Compound data + dt_compound = np.dtype([("x", "i4"), ("y", "f8")]) + compound_data = np.array([(1, 2.5), (3, 4.5), (5, 6.5)], dtype=dt_compound) + f.create_dataset("compound_data", data=compound_data) + + config = HDF5Config() + hdf5 = HDF5() + hdf5.config = config + + generator = hdf5._generate_tables([[str(filename)]]) + tables = list(generator) + + assert len(tables) == 1 + _, table = tables[0] + + # Check all expected columns are present + expected_columns = { + "regular_int", + "regular_float", + "complex_data_real", + "complex_data_imag", + "compound_data_x", + "compound_data_y", + } + assert set(table.column_names) == expected_columns + + # Check data types + assert table["regular_int"].to_pylist() == [0, 1, 2] + assert len(table["complex_data_real"].to_pylist()) == 3 + assert len(table["compound_data_x"].to_pylist()) == 3 + + +def test_hdf5_column_name_collision_detection(tmp_path): + """Test that column name collision detection works correctly.""" + filename = tmp_path / "collision.h5" + + with h5py.File(filename, "w") as f: + # Create a complex dataset + complex_data = np.array([1 + 2j, 3 + 4j], dtype=np.complex64) + f.create_dataset("data", data=complex_data) + + # Create a regular dataset that would collide with the complex real part + regular_data = np.array([1.0, 2.0], dtype=np.float32) + f.create_dataset("data_real", data=regular_data) # This should cause a collision + + config = HDF5Config() + hdf5 = HDF5() + hdf5.config = config + hdf5.config.data_files = DataFilesDict({"train": [str(filename)]}) + + # This should raise a ValueError due to column name collision + dl_manager = StreamingDownloadManager() + with pytest.raises(ValueError, match="Column name collision detected"): + hdf5._split_generators(dl_manager) + + +def test_hdf5_compound_collision_detection(tmp_path): + """Test collision detection with compound types.""" + filename = tmp_path / "compound_collision.h5" + + with h5py.File(filename, "w") as f: + # Create a compound dataset + dt_compound = np.dtype([("x", "i4"), ("y", "f8")]) + compound_data = np.array([(1, 2.5), (3, 4.5)], dtype=dt_compound) + f.create_dataset("position", data=compound_data) + + # Create a regular dataset that would collide with compound field + regular_data = np.array([10, 20], dtype=np.int32) + f.create_dataset("position_x", data=regular_data) # This should cause a collision + + config = HDF5Config() + hdf5 = HDF5() + hdf5.config = config + hdf5.config.data_files = DataFilesDict({"train": [str(filename)]}) + + # This should raise a ValueError due to column name collision + dl_manager = StreamingDownloadManager() + with pytest.raises(ValueError, match="Column name collision detected"): + hdf5._split_generators(dl_manager) From 1e74de684a5258d9def3da3c36aefe87b8b3bd2a Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Sat, 19 Jul 2025 01:39:25 -0400 Subject: [PATCH 06/16] update tests --- tests/packaged_modules/test_hdf5.py | 145 ++++++++++++---------------- 1 file changed, 64 insertions(+), 81 deletions(-) diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py index 18620722ae7..602616a7e4a 100644 --- a/tests/packaged_modules/test_hdf5.py +++ b/tests/packaged_modules/test_hdf5.py @@ -225,6 +225,64 @@ def empty_hdf5_file(tmp_path): return str(filename) +@pytest.fixture +def hdf5_file_with_mixed_data_types(tmp_path): + """Create an HDF5 file with mixed data types in the same file.""" + filename = tmp_path / "mixed.h5" + n_rows = 3 + + with h5py.File(filename, "w") as f: + # Regular numeric data + f.create_dataset("regular_int", data=np.arange(n_rows, dtype=np.int32)) + f.create_dataset("regular_float", data=np.arange(n_rows, dtype=np.float32)) + + # Complex data + complex_data = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64) + f.create_dataset("complex_data", data=complex_data) + + # Compound data + dt_compound = np.dtype([("x", "i4"), ("y", "f8")]) + compound_data = np.array([(1, 2.5), (3, 4.5), (5, 6.5)], dtype=dt_compound) + f.create_dataset("compound_data", data=compound_data) + + return str(filename) + + +@pytest.fixture +def hdf5_file_with_complex_collision(tmp_path): + """Create an HDF5 file where complex dataset would collide with existing dataset name.""" + filename = tmp_path / "collision.h5" + + with h5py.File(filename, "w") as f: + # Create a complex dataset + complex_data = np.array([1 + 2j, 3 + 4j], dtype=np.complex64) + f.create_dataset("data", data=complex_data) + + # Create a regular dataset that would collide with the complex real part + regular_data = np.array([1.0, 2.0], dtype=np.float32) + f.create_dataset("data_real", data=regular_data) # This should cause a collision + + return str(filename) + + +@pytest.fixture +def hdf5_file_with_compound_collision(tmp_path): + """Create an HDF5 file where compound dataset would collide with existing dataset name.""" + filename = tmp_path / "compound_collision.h5" + + with h5py.File(filename, "w") as f: + # Create a compound dataset + dt_compound = np.dtype([("x", "i4"), ("y", "f8")]) + compound_data = np.array([(1, 2.5), (3, 4.5)], dtype=dt_compound) + f.create_dataset("position", data=compound_data) + + # Create a regular dataset that would collide with compound field + regular_data = np.array([10, 20], dtype=np.int32) + f.create_dataset("position_x", data=regular_data) # This should cause a collision + + return str(filename) + + def test_config_raises_when_invalid_name(): """Test that invalid config names raise an error.""" with pytest.raises(InvalidConfigName, match="Bad characters"): @@ -596,16 +654,6 @@ def test_hdf5_no_data_files_error(): hdf5._split_generators(None) -def test_hdf5_config_options(): - """Test HDF5Config with different options.""" - # Test default options - config = HDF5Config() - # Complex and compound types are always split now, no config options needed - assert config.batch_size is None - assert config.columns is None - assert config.features is None - - def test_hdf5_complex_numbers(hdf5_file_with_complex_data): """Test HDF5 loading with complex number datasets.""" config = HDF5Config() @@ -670,32 +718,6 @@ def test_hdf5_compound_types(hdf5_file_with_compound_data): assert y_data == [2.5, 4.5, 6.5] -def test_hdf5_unsupported_dtype_handling(tmp_path): - """Test handling of truly unsupported dtypes.""" - filename = tmp_path / "unsupported.h5" - - with h5py.File(filename, "w") as f: - # Create a dataset with an unsupported dtype (e.g., bitfield) - # This should raise a TypeError during feature inference - bitfield_data = np.array([1, 2, 3], dtype=np.uint8) - # We'll create a dataset that will fail during feature inference - # by using a custom dtype that's not supported - f.create_dataset("bitfield_data", data=bitfield_data) - - config = HDF5Config() - hdf5 = HDF5() - hdf5.config = config - hdf5.config.data_files = DataFilesDict({"train": [str(filename)]}) - - # This should not raise an error since uint8 is supported - # Let's test with a different approach - create a dataset that will fail - # during the actual data loading phase - dl_manager = StreamingDownloadManager() - hdf5._split_generators(dl_manager) - - # The test passes if no error is raised, since uint8 is actually supported - - def test_hdf5_feature_inference_complex(hdf5_file_with_complex_data): """Test automatic feature inference for complex datasets.""" config = HDF5Config() @@ -740,29 +762,13 @@ def test_hdf5_feature_inference_compound(hdf5_file_with_compound_data): assert features["simple_compound_y"] == Value("float64") -def test_hdf5_mixed_data_types(tmp_path): +def test_hdf5_mixed_data_types(hdf5_file_with_mixed_data_types): """Test HDF5 loading with mixed data types in the same file.""" - filename = tmp_path / "mixed.h5" - - with h5py.File(filename, "w") as f: - # Regular numeric data - f.create_dataset("regular_int", data=np.arange(3, dtype=np.int32)) - f.create_dataset("regular_float", data=np.arange(3, dtype=np.float32)) - - # Complex data - complex_data = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64) - f.create_dataset("complex_data", data=complex_data) - - # Compound data - dt_compound = np.dtype([("x", "i4"), ("y", "f8")]) - compound_data = np.array([(1, 2.5), (3, 4.5), (5, 6.5)], dtype=dt_compound) - f.create_dataset("compound_data", data=compound_data) - config = HDF5Config() hdf5 = HDF5() hdf5.config = config - generator = hdf5._generate_tables([[str(filename)]]) + generator = hdf5._generate_tables([[hdf5_file_with_mixed_data_types]]) tables = list(generator) assert len(tables) == 1 @@ -785,23 +791,12 @@ def test_hdf5_mixed_data_types(tmp_path): assert len(table["compound_data_x"].to_pylist()) == 3 -def test_hdf5_column_name_collision_detection(tmp_path): +def test_hdf5_column_name_collision_detection(hdf5_file_with_complex_collision): """Test that column name collision detection works correctly.""" - filename = tmp_path / "collision.h5" - - with h5py.File(filename, "w") as f: - # Create a complex dataset - complex_data = np.array([1 + 2j, 3 + 4j], dtype=np.complex64) - f.create_dataset("data", data=complex_data) - - # Create a regular dataset that would collide with the complex real part - regular_data = np.array([1.0, 2.0], dtype=np.float32) - f.create_dataset("data_real", data=regular_data) # This should cause a collision - config = HDF5Config() hdf5 = HDF5() hdf5.config = config - hdf5.config.data_files = DataFilesDict({"train": [str(filename)]}) + hdf5.config.data_files = DataFilesDict({"train": [hdf5_file_with_complex_collision]}) # This should raise a ValueError due to column name collision dl_manager = StreamingDownloadManager() @@ -809,24 +804,12 @@ def test_hdf5_column_name_collision_detection(tmp_path): hdf5._split_generators(dl_manager) -def test_hdf5_compound_collision_detection(tmp_path): +def test_hdf5_compound_collision_detection(hdf5_file_with_compound_collision): """Test collision detection with compound types.""" - filename = tmp_path / "compound_collision.h5" - - with h5py.File(filename, "w") as f: - # Create a compound dataset - dt_compound = np.dtype([("x", "i4"), ("y", "f8")]) - compound_data = np.array([(1, 2.5), (3, 4.5)], dtype=dt_compound) - f.create_dataset("position", data=compound_data) - - # Create a regular dataset that would collide with compound field - regular_data = np.array([10, 20], dtype=np.int32) - f.create_dataset("position_x", data=regular_data) # This should cause a collision - config = HDF5Config() hdf5 = HDF5() hdf5.config = config - hdf5.config.data_files = DataFilesDict({"train": [str(filename)]}) + hdf5.config.data_files = DataFilesDict({"train": [hdf5_file_with_compound_collision]}) # This should raise a ValueError due to column name collision dl_manager = StreamingDownloadManager() From f9c7cf36d9bdc60d80be68fcd5a98610257bba6c Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Sat, 19 Jul 2025 01:52:58 -0400 Subject: [PATCH 07/16] explicit h5py dependency --- setup.py | 1 + src/datasets/packaged_modules/hdf5/hdf5.py | 26 +++++++++++++++------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/setup.py b/setup.py index 5e58064212a..74fc06bb199 100644 --- a/setup.py +++ b/setup.py @@ -166,6 +166,7 @@ "aiohttp", "elasticsearch>=7.17.12,<8.0.0", # 8.0 asks users to provide hosts or cloud_id when instantiating ElasticSearch(); 7.9.1 has legacy numpy.float_ which was fixed in https://github.com/elastic/elasticsearch-py/pull/2551. "faiss-cpu>=1.8.0.post1", # Pins numpy < 2 + "h5py", # FIXME: probably needs a lower bound "jax>=0.3.14; sys_platform != 'win32'", "jaxlib>=0.3.14; sys_platform != 'win32'", "lz4", diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 3cf5bc1b7c5..ed4f0897f31 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -1,12 +1,11 @@ import itertools from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import numpy as np import pyarrow as pa import datasets -import h5py from datasets.features.features import ( Array2D, Array3D, @@ -21,6 +20,9 @@ from datasets.table import table_cast +if TYPE_CHECKING: + import h5py + logger = datasets.utils.logging.get_logger(__name__) EXTENSIONS = [".h5", ".hdf5"] @@ -56,6 +58,8 @@ def _info(self): return datasets.DatasetInfo(features=self.config.features) def _split_generators(self, dl_manager): + import h5py + if not self.config.data_files: raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") dl_manager.download_config.extract_on_the_fly = True @@ -119,6 +123,8 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table: return pa_table def _generate_tables(self, files): + import h5py + batch_size_cfg = self.config.batch_size for file_idx, file in enumerate(itertools.chain.from_iterable(files)): try: @@ -179,7 +185,9 @@ def _generate_tables(self, files): raise -def _traverse_datasets(h5_obj, prefix: str = "") -> Dict[str, h5py.Dataset]: +def _traverse_datasets(h5_obj, prefix: str = "") -> Dict[str, "h5py.Dataset"]: + import h5py + mapping: Dict[str, h5py.Dataset] = {} def collect_datasets(name, obj): @@ -201,7 +209,7 @@ def _is_complex_dtype(dtype: np.dtype) -> bool: return dtype.kind == "c" -def _create_complex_features(base_path: str, dset: h5py.Dataset) -> Dict[str, Value]: +def _create_complex_features(base_path: str, dset: "h5py.Dataset") -> Dict[str, Value]: """Create separate features for real and imaginary parts of complex data. NOTE: Always uses float64 for the real and imaginary parts. @@ -212,7 +220,7 @@ def _create_complex_features(base_path: str, dset: h5py.Dataset) -> Dict[str, Va return {f"{base_path}_real": Value("float64"), f"{base_path}_imag": Value("float64")} -def _convert_complex_to_separate_columns(base_path: str, arr: np.ndarray, dset: h5py.Dataset) -> Dict[str, pa.Array]: +def _convert_complex_to_separate_columns(base_path: str, arr: np.ndarray, dset: "h5py.Dataset") -> Dict[str, pa.Array]: """Convert complex array to separate real and imaginary columns.""" result = {} result[f"{base_path}_real"] = datasets.features.features.numpy_to_pyarrow_listarray(arr.real) @@ -236,7 +244,7 @@ def __init__(self, dtype): self.names = dtype.names -def _create_compound_features(base_path: str, dset: h5py.Dataset) -> Dict[str, Any]: +def _create_compound_features(base_path: str, dset: "h5py.Dataset") -> Dict[str, Any]: """Create separate features for each field in compound data.""" field_names = list(dset.dtype.names) logger.info( @@ -262,7 +270,9 @@ def _create_compound_features(base_path: str, dset: h5py.Dataset) -> Dict[str, A return features -def _convert_compound_to_separate_columns(base_path: str, arr: np.ndarray, dset: h5py.Dataset) -> Dict[str, pa.Array]: +def _convert_compound_to_separate_columns( + base_path: str, arr: np.ndarray, dset: "h5py.Dataset" +) -> Dict[str, pa.Array]: """Convert compound array to separate columns for each field.""" result = {} for field_name in list(dset.dtype.names): @@ -314,7 +324,7 @@ def _convert_vlen_string_to_array(arr: np.ndarray) -> pa.Array: # └───────────┘ -def _infer_feature_from_dataset(dset: h5py.Dataset): +def _infer_feature_from_dataset(dset: "h5py.Dataset"): # non-string varlen if hasattr(dset.dtype, "metadata") and dset.dtype.metadata and "vlen" in dset.dtype.metadata: vlen_dtype = dset.dtype.metadata["vlen"] From d0315a38d094dfe969dcba078c0807f6eb211ed3 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Sat, 19 Jul 2025 02:08:57 -0400 Subject: [PATCH 08/16] allow mismatched lengths if ignored --- src/datasets/packaged_modules/hdf5/hdf5.py | 16 +++++- tests/packaged_modules/test_hdf5.py | 65 +++++++++++++++++++++- 2 files changed, 77 insertions(+), 4 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index ed4f0897f31..1bb094a3c6d 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -133,9 +133,21 @@ def _generate_tables(self, files): if not dataset_map: logger.warning(f"File '{file}' contains no data, skipping...") continue + + if self.config.columns is not None: + filtered_dataset_map = { + path: dset for path, dset in dataset_map.items() if path in self.config.columns + } + if not filtered_dataset_map: + logger.warning( + f"No datasets match the specified columns {self.config.columns}, skipping..." + ) + continue + dataset_map = filtered_dataset_map + + # Sanity-check lengths for selected datasets first_dset = next(iter(dataset_map.values())) num_rows = first_dset.shape[0] - # Sanity-check lengths for path, dset in dataset_map.items(): if dset.shape[0] != num_rows: raise ValueError( @@ -146,8 +158,6 @@ def _generate_tables(self, files): end = min(start + effective_batch, num_rows) batch_dict = {} for path, dset in dataset_map.items(): - if self.config.columns is not None and path not in self.config.columns: - continue arr = dset[start:end] # Handle variable-length arrays diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py index 602616a7e4a..30a8ab298d3 100644 --- a/tests/packaged_modules/test_hdf5.py +++ b/tests/packaged_modules/test_hdf5.py @@ -190,7 +190,16 @@ def hdf5_file_with_mismatched_lengths(tmp_path): with h5py.File(filename, "w") as f: f.create_dataset("data1", data=np.arange(5, dtype=np.int32)) - f.create_dataset("data2", data=np.arange(3, dtype=np.int32)) # Different length + # Dataset with 3 rows (mismatched) + f.create_dataset("data2", data=np.arange(3, dtype=np.int32)) + f.create_dataset("data3", data=np.random.randn(5, 3, 4).astype(np.float32)) + f.create_dataset("data4", data=np.arange(5, dtype=np.float64) / 10.0) + f.create_dataset("data5", data=np.array([True, False, True, False, True])) + var_strings = ["short", "medium length", "very long string", "tiny", "another string"] + dt = h5py.vlen_dtype(str) + dset = f.create_dataset("data6", (5,), dtype=dt) + for i, s in enumerate(var_strings): + dset[i] = s return str(filename) @@ -815,3 +824,57 @@ def test_hdf5_compound_collision_detection(hdf5_file_with_compound_collision): dl_manager = StreamingDownloadManager() with pytest.raises(ValueError, match="Column name collision detected"): hdf5._split_generators(dl_manager) + + +def test_hdf5_mismatched_lengths_with_column_filtering(hdf5_file_with_mismatched_lengths): + """Test that mismatched dataset lengths are ignored when the mismatched dataset is excluded via columns config.""" + config = HDF5Config(columns=["data1"]) + hdf5 = HDF5() + hdf5.config = config + + generator = hdf5._generate_tables([[hdf5_file_with_mismatched_lengths]]) + tables = list(generator) + + # Should work without error since we're only including the first dataset + assert len(tables) == 1 + _, table = tables[0] + + # Check that only the specified column is present + expected_columns = {"data1"} + assert set(table.column_names) == expected_columns + assert "data2" not in table.column_names + + # Check the data + data1_values = table["data1"].to_pylist() + assert data1_values == [0, 1, 2, 3, 4] + + # Test 2: Include multiple compatible datasets (all with 5 rows) + config2 = HDF5Config(columns=["data1", "data3", "data4", "data5", "data6"]) + hdf5.config = config2 + + generator2 = hdf5._generate_tables([[hdf5_file_with_mismatched_lengths]]) + tables2 = list(generator2) + + # Should work without error since we're excluding the mismatched dataset + assert len(tables2) == 1 + _, table2 = tables2[0] + + # Check that all specified columns are present + expected_columns2 = {"data1", "data3", "data4", "data5", "data6"} + assert set(table2.column_names) == expected_columns2 + assert "data2" not in table2.column_names + + # Check data types and values + assert table2["data1"].to_pylist() == [0, 1, 2, 3, 4] # int32 + assert len(table2["data3"].to_pylist()) == 5 # Array2D + assert len(table2["data3"].to_pylist()[0]) == 3 # 3 rows in each 2D array + assert len(table2["data3"].to_pylist()[0][0]) == 4 # 4 columns in each 2D array + np.testing.assert_allclose(table2["data4"].to_pylist(), [0.0, 0.1, 0.2, 0.3, 0.4], rtol=1e-6) # float64 + assert table2["data5"].to_pylist() == [True, False, True, False, True] # boolean + assert table2["data6"].to_pylist() == [ + "short", + "medium length", + "very long string", + "tiny", + "another string", + ] # vlen string From c650c6fd3087e7b0d28b4beb3c66d74da8924db2 Mon Sep 17 00:00:00 2001 From: Michael Klamkin Date: Thu, 24 Jul 2025 16:31:36 -0400 Subject: [PATCH 09/16] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 74fc06bb199..58e0b155666 100644 --- a/setup.py +++ b/setup.py @@ -166,7 +166,7 @@ "aiohttp", "elasticsearch>=7.17.12,<8.0.0", # 8.0 asks users to provide hosts or cloud_id when instantiating ElasticSearch(); 7.9.1 has legacy numpy.float_ which was fixed in https://github.com/elastic/elasticsearch-py/pull/2551. "faiss-cpu>=1.8.0.post1", # Pins numpy < 2 - "h5py", # FIXME: probably needs a lower bound + "h5py", "jax>=0.3.14; sys_platform != 'win32'", "jaxlib>=0.3.14; sys_platform != 'win32'", "lz4", From c3c567d98a1135ade6f396287998d41cd4edf5dd Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Mon, 11 Aug 2025 13:27:55 -0400 Subject: [PATCH 10/16] Sequence -> List --- src/datasets/packaged_modules/hdf5/hdf5.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 1bb094a3c6d..95485018a8c 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -12,7 +12,7 @@ Array4D, Array5D, LargeList, - Sequence, + List, Value, _ArrayXD, _arrow_to_datasets_dtype, @@ -339,7 +339,7 @@ def _infer_feature_from_dataset(dset: "h5py.Dataset"): if hasattr(dset.dtype, "metadata") and dset.dtype.metadata and "vlen" in dset.dtype.metadata: vlen_dtype = dset.dtype.metadata["vlen"] inner_feature = _np_to_pa_to_hf_value(vlen_dtype) - return Sequence(inner_feature) + return List(inner_feature) value_feature = _np_to_pa_to_hf_value(dset.dtype) dtype_str = value_feature.dtype @@ -349,7 +349,7 @@ def _infer_feature_from_dataset(dset: "h5py.Dataset"): if rank == 0: return value_feature elif rank == 1: - return Sequence(value_feature, length=value_shape[0]) + return List(value_feature, length=value_shape[0]) elif rank <= 5: return _sized_arrayxd(rank)(shape=value_shape, dtype=dtype_str) else: @@ -359,7 +359,7 @@ def _infer_feature_from_dataset(dset: "h5py.Dataset"): def _has_zero_dimensions(feature): if isinstance(feature, _ArrayXD): return any(dim == 0 for dim in feature.shape) - elif isinstance(feature, Sequence): # also gets regular List + elif isinstance(feature, List): # also gets regular List return feature.length == 0 or _has_zero_dimensions(feature.feature) elif isinstance(feature, LargeList): return _has_zero_dimensions(feature.feature) From babd919f68e4e46dbc811c40f188d672bdd1a34f Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Mon, 11 Aug 2025 13:28:53 -0400 Subject: [PATCH 11/16] Sequence -> List cont. --- tests/packaged_modules/test_hdf5.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py index 30a8ab298d3..87070dd599a 100644 --- a/tests/packaged_modules/test_hdf5.py +++ b/tests/packaged_modules/test_hdf5.py @@ -2,7 +2,7 @@ import pytest import h5py -from datasets import Array2D, Array3D, Array4D, Features, Sequence, Value +from datasets import Array2D, Array3D, Array4D, Features, List, Value from datasets.builder import InvalidConfigName from datasets.data_files import DataFilesDict, DataFilesList from datasets.download.streaming_download_manager import StreamingDownloadManager @@ -579,8 +579,8 @@ def test_hdf5_feature_inference(hdf5_file_with_arrays): # (n_rows, 2, 3, 4, 5) -> Array4D with shape (2, 3, 4, 5) assert isinstance(features["tensor_4d"], Array4D) assert features["tensor_4d"].shape == (2, 3, 4, 5) - # (n_rows, 10) -> Sequence of length 10 - assert isinstance(features["vector_1d"], Sequence) + # (n_rows, 10) -> List of length 10 + assert isinstance(features["vector_1d"], List) assert features["vector_1d"].length == 10 @@ -600,9 +600,9 @@ def test_hdf5_vlen_feature_inference(hdf5_file_with_vlen_arrays): # Check specific feature types for variable-length arrays features = hdf5.info.features - # Variable-length arrays should become Sequence features by default (for small datasets) - assert isinstance(features["vlen_ints"], Sequence) - assert isinstance(features["mixed_data"], Sequence) + # Variable-length arrays should become List features by default (for small datasets) + assert isinstance(features["vlen_ints"], List) + assert isinstance(features["mixed_data"], List) # Check that the inner feature types are correct assert isinstance(features["vlen_ints"].feature, Value) From f709dae53ff99d95295da11f52a083491753c9a3 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Mon, 11 Aug 2025 13:37:58 -0400 Subject: [PATCH 12/16] Fix features.List and typing.List conflict --- src/datasets/packaged_modules/hdf5/hdf5.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 95485018a8c..cd4f54e2e1c 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -1,6 +1,6 @@ import itertools from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List as ListT, Optional import numpy as np import pyarrow as pa @@ -33,7 +33,7 @@ class HDF5Config(datasets.BuilderConfig): """BuilderConfig for HDF5.""" batch_size: Optional[int] = None - columns: Optional[List[str]] = None + columns: Optional[ListT[str]] = None features: Optional[datasets.Features] = None def __post_init__(self): From db2e76a8df165d10b2ee0373537a48e9995b31de Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Mon, 11 Aug 2025 13:51:21 -0400 Subject: [PATCH 13/16] Use `Features(dict)` for complex and compound --- src/datasets/packaged_modules/hdf5/hdf5.py | 114 ++++++++++-------- tests/packaged_modules/test_hdf5.py | 134 +++++---------------- 2 files changed, 93 insertions(+), 155 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index cd4f54e2e1c..792ef6c3422 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -1,6 +1,7 @@ import itertools from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List as ListT, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import List as ListT import numpy as np import pyarrow as pa @@ -11,6 +12,7 @@ Array3D, Array4D, Array5D, + Features, LargeList, List, Value, @@ -77,29 +79,16 @@ def _split_generators(self, dl_manager): dataset_map = _traverse_datasets(h5) features_dict = {} - def _check_column_collisions(new_columns, source_dataset_path): - """Check for column name collisions and raise informative errors.""" - for new_col in new_columns: - if new_col in features_dict: - raise ValueError( - f"Column name collision detected: '{new_col}' from dataset '{source_dataset_path}' " - f"conflicts with existing column. Consider renaming datasets in the HDF5 file." - ) - for path, dset in dataset_map.items(): if _is_complex_dtype(dset.dtype): complex_features = _create_complex_features(path, dset) - _check_column_collisions(complex_features.keys(), path) features_dict.update(complex_features) elif _is_compound_dtype(dset.dtype): compound_features = _create_compound_features(path, dset) - _check_column_collisions(compound_features.keys(), path) features_dict.update(compound_features) elif _is_vlen_string_dtype(dset.dtype): - _check_column_collisions([path], path) features_dict[path] = Value("string") else: - _check_column_collisions([path], path) feat = _infer_feature_from_dataset(dset) features_dict[path] = feat self.info.features = datasets.Features(features_dict) @@ -175,9 +164,9 @@ def _generate_tables(self, files): pa_arr = datasets.features.features.numpy_to_pyarrow_listarray(arr) batch_dict[path] = pa_arr elif _is_complex_dtype(dset.dtype): - batch_dict.update(_convert_complex_to_separate_columns(path, arr, dset)) + batch_dict.update(_convert_complex_to_nested(path, arr, dset)) elif _is_compound_dtype(dset.dtype): - batch_dict.update(_convert_compound_to_separate_columns(path, arr, dset)) + batch_dict.update(_convert_compound_to_nested(path, arr, dset)) elif dset.dtype.kind == "O": raise ValueError( f"Object dtype dataset '{path}' is not supported. " @@ -219,22 +208,36 @@ def _is_complex_dtype(dtype: np.dtype) -> bool: return dtype.kind == "c" -def _create_complex_features(base_path: str, dset: "h5py.Dataset") -> Dict[str, Value]: - """Create separate features for real and imaginary parts of complex data. +def _create_complex_features(base_path: str, dset: "h5py.Dataset") -> Dict[str, Any]: + """Create Features for complex data with real and imaginary parts `real` and `imag`. NOTE: Always uses float64 for the real and imaginary parts. """ logger.info( - f"Complex dataset '{base_path}' (dtype: {dset.dtype}) split into '{base_path}_real' and '{base_path}_imag'" + f"Complex dataset '{base_path}' (dtype: {dset.dtype}) represented as nested structure with 'real' and 'imag' fields" + ) + nested_features = Features( + { + "real": Value("float64"), + "imag": Value("float64"), + } ) - return {f"{base_path}_real": Value("float64"), f"{base_path}_imag": Value("float64")} + return {base_path: nested_features} -def _convert_complex_to_separate_columns(base_path: str, arr: np.ndarray, dset: "h5py.Dataset") -> Dict[str, pa.Array]: - """Convert complex array to separate real and imaginary columns.""" +def _convert_complex_to_nested(base_path: str, arr: np.ndarray, dset: "h5py.Dataset") -> Dict[str, pa.Array]: + """Convert complex to Features with real and imaginary parts `real` and `imag`.""" result = {} - result[f"{base_path}_real"] = datasets.features.features.numpy_to_pyarrow_listarray(arr.real) - result[f"{base_path}_imag"] = datasets.features.features.numpy_to_pyarrow_listarray(arr.imag) + + def _convert_complex_scalar(complex_val): + """Convert a complex scalar to a dictionary.""" + if complex_val.size == 1: + return {"real": float(complex_val.item().real), "imag": float(complex_val.item().imag)} + else: + # For multi-dimensional arrays, convert to list + return {"real": complex_val.real.tolist(), "imag": complex_val.imag.tolist()} + + result[base_path] = pa.array([_convert_complex_scalar(complex_val) for complex_val in arr]) return result @@ -255,51 +258,56 @@ def __init__(self, dtype): def _create_compound_features(base_path: str, dset: "h5py.Dataset") -> Dict[str, Any]: - """Create separate features for each field in compound data.""" + """Create nested features for compound data with field names as keys.""" field_names = list(dset.dtype.names) logger.info( - f"Compound dataset '{base_path}' (dtype: {dset.dtype}) flattened into {len(field_names)} columns: {field_names}" + f"Compound dataset '{base_path}' (dtype: {dset.dtype}) represented as nested Features with fields: {field_names}" ) - features = {} + nested_features_dict = {} for field_name in field_names: field_dtype = dset.dtype[field_name] - field_path = f"{base_path}_{field_name}" if _is_complex_dtype(field_dtype): - features[f"{field_path}_real"] = Value("float64") - features[f"{field_path}_imag"] = Value("float64") + nested_features_dict[field_name] = Features( + { + "real": Value("float64"), + "imag": Value("float64"), + } + ) elif _is_compound_dtype(field_dtype): mock_dset = _MockDataset(field_dtype) - nested_features = _create_compound_features(field_path, mock_dset) - features.update(nested_features) + nested_features_dict[field_name] = _create_compound_features(field_name, mock_dset)[field_name] else: - value_feature = _np_to_pa_to_hf_value(field_dtype) - features[field_path] = value_feature + nested_features_dict[field_name] = _np_to_pa_to_hf_value(field_dtype) - return features + nested_features = Features(nested_features_dict) + return {base_path: nested_features} -def _convert_compound_to_separate_columns( - base_path: str, arr: np.ndarray, dset: "h5py.Dataset" -) -> Dict[str, pa.Array]: - """Convert compound array to separate columns for each field.""" +def _convert_compound_to_nested(base_path: str, arr: np.ndarray, dset: "h5py.Dataset") -> Dict[str, pa.Array]: + """Convert compound array to nested structure with field names as keys.""" result = {} - for field_name in list(dset.dtype.names): - field_dtype = dset.dtype[field_name] - field_path = f"{base_path}_{field_name}" - field_data = arr[field_name] - - if _is_complex_dtype(field_dtype): - result[f"{field_path}_real"] = datasets.features.features.numpy_to_pyarrow_listarray(field_data.real) - result[f"{field_path}_imag"] = datasets.features.features.numpy_to_pyarrow_listarray(field_data.imag) - elif _is_compound_dtype(field_dtype): - mock_dset = _MockDataset(field_dtype) - nested_result = _convert_compound_to_separate_columns(field_path, field_data, mock_dset) - result.update(nested_result) - else: - result[field_path] = datasets.features.features.numpy_to_pyarrow_listarray(field_data) + def _convert_compound_recursive(compound_arr, compound_dtype): + """Recursively convert compound array to nested structure.""" + nested_data = [] + for row in compound_arr: + row_dict = {} + for field_name in compound_dtype.names: + field_dtype = compound_dtype[field_name] + field_data = row[field_name] + + if _is_complex_dtype(field_dtype): + row_dict[field_name] = {"real": float(field_data.real), "imag": float(field_data.imag)} + elif _is_compound_dtype(field_dtype): + row_dict[field_name] = _convert_compound_recursive([field_data], field_dtype)[0] + else: + row_dict[field_name] = field_data.item() if field_data.size == 1 else field_data.tolist() + nested_data.append(row_dict) + return nested_data + + result[base_path] = pa.array(_convert_compound_recursive(arr, dset.dtype)) return result diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py index 87070dd599a..0997302838b 100644 --- a/tests/packaged_modules/test_hdf5.py +++ b/tests/packaged_modules/test_hdf5.py @@ -1,7 +1,7 @@ +import h5py import numpy as np import pytest -import h5py from datasets import Array2D, Array3D, Array4D, Features, List, Value from datasets.builder import InvalidConfigName from datasets.data_files import DataFilesDict, DataFilesList @@ -257,41 +257,6 @@ def hdf5_file_with_mixed_data_types(tmp_path): return str(filename) -@pytest.fixture -def hdf5_file_with_complex_collision(tmp_path): - """Create an HDF5 file where complex dataset would collide with existing dataset name.""" - filename = tmp_path / "collision.h5" - - with h5py.File(filename, "w") as f: - # Create a complex dataset - complex_data = np.array([1 + 2j, 3 + 4j], dtype=np.complex64) - f.create_dataset("data", data=complex_data) - - # Create a regular dataset that would collide with the complex real part - regular_data = np.array([1.0, 2.0], dtype=np.float32) - f.create_dataset("data_real", data=regular_data) # This should cause a collision - - return str(filename) - - -@pytest.fixture -def hdf5_file_with_compound_collision(tmp_path): - """Create an HDF5 file where compound dataset would collide with existing dataset name.""" - filename = tmp_path / "compound_collision.h5" - - with h5py.File(filename, "w") as f: - # Create a compound dataset - dt_compound = np.dtype([("x", "i4"), ("y", "f8")]) - compound_data = np.array([(1, 2.5), (3, 4.5)], dtype=dt_compound) - f.create_dataset("position", data=compound_data) - - # Create a regular dataset that would collide with compound field - regular_data = np.array([10, 20], dtype=np.int32) - f.create_dataset("position_x", data=regular_data) # This should cause a collision - - return str(filename) - - def test_config_raises_when_invalid_name(): """Test that invalid config names raise an error.""" with pytest.raises(InvalidConfigName, match="Bad characters"): @@ -675,23 +640,21 @@ def test_hdf5_complex_numbers(hdf5_file_with_complex_data): assert len(tables) == 1 _, table = tables[0] - # Check that complex numbers are split into real/imaginary parts + # Check that complex numbers are represented as nested Features expected_columns = { - "complex_64_real", - "complex_64_imag", - "complex_128_real", - "complex_128_imag", - "complex_array_real", - "complex_array_imag", + "complex_64", + "complex_128", + "complex_array", } assert set(table.column_names) == expected_columns # Check complex_64 data - real_data = table["complex_64_real"].to_pylist() - imag_data = table["complex_64_imag"].to_pylist() - - assert real_data == [1.0, 3.0, 5.0, 7.0] - assert imag_data == [2.0, 4.0, 6.0, 8.0] + complex_64_data = table["complex_64"].to_pylist() + assert len(complex_64_data) == 4 + assert complex_64_data[0] == {"real": 1.0, "imag": 2.0} + assert complex_64_data[1] == {"real": 3.0, "imag": 4.0} + assert complex_64_data[2] == {"real": 5.0, "imag": 6.0} + assert complex_64_data[3] == {"real": 7.0, "imag": 8.0} def test_hdf5_compound_types(hdf5_file_with_compound_data): @@ -706,25 +669,20 @@ def test_hdf5_compound_types(hdf5_file_with_compound_data): assert len(tables) == 1 _, table = tables[0] - # Check that compound types are flattened into separate columns + # Check that compound types are represented as nested structures expected_columns = { - "simple_compound_x", - "simple_compound_y", - "complex_compound_real", - "complex_compound_imag", - "nested_compound_position_x", - "nested_compound_position_y", - "nested_compound_velocity_vx", - "nested_compound_velocity_vy", + "simple_compound", + "complex_compound", + "nested_compound", } assert set(table.column_names) == expected_columns # Check simple compound data - x_data = table["simple_compound_x"].to_pylist() - y_data = table["simple_compound_y"].to_pylist() - - assert x_data == [1, 3, 5] - assert y_data == [2.5, 4.5, 6.5] + simple_compound_data = table["simple_compound"].to_pylist() + assert len(simple_compound_data) == 3 + assert simple_compound_data[0] == {"x": 1, "y": 2.5} + assert simple_compound_data[1] == {"x": 3, "y": 4.5} + assert simple_compound_data[2] == {"x": 5, "y": 6.5} def test_hdf5_feature_inference_complex(hdf5_file_with_complex_data): @@ -743,10 +701,10 @@ def test_hdf5_feature_inference_complex(hdf5_file_with_complex_data): features = hdf5.info.features # Check complex number features - assert "complex_64_real" in features - assert "complex_64_imag" in features - assert features["complex_64_real"] == Value("float64") - assert features["complex_64_imag"] == Value("float64") + assert "complex_64" in features + assert isinstance(features["complex_64"], Features) + assert features["complex_64"]["real"] == Value("float64") + assert features["complex_64"]["imag"] == Value("float64") def test_hdf5_feature_inference_compound(hdf5_file_with_compound_data): @@ -765,10 +723,10 @@ def test_hdf5_feature_inference_compound(hdf5_file_with_compound_data): features = hdf5.info.features # Check compound type features - assert "simple_compound_x" in features - assert "simple_compound_y" in features - assert features["simple_compound_x"] == Value("int32") - assert features["simple_compound_y"] == Value("float64") + assert "simple_compound" in features + assert isinstance(features["simple_compound"], Features) + assert features["simple_compound"]["x"] == Value("int32") + assert features["simple_compound"]["y"] == Value("float64") def test_hdf5_mixed_data_types(hdf5_file_with_mixed_data_types): @@ -787,43 +745,15 @@ def test_hdf5_mixed_data_types(hdf5_file_with_mixed_data_types): expected_columns = { "regular_int", "regular_float", - "complex_data_real", - "complex_data_imag", - "compound_data_x", - "compound_data_y", + "complex_data", + "compound_data", } assert set(table.column_names) == expected_columns # Check data types assert table["regular_int"].to_pylist() == [0, 1, 2] - assert len(table["complex_data_real"].to_pylist()) == 3 - assert len(table["compound_data_x"].to_pylist()) == 3 - - -def test_hdf5_column_name_collision_detection(hdf5_file_with_complex_collision): - """Test that column name collision detection works correctly.""" - config = HDF5Config() - hdf5 = HDF5() - hdf5.config = config - hdf5.config.data_files = DataFilesDict({"train": [hdf5_file_with_complex_collision]}) - - # This should raise a ValueError due to column name collision - dl_manager = StreamingDownloadManager() - with pytest.raises(ValueError, match="Column name collision detected"): - hdf5._split_generators(dl_manager) - - -def test_hdf5_compound_collision_detection(hdf5_file_with_compound_collision): - """Test collision detection with compound types.""" - config = HDF5Config() - hdf5 = HDF5() - hdf5.config = config - hdf5.config.data_files = DataFilesDict({"train": [hdf5_file_with_compound_collision]}) - - # This should raise a ValueError due to column name collision - dl_manager = StreamingDownloadManager() - with pytest.raises(ValueError, match="Column name collision detected"): - hdf5._split_generators(dl_manager) + assert len(table["complex_data"].to_pylist()) == 3 + assert len(table["compound_data"].to_pylist()) == 3 def test_hdf5_mismatched_lengths_with_column_filtering(hdf5_file_with_mismatched_lengths): From b2e1894b7eeec83061cbee55d078a757453179ee Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Mon, 11 Aug 2025 13:56:45 -0400 Subject: [PATCH 14/16] Hardcode .hdf5 and .h5 extensions to point to hdf5 --- src/datasets/packaged_modules/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index 863d7d31d81..515ff147b29 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -78,6 +78,8 @@ def _hash_python_lines(lines: list[str]) -> str: ".txt": ("text", {}), ".tar": ("webdataset", {}), ".xml": ("xml", {}), + ".hdf5": ("hdf5", {}), + ".h5": ("hdf5", {}), } _EXTENSION_TO_MODULE.update({ext: ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS}) @@ -87,8 +89,6 @@ def _hash_python_lines(lines: list[str]) -> str: _EXTENSION_TO_MODULE.update({ext.upper(): ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext: ("pdffolder", {}) for ext in pdffolder.PdfFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext.upper(): ("pdffolder", {}) for ext in pdffolder.PdfFolder.EXTENSIONS}) -_EXTENSION_TO_MODULE.update({ext: ("hdf5", {}) for ext in hdf5.EXTENSIONS}) -_EXTENSION_TO_MODULE.update({ext.upper(): ("hdf5", {}) for ext in hdf5.EXTENSIONS}) # Used to filter data files based on extensions given a module name _MODULE_TO_EXTENSIONS: dict[str, list[str]] = {} From 9b1550ad28ea1db8244689531c776dcec0ddc39b Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Mon, 11 Aug 2025 14:02:24 -0400 Subject: [PATCH 15/16] Update type hints from Any to Features --- src/datasets/packaged_modules/hdf5/hdf5.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 792ef6c3422..f240316eab1 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -1,6 +1,6 @@ import itertools from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Dict, Optional from typing import List as ListT import numpy as np @@ -208,7 +208,7 @@ def _is_complex_dtype(dtype: np.dtype) -> bool: return dtype.kind == "c" -def _create_complex_features(base_path: str, dset: "h5py.Dataset") -> Dict[str, Any]: +def _create_complex_features(base_path: str, dset: "h5py.Dataset") -> Dict[str, Features]: """Create Features for complex data with real and imaginary parts `real` and `imag`. NOTE: Always uses float64 for the real and imaginary parts. @@ -257,7 +257,7 @@ def __init__(self, dtype): self.names = dtype.names -def _create_compound_features(base_path: str, dset: "h5py.Dataset") -> Dict[str, Any]: +def _create_compound_features(base_path: str, dset: "h5py.Dataset") -> Dict[str, Features]: """Create nested features for compound data with field names as keys.""" field_names = list(dset.dtype.names) logger.info( From 2c4bfba70d525b0f9336b8e36b299d73d4a2f3e4 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Tue, 12 Aug 2025 15:46:23 -0400 Subject: [PATCH 16/16] Unsized List in zero dim case --- src/datasets/packaged_modules/hdf5/hdf5.py | 30 +++++++++++++--------- tests/packaged_modules/test_hdf5.py | 19 +++++++++++++- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index f240316eab1..36858c7dab3 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -102,13 +102,7 @@ def _split_generators(self, dl_manager): def _cast_table(self, pa_table: pa.Table) -> pa.Table: if self.info.features is not None: - relevant_features = { - col: self.info.features[col] for col in pa_table.column_names if col in self.info.features - } - has_zero_dims = any(_has_zero_dimensions(feature) for feature in relevant_features.values()) - # FIXME: pyarrow.lib.ArrowInvalid: list_size needs to be a strict positive integer - if not has_zero_dims: - pa_table = table_cast(pa_table, self.info.features.arrow_schema) + pa_table = table_cast(pa_table, self.info.features.arrow_schema) return pa_table def _generate_tables(self, files): @@ -175,7 +169,13 @@ def _generate_tables(self, files): f"See: https://docs.h5py.org/en/stable/special.html#variable-length-strings" ) else: - pa_arr = datasets.features.features.numpy_to_pyarrow_listarray(arr) + # If any non-batch dimension is zero, emit an unsized pa.list_ + # to avoid creating FixedSizeListArray with list_size=0. + if any(dim == 0 for dim in dset.shape[1:]): + inner_type = pa.from_numpy_dtype(dset.dtype) + pa_arr = pa.array([[] for _ in arr], type=pa.list_(inner_type)) + else: + pa_arr = datasets.features.features.numpy_to_pyarrow_listarray(arr) batch_dict[path] = pa_arr pa_table = pa.Table.from_pydict(batch_dict) yield f"{file_idx}_{start}", self._cast_table(pa_table) @@ -351,15 +351,21 @@ def _infer_feature_from_dataset(dset: "h5py.Dataset"): value_feature = _np_to_pa_to_hf_value(dset.dtype) dtype_str = value_feature.dtype - value_shape = dset.shape[1:] - rank = len(value_shape) + dset_shape = dset.shape[1:] + if any(dim == 0 for dim in dset_shape): + logger.warning( + f"HDF5 to Arrow: Found a dataset named '{dset.name}' with shape {dset_shape} and dtype {dtype_str} that has a dimension with size 0. Shape information will be lost in the conversion to List({value_feature})." + ) + return List(value_feature) + + rank = len(dset_shape) if rank == 0: return value_feature elif rank == 1: - return List(value_feature, length=value_shape[0]) + return List(value_feature, length=dset_shape[0]) elif rank <= 5: - return _sized_arrayxd(rank)(shape=value_shape, dtype=dtype_str) + return _sized_arrayxd(rank)(shape=dset_shape, dtype=dtype_str) else: raise TypeError(f"Array{rank}D not supported. Maximum 5 dimensions allowed.") diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py index 0997302838b..06329a9c430 100644 --- a/tests/packaged_modules/test_hdf5.py +++ b/tests/packaged_modules/test_hdf5.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from datasets import Array2D, Array3D, Array4D, Features, List, Value +from datasets import Array2D, Array3D, Array4D, Features, List, Value, load_dataset from datasets.builder import InvalidConfigName from datasets.data_files import DataFilesDict, DataFilesList from datasets.download.streaming_download_manager import StreamingDownloadManager @@ -504,6 +504,23 @@ def test_hdf5_zero_dimensions_handling(hdf5_file_with_zero_dimensions, caplog): assert len(zero_dim_data) == 3 # 3 rows assert all(len(row) == 0 for row in zero_dim_data) # Each row is empty + # Check that shape info is lost + caplog.clear() + ds = load_dataset("hdf5", data_files=[hdf5_file_with_zero_dimensions], split="train") + assert all(isinstance(col, List) and col.length == -1 for col in ds.features.values()) + + # Check for the warnings + assert ( + len( + [ + record.message + for record in caplog.records + if record.levelname == "WARNING" and "dimension with size 0" in record.message + ] + ) + == 3 + ) + def test_hdf5_empty_file_warning(empty_hdf5_file, caplog): """Test that empty files (no datasets) are skipped with a warning."""