Skip to content

Commit db2e76a

Browse files
committed
Use Features(dict) for complex and compound
1 parent f709dae commit db2e76a

File tree

2 files changed

+93
-155
lines changed

2 files changed

+93
-155
lines changed

src/datasets/packaged_modules/hdf5/hdf5.py

Lines changed: 61 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import itertools
22
from dataclasses import dataclass
3-
from typing import TYPE_CHECKING, Any, Dict, List as ListT, Optional
3+
from typing import TYPE_CHECKING, Any, Dict, Optional
4+
from typing import List as ListT
45

56
import numpy as np
67
import pyarrow as pa
@@ -11,6 +12,7 @@
1112
Array3D,
1213
Array4D,
1314
Array5D,
15+
Features,
1416
LargeList,
1517
List,
1618
Value,
@@ -77,29 +79,16 @@ def _split_generators(self, dl_manager):
7779
dataset_map = _traverse_datasets(h5)
7880
features_dict = {}
7981

80-
def _check_column_collisions(new_columns, source_dataset_path):
81-
"""Check for column name collisions and raise informative errors."""
82-
for new_col in new_columns:
83-
if new_col in features_dict:
84-
raise ValueError(
85-
f"Column name collision detected: '{new_col}' from dataset '{source_dataset_path}' "
86-
f"conflicts with existing column. Consider renaming datasets in the HDF5 file."
87-
)
88-
8982
for path, dset in dataset_map.items():
9083
if _is_complex_dtype(dset.dtype):
9184
complex_features = _create_complex_features(path, dset)
92-
_check_column_collisions(complex_features.keys(), path)
9385
features_dict.update(complex_features)
9486
elif _is_compound_dtype(dset.dtype):
9587
compound_features = _create_compound_features(path, dset)
96-
_check_column_collisions(compound_features.keys(), path)
9788
features_dict.update(compound_features)
9889
elif _is_vlen_string_dtype(dset.dtype):
99-
_check_column_collisions([path], path)
10090
features_dict[path] = Value("string")
10191
else:
102-
_check_column_collisions([path], path)
10392
feat = _infer_feature_from_dataset(dset)
10493
features_dict[path] = feat
10594
self.info.features = datasets.Features(features_dict)
@@ -175,9 +164,9 @@ def _generate_tables(self, files):
175164
pa_arr = datasets.features.features.numpy_to_pyarrow_listarray(arr)
176165
batch_dict[path] = pa_arr
177166
elif _is_complex_dtype(dset.dtype):
178-
batch_dict.update(_convert_complex_to_separate_columns(path, arr, dset))
167+
batch_dict.update(_convert_complex_to_nested(path, arr, dset))
179168
elif _is_compound_dtype(dset.dtype):
180-
batch_dict.update(_convert_compound_to_separate_columns(path, arr, dset))
169+
batch_dict.update(_convert_compound_to_nested(path, arr, dset))
181170
elif dset.dtype.kind == "O":
182171
raise ValueError(
183172
f"Object dtype dataset '{path}' is not supported. "
@@ -219,22 +208,36 @@ def _is_complex_dtype(dtype: np.dtype) -> bool:
219208
return dtype.kind == "c"
220209

221210

222-
def _create_complex_features(base_path: str, dset: "h5py.Dataset") -> Dict[str, Value]:
223-
"""Create separate features for real and imaginary parts of complex data.
211+
def _create_complex_features(base_path: str, dset: "h5py.Dataset") -> Dict[str, Any]:
212+
"""Create Features for complex data with real and imaginary parts `real` and `imag`.
224213
225214
NOTE: Always uses float64 for the real and imaginary parts.
226215
"""
227216
logger.info(
228-
f"Complex dataset '{base_path}' (dtype: {dset.dtype}) split into '{base_path}_real' and '{base_path}_imag'"
217+
f"Complex dataset '{base_path}' (dtype: {dset.dtype}) represented as nested structure with 'real' and 'imag' fields"
218+
)
219+
nested_features = Features(
220+
{
221+
"real": Value("float64"),
222+
"imag": Value("float64"),
223+
}
229224
)
230-
return {f"{base_path}_real": Value("float64"), f"{base_path}_imag": Value("float64")}
225+
return {base_path: nested_features}
231226

232227

233-
def _convert_complex_to_separate_columns(base_path: str, arr: np.ndarray, dset: "h5py.Dataset") -> Dict[str, pa.Array]:
234-
"""Convert complex array to separate real and imaginary columns."""
228+
def _convert_complex_to_nested(base_path: str, arr: np.ndarray, dset: "h5py.Dataset") -> Dict[str, pa.Array]:
229+
"""Convert complex to Features with real and imaginary parts `real` and `imag`."""
235230
result = {}
236-
result[f"{base_path}_real"] = datasets.features.features.numpy_to_pyarrow_listarray(arr.real)
237-
result[f"{base_path}_imag"] = datasets.features.features.numpy_to_pyarrow_listarray(arr.imag)
231+
232+
def _convert_complex_scalar(complex_val):
233+
"""Convert a complex scalar to a dictionary."""
234+
if complex_val.size == 1:
235+
return {"real": float(complex_val.item().real), "imag": float(complex_val.item().imag)}
236+
else:
237+
# For multi-dimensional arrays, convert to list
238+
return {"real": complex_val.real.tolist(), "imag": complex_val.imag.tolist()}
239+
240+
result[base_path] = pa.array([_convert_complex_scalar(complex_val) for complex_val in arr])
238241
return result
239242

240243

@@ -255,51 +258,56 @@ def __init__(self, dtype):
255258

256259

257260
def _create_compound_features(base_path: str, dset: "h5py.Dataset") -> Dict[str, Any]:
258-
"""Create separate features for each field in compound data."""
261+
"""Create nested features for compound data with field names as keys."""
259262
field_names = list(dset.dtype.names)
260263
logger.info(
261-
f"Compound dataset '{base_path}' (dtype: {dset.dtype}) flattened into {len(field_names)} columns: {field_names}"
264+
f"Compound dataset '{base_path}' (dtype: {dset.dtype}) represented as nested Features with fields: {field_names}"
262265
)
263266

264-
features = {}
267+
nested_features_dict = {}
265268
for field_name in field_names:
266269
field_dtype = dset.dtype[field_name]
267-
field_path = f"{base_path}_{field_name}"
268270

269271
if _is_complex_dtype(field_dtype):
270-
features[f"{field_path}_real"] = Value("float64")
271-
features[f"{field_path}_imag"] = Value("float64")
272+
nested_features_dict[field_name] = Features(
273+
{
274+
"real": Value("float64"),
275+
"imag": Value("float64"),
276+
}
277+
)
272278
elif _is_compound_dtype(field_dtype):
273279
mock_dset = _MockDataset(field_dtype)
274-
nested_features = _create_compound_features(field_path, mock_dset)
275-
features.update(nested_features)
280+
nested_features_dict[field_name] = _create_compound_features(field_name, mock_dset)[field_name]
276281
else:
277-
value_feature = _np_to_pa_to_hf_value(field_dtype)
278-
features[field_path] = value_feature
282+
nested_features_dict[field_name] = _np_to_pa_to_hf_value(field_dtype)
279283

280-
return features
284+
nested_features = Features(nested_features_dict)
285+
return {base_path: nested_features}
281286

282287

283-
def _convert_compound_to_separate_columns(
284-
base_path: str, arr: np.ndarray, dset: "h5py.Dataset"
285-
) -> Dict[str, pa.Array]:
286-
"""Convert compound array to separate columns for each field."""
288+
def _convert_compound_to_nested(base_path: str, arr: np.ndarray, dset: "h5py.Dataset") -> Dict[str, pa.Array]:
289+
"""Convert compound array to nested structure with field names as keys."""
287290
result = {}
288-
for field_name in list(dset.dtype.names):
289-
field_dtype = dset.dtype[field_name]
290-
field_path = f"{base_path}_{field_name}"
291-
field_data = arr[field_name]
292-
293-
if _is_complex_dtype(field_dtype):
294-
result[f"{field_path}_real"] = datasets.features.features.numpy_to_pyarrow_listarray(field_data.real)
295-
result[f"{field_path}_imag"] = datasets.features.features.numpy_to_pyarrow_listarray(field_data.imag)
296-
elif _is_compound_dtype(field_dtype):
297-
mock_dset = _MockDataset(field_dtype)
298-
nested_result = _convert_compound_to_separate_columns(field_path, field_data, mock_dset)
299-
result.update(nested_result)
300-
else:
301-
result[field_path] = datasets.features.features.numpy_to_pyarrow_listarray(field_data)
302291

292+
def _convert_compound_recursive(compound_arr, compound_dtype):
293+
"""Recursively convert compound array to nested structure."""
294+
nested_data = []
295+
for row in compound_arr:
296+
row_dict = {}
297+
for field_name in compound_dtype.names:
298+
field_dtype = compound_dtype[field_name]
299+
field_data = row[field_name]
300+
301+
if _is_complex_dtype(field_dtype):
302+
row_dict[field_name] = {"real": float(field_data.real), "imag": float(field_data.imag)}
303+
elif _is_compound_dtype(field_dtype):
304+
row_dict[field_name] = _convert_compound_recursive([field_data], field_dtype)[0]
305+
else:
306+
row_dict[field_name] = field_data.item() if field_data.size == 1 else field_data.tolist()
307+
nested_data.append(row_dict)
308+
return nested_data
309+
310+
result[base_path] = pa.array(_convert_compound_recursive(arr, dset.dtype))
303311
return result
304312

305313

tests/packaged_modules/test_hdf5.py

Lines changed: 32 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
import h5py
12
import numpy as np
23
import pytest
34

4-
import h5py
55
from datasets import Array2D, Array3D, Array4D, Features, List, Value
66
from datasets.builder import InvalidConfigName
77
from datasets.data_files import DataFilesDict, DataFilesList
@@ -257,41 +257,6 @@ def hdf5_file_with_mixed_data_types(tmp_path):
257257
return str(filename)
258258

259259

260-
@pytest.fixture
261-
def hdf5_file_with_complex_collision(tmp_path):
262-
"""Create an HDF5 file where complex dataset would collide with existing dataset name."""
263-
filename = tmp_path / "collision.h5"
264-
265-
with h5py.File(filename, "w") as f:
266-
# Create a complex dataset
267-
complex_data = np.array([1 + 2j, 3 + 4j], dtype=np.complex64)
268-
f.create_dataset("data", data=complex_data)
269-
270-
# Create a regular dataset that would collide with the complex real part
271-
regular_data = np.array([1.0, 2.0], dtype=np.float32)
272-
f.create_dataset("data_real", data=regular_data) # This should cause a collision
273-
274-
return str(filename)
275-
276-
277-
@pytest.fixture
278-
def hdf5_file_with_compound_collision(tmp_path):
279-
"""Create an HDF5 file where compound dataset would collide with existing dataset name."""
280-
filename = tmp_path / "compound_collision.h5"
281-
282-
with h5py.File(filename, "w") as f:
283-
# Create a compound dataset
284-
dt_compound = np.dtype([("x", "i4"), ("y", "f8")])
285-
compound_data = np.array([(1, 2.5), (3, 4.5)], dtype=dt_compound)
286-
f.create_dataset("position", data=compound_data)
287-
288-
# Create a regular dataset that would collide with compound field
289-
regular_data = np.array([10, 20], dtype=np.int32)
290-
f.create_dataset("position_x", data=regular_data) # This should cause a collision
291-
292-
return str(filename)
293-
294-
295260
def test_config_raises_when_invalid_name():
296261
"""Test that invalid config names raise an error."""
297262
with pytest.raises(InvalidConfigName, match="Bad characters"):
@@ -675,23 +640,21 @@ def test_hdf5_complex_numbers(hdf5_file_with_complex_data):
675640
assert len(tables) == 1
676641
_, table = tables[0]
677642

678-
# Check that complex numbers are split into real/imaginary parts
643+
# Check that complex numbers are represented as nested Features
679644
expected_columns = {
680-
"complex_64_real",
681-
"complex_64_imag",
682-
"complex_128_real",
683-
"complex_128_imag",
684-
"complex_array_real",
685-
"complex_array_imag",
645+
"complex_64",
646+
"complex_128",
647+
"complex_array",
686648
}
687649
assert set(table.column_names) == expected_columns
688650

689651
# Check complex_64 data
690-
real_data = table["complex_64_real"].to_pylist()
691-
imag_data = table["complex_64_imag"].to_pylist()
692-
693-
assert real_data == [1.0, 3.0, 5.0, 7.0]
694-
assert imag_data == [2.0, 4.0, 6.0, 8.0]
652+
complex_64_data = table["complex_64"].to_pylist()
653+
assert len(complex_64_data) == 4
654+
assert complex_64_data[0] == {"real": 1.0, "imag": 2.0}
655+
assert complex_64_data[1] == {"real": 3.0, "imag": 4.0}
656+
assert complex_64_data[2] == {"real": 5.0, "imag": 6.0}
657+
assert complex_64_data[3] == {"real": 7.0, "imag": 8.0}
695658

696659

697660
def test_hdf5_compound_types(hdf5_file_with_compound_data):
@@ -706,25 +669,20 @@ def test_hdf5_compound_types(hdf5_file_with_compound_data):
706669
assert len(tables) == 1
707670
_, table = tables[0]
708671

709-
# Check that compound types are flattened into separate columns
672+
# Check that compound types are represented as nested structures
710673
expected_columns = {
711-
"simple_compound_x",
712-
"simple_compound_y",
713-
"complex_compound_real",
714-
"complex_compound_imag",
715-
"nested_compound_position_x",
716-
"nested_compound_position_y",
717-
"nested_compound_velocity_vx",
718-
"nested_compound_velocity_vy",
674+
"simple_compound",
675+
"complex_compound",
676+
"nested_compound",
719677
}
720678
assert set(table.column_names) == expected_columns
721679

722680
# Check simple compound data
723-
x_data = table["simple_compound_x"].to_pylist()
724-
y_data = table["simple_compound_y"].to_pylist()
725-
726-
assert x_data == [1, 3, 5]
727-
assert y_data == [2.5, 4.5, 6.5]
681+
simple_compound_data = table["simple_compound"].to_pylist()
682+
assert len(simple_compound_data) == 3
683+
assert simple_compound_data[0] == {"x": 1, "y": 2.5}
684+
assert simple_compound_data[1] == {"x": 3, "y": 4.5}
685+
assert simple_compound_data[2] == {"x": 5, "y": 6.5}
728686

729687

730688
def test_hdf5_feature_inference_complex(hdf5_file_with_complex_data):
@@ -743,10 +701,10 @@ def test_hdf5_feature_inference_complex(hdf5_file_with_complex_data):
743701
features = hdf5.info.features
744702

745703
# Check complex number features
746-
assert "complex_64_real" in features
747-
assert "complex_64_imag" in features
748-
assert features["complex_64_real"] == Value("float64")
749-
assert features["complex_64_imag"] == Value("float64")
704+
assert "complex_64" in features
705+
assert isinstance(features["complex_64"], Features)
706+
assert features["complex_64"]["real"] == Value("float64")
707+
assert features["complex_64"]["imag"] == Value("float64")
750708

751709

752710
def test_hdf5_feature_inference_compound(hdf5_file_with_compound_data):
@@ -765,10 +723,10 @@ def test_hdf5_feature_inference_compound(hdf5_file_with_compound_data):
765723
features = hdf5.info.features
766724

767725
# Check compound type features
768-
assert "simple_compound_x" in features
769-
assert "simple_compound_y" in features
770-
assert features["simple_compound_x"] == Value("int32")
771-
assert features["simple_compound_y"] == Value("float64")
726+
assert "simple_compound" in features
727+
assert isinstance(features["simple_compound"], Features)
728+
assert features["simple_compound"]["x"] == Value("int32")
729+
assert features["simple_compound"]["y"] == Value("float64")
772730

773731

774732
def test_hdf5_mixed_data_types(hdf5_file_with_mixed_data_types):
@@ -787,43 +745,15 @@ def test_hdf5_mixed_data_types(hdf5_file_with_mixed_data_types):
787745
expected_columns = {
788746
"regular_int",
789747
"regular_float",
790-
"complex_data_real",
791-
"complex_data_imag",
792-
"compound_data_x",
793-
"compound_data_y",
748+
"complex_data",
749+
"compound_data",
794750
}
795751
assert set(table.column_names) == expected_columns
796752

797753
# Check data types
798754
assert table["regular_int"].to_pylist() == [0, 1, 2]
799-
assert len(table["complex_data_real"].to_pylist()) == 3
800-
assert len(table["compound_data_x"].to_pylist()) == 3
801-
802-
803-
def test_hdf5_column_name_collision_detection(hdf5_file_with_complex_collision):
804-
"""Test that column name collision detection works correctly."""
805-
config = HDF5Config()
806-
hdf5 = HDF5()
807-
hdf5.config = config
808-
hdf5.config.data_files = DataFilesDict({"train": [hdf5_file_with_complex_collision]})
809-
810-
# This should raise a ValueError due to column name collision
811-
dl_manager = StreamingDownloadManager()
812-
with pytest.raises(ValueError, match="Column name collision detected"):
813-
hdf5._split_generators(dl_manager)
814-
815-
816-
def test_hdf5_compound_collision_detection(hdf5_file_with_compound_collision):
817-
"""Test collision detection with compound types."""
818-
config = HDF5Config()
819-
hdf5 = HDF5()
820-
hdf5.config = config
821-
hdf5.config.data_files = DataFilesDict({"train": [hdf5_file_with_compound_collision]})
822-
823-
# This should raise a ValueError due to column name collision
824-
dl_manager = StreamingDownloadManager()
825-
with pytest.raises(ValueError, match="Column name collision detected"):
826-
hdf5._split_generators(dl_manager)
755+
assert len(table["complex_data"].to_pylist()) == 3
756+
assert len(table["compound_data"].to_pylist()) == 3
827757

828758

829759
def test_hdf5_mismatched_lengths_with_column_filtering(hdf5_file_with_mismatched_lengths):

0 commit comments

Comments
 (0)