Skip to content

Commit 94146be

Browse files
committed
refactor type inference
1 parent fdadd1a commit 94146be

File tree

2 files changed

+301
-65
lines changed

2 files changed

+301
-65
lines changed

src/datasets/packaged_modules/hdf5/hdf5.py

Lines changed: 135 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,22 @@
22
from dataclasses import dataclass
33
from typing import Dict, List, Optional
44

5-
import h5py
65
import numpy as np
76
import pyarrow as pa
87

98
import datasets
10-
from datasets.features.features import LargeList, Sequence, _ArrayXD
9+
import h5py
10+
from datasets.features.features import (
11+
Array2D,
12+
Array3D,
13+
Array4D,
14+
Array5D,
15+
LargeList,
16+
Sequence,
17+
Value,
18+
_ArrayXD,
19+
_arrow_to_datasets_dtype,
20+
)
1121
from datasets.table import table_cast
1222

1323

@@ -76,7 +86,7 @@ def _split_generators(self, dl_manager):
7686

7787
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
7888
if self.info.features is not None:
79-
has_zero_dims = any(has_zero_dimensions(feature) for feature in self.info.features.values())
89+
has_zero_dims = any(_has_zero_dimensions(feature) for feature in self.info.features.values())
8090
if not has_zero_dims:
8191
pa_table = table_cast(pa_table, self.info.features.arrow_schema)
8292
return pa_table
@@ -105,7 +115,13 @@ def _generate_tables(self, files):
105115
if self.config.columns is not None and path not in self.config.columns:
106116
continue
107117
arr = dset[start:end]
108-
pa_arr = datasets.features.features.numpy_to_pyarrow_listarray(arr)
118+
if _is_ragged_dataset(dset):
119+
if _is_variable_length_string(dset):
120+
pa_arr = _variable_length_string_to_pyarrow(arr, dset)
121+
else:
122+
pa_arr = _ragged_array_to_pyarrow_largelist(arr, dset)
123+
else:
124+
pa_arr = datasets.features.features.numpy_to_pyarrow_listarray(arr) # NOTE: type=None
109125
batch_dict[path] = pa_arr
110126
pa_table = pa.Table.from_pydict(batch_dict)
111127
yield f"{file_idx}_{start}", self._cast_table(pa_table)
@@ -123,82 +139,137 @@ def _traverse_datasets(h5_obj, prefix: str = "") -> Dict[str, h5py.Dataset]:
123139
return mapping
124140

125141

126-
_DTYPE_TO_DATASETS: Dict[np.dtype, str] = { # FIXME: necessary/check if util exists?
127-
np.dtype("bool").newbyteorder("="): "bool",
128-
np.dtype("int8").newbyteorder("="): "int8",
129-
np.dtype("int16").newbyteorder("="): "int16",
130-
np.dtype("int32").newbyteorder("="): "int32",
131-
np.dtype("int64").newbyteorder("="): "int64",
132-
np.dtype("uint8").newbyteorder("="): "uint8",
133-
np.dtype("uint16").newbyteorder("="): "uint16",
134-
np.dtype("uint32").newbyteorder("="): "uint32",
135-
np.dtype("uint64").newbyteorder("="): "uint64",
136-
np.dtype("float16").newbyteorder("="): "float16",
137-
np.dtype("float32").newbyteorder("="): "float32",
138-
np.dtype("float64").newbyteorder("="): "float64",
139-
# np.dtype("complex64").newbyteorder("="): "complex64",
140-
# np.dtype("complex128").newbyteorder("="): "complex128",
141-
}
142-
143-
144-
def _dtype_to_dataset_dtype(dtype: np.dtype) -> str:
145-
"""Map NumPy dtype to datasets.Value dtype string, falls back to "binary" for unknown or unsupported dtypes."""
146-
147-
# FIXME: endian fix necessary/correct?
148-
base_dtype = dtype.newbyteorder("=")
149-
if base_dtype in _DTYPE_TO_DATASETS:
150-
return _DTYPE_TO_DATASETS[base_dtype]
151-
152-
if base_dtype.kind in {"S", "a"}:
153-
return "binary"
154-
155-
# FIXME: seems h5 converts unicode back to bytes?
156-
if base_dtype.kind == "U":
157-
return "binary"
158-
159-
if base_dtype.kind == "O":
160-
return "binary"
161-
162-
# FIXME: support varlen?
163-
164-
return "binary"
142+
def _base_dtype(dtype):
143+
if hasattr(dtype, "metadata") and dtype.metadata and "vlen" in dtype.metadata:
144+
return dtype.metadata["vlen"]
145+
if hasattr(dtype, "subdtype") and dtype.subdtype is not None:
146+
return _base_dtype(dtype.subdtype[0])
147+
return dtype
148+
149+
150+
def _ragged_array_to_pyarrow_largelist(arr: np.ndarray, dset: h5py.Dataset) -> pa.Array:
151+
if _is_variable_length_string(dset):
152+
list_of_strings = []
153+
for item in arr:
154+
if item is None:
155+
list_of_strings.append(None)
156+
else:
157+
if isinstance(item, bytes):
158+
item = item.decode("utf-8")
159+
list_of_strings.append(item)
160+
return datasets.features.features.list_of_pa_arrays_to_pyarrow_listarray(
161+
[pa.array([item]) if item is not None else None for item in list_of_strings]
162+
)
163+
else:
164+
return _convert_nested_ragged_array_recursive(arr, dset.dtype)
165+
166+
167+
def _convert_nested_ragged_array_recursive(arr: np.ndarray, dtype):
168+
if hasattr(dtype, "subdtype") and dtype.subdtype is not None:
169+
inner_dtype = dtype.subdtype[0]
170+
list_of_arrays = []
171+
for item in arr:
172+
if item is None:
173+
list_of_arrays.append(None)
174+
else:
175+
inner_array = _convert_nested_ragged_array_recursive(item, inner_dtype)
176+
list_of_arrays.append(inner_array)
177+
return datasets.features.features.list_of_pa_arrays_to_pyarrow_listarray(
178+
[pa.array(item) if item is not None else None for item in list_of_arrays]
179+
)
180+
else:
181+
list_of_arrays = []
182+
for item in arr:
183+
if item is None:
184+
list_of_arrays.append(None)
185+
else:
186+
if not isinstance(item, np.ndarray):
187+
item = np.array(item, dtype=dtype)
188+
list_of_arrays.append(item)
189+
return datasets.features.features.list_of_pa_arrays_to_pyarrow_listarray(
190+
[pa.array(item) if item is not None else None for item in list_of_arrays]
191+
)
165192

166193

167194
def _infer_feature_from_dataset(dset: h5py.Dataset):
168-
"""Infer a ``datasets.Features`` entry for one HDF5 dataset."""
195+
if _is_variable_length_string(dset):
196+
return Value("string") # FIXME: large_string?
169197

170-
import datasets as hfd
198+
if _is_ragged_dataset(dset):
199+
return _infer_nested_feature_recursive(dset.dtype, dset)
171200

172-
dtype_str = _dtype_to_dataset_dtype(dset.dtype)
201+
value_feature = _np_to_pa_to_hf_value(dset.dtype)
202+
dtype_str = value_feature.dtype
173203
value_shape = dset.shape[1:]
174204

175-
# Reject ragged datasets (variable-length or None dims)
176-
if dset.dtype.kind == "O" or any(s is None for s in value_shape):
177-
raise ValueError(f"Ragged dataset {dset.name} with shape {value_shape} and dtype {dset.dtype} not supported")
178-
179205
if dset.dtype.kind not in {"b", "i", "u", "f", "S", "a"}:
180-
raise ValueError(f"Unsupported dtype {dset.dtype} for dataset {dset.name}")
206+
raise TypeError(f"Unsupported dtype {dset.dtype} for dataset {dset.name}")
181207

182208
rank = len(value_shape)
183-
if 2 <= rank <= 5:
184-
from datasets.features import Array2D, Array3D, Array4D, Array5D
185-
186-
array_cls = [None, None, Array2D, Array3D, Array4D, Array5D][rank]
187-
return array_cls(shape=value_shape, dtype=dtype_str)
209+
if rank == 0:
210+
return value_feature
211+
elif rank == 1:
212+
return Sequence(value_feature, length=value_shape[0])
213+
elif 2 <= rank <= 5:
214+
return _sized_arrayxd(rank)(shape=value_shape, dtype=dtype_str)
215+
else:
216+
raise TypeError(f"Array{rank}D not supported. Only up to 5D arrays are supported.")
188217

189-
# Fallback to nested Sequence
190-
def _build_feature(shape: tuple[int, ...]):
191-
if len(shape) == 0:
192-
return hfd.Value(dtype_str)
193-
return hfd.Sequence(length=shape[0], feature=_build_feature(shape[1:]))
194218

195-
return _build_feature(value_shape)
219+
def _infer_nested_feature_recursive(dtype, dset: h5py.Dataset):
220+
if hasattr(dtype, "subdtype") and dtype.subdtype is not None:
221+
inner_dtype = dtype.subdtype[0]
222+
inner_feature = _infer_nested_feature_recursive(inner_dtype, dset)
223+
return Sequence(inner_feature)
224+
else:
225+
if hasattr(dtype, "kind") and dtype.kind == "O":
226+
if _is_variable_length_string(dset):
227+
base_dtype = np.dtype("S1")
228+
else:
229+
base_dtype = _base_dtype(dset.dtype)
230+
return Sequence(_np_to_pa_to_hf_value(base_dtype))
231+
else:
232+
return _np_to_pa_to_hf_value(dtype)
196233

197234

198-
def has_zero_dimensions(feature: _ArrayXD | Sequence | LargeList):
235+
def _has_zero_dimensions(feature):
199236
if isinstance(feature, _ArrayXD):
200237
return any(dim == 0 for dim in feature.shape)
201238
elif isinstance(feature, (Sequence, LargeList)):
202-
return feature.length == 0 or has_zero_dimensions(feature.feature)
239+
return feature.length == 0 or _has_zero_dimensions(feature.feature)
203240
else:
204241
return False
242+
243+
244+
def _sized_arrayxd(rank: int):
245+
return {2: Array2D, 3: Array3D, 4: Array4D, 5: Array5D}[rank]
246+
247+
248+
def _np_to_pa_to_hf_value(numpy_dtype: np.dtype) -> Value:
249+
return Value(dtype=_arrow_to_datasets_dtype(pa.from_numpy_dtype(numpy_dtype)))
250+
251+
252+
def _is_ragged_dataset(dset: h5py.Dataset) -> bool:
253+
return dset.dtype.kind == "O" and hasattr(dset.dtype, "subdtype")
254+
255+
256+
def _is_variable_length_string(dset: h5py.Dataset) -> bool:
257+
if not _is_ragged_dataset(dset) or dset.shape[0] == 0:
258+
return False
259+
num_samples = min(3, dset.shape[0])
260+
for i in range(num_samples):
261+
try:
262+
if isinstance(dset[i], (str, bytes)):
263+
return True
264+
except (IndexError, TypeError):
265+
continue
266+
return False
267+
268+
269+
def _variable_length_string_to_pyarrow(arr: np.ndarray, dset: h5py.Dataset) -> pa.Array:
270+
list_of_strings = []
271+
for item in arr:
272+
if isinstance(item, bytes):
273+
item = item.decode("utf-8")
274+
list_of_strings.append(item)
275+
return pa.array(list_of_strings)

0 commit comments

Comments
 (0)