2
2
from dataclasses import dataclass
3
3
from typing import Dict , List , Optional
4
4
5
- import h5py
6
5
import numpy as np
7
6
import pyarrow as pa
8
7
9
8
import datasets
10
- from datasets .features .features import LargeList , Sequence , _ArrayXD
9
+ import h5py
10
+ from datasets .features .features import (
11
+ Array2D ,
12
+ Array3D ,
13
+ Array4D ,
14
+ Array5D ,
15
+ LargeList ,
16
+ Sequence ,
17
+ Value ,
18
+ _ArrayXD ,
19
+ _arrow_to_datasets_dtype ,
20
+ )
11
21
from datasets .table import table_cast
12
22
13
23
@@ -76,7 +86,7 @@ def _split_generators(self, dl_manager):
76
86
77
87
def _cast_table (self , pa_table : pa .Table ) -> pa .Table :
78
88
if self .info .features is not None :
79
- has_zero_dims = any (has_zero_dimensions (feature ) for feature in self .info .features .values ())
89
+ has_zero_dims = any (_has_zero_dimensions (feature ) for feature in self .info .features .values ())
80
90
if not has_zero_dims :
81
91
pa_table = table_cast (pa_table , self .info .features .arrow_schema )
82
92
return pa_table
@@ -105,7 +115,13 @@ def _generate_tables(self, files):
105
115
if self .config .columns is not None and path not in self .config .columns :
106
116
continue
107
117
arr = dset [start :end ]
108
- pa_arr = datasets .features .features .numpy_to_pyarrow_listarray (arr )
118
+ if _is_ragged_dataset (dset ):
119
+ if _is_variable_length_string (dset ):
120
+ pa_arr = _variable_length_string_to_pyarrow (arr , dset )
121
+ else :
122
+ pa_arr = _ragged_array_to_pyarrow_largelist (arr , dset )
123
+ else :
124
+ pa_arr = datasets .features .features .numpy_to_pyarrow_listarray (arr ) # NOTE: type=None
109
125
batch_dict [path ] = pa_arr
110
126
pa_table = pa .Table .from_pydict (batch_dict )
111
127
yield f"{ file_idx } _{ start } " , self ._cast_table (pa_table )
@@ -123,82 +139,137 @@ def _traverse_datasets(h5_obj, prefix: str = "") -> Dict[str, h5py.Dataset]:
123
139
return mapping
124
140
125
141
126
- _DTYPE_TO_DATASETS : Dict [np .dtype , str ] = { # FIXME: necessary/check if util exists?
127
- np .dtype ("bool" ).newbyteorder ("=" ): "bool" ,
128
- np .dtype ("int8" ).newbyteorder ("=" ): "int8" ,
129
- np .dtype ("int16" ).newbyteorder ("=" ): "int16" ,
130
- np .dtype ("int32" ).newbyteorder ("=" ): "int32" ,
131
- np .dtype ("int64" ).newbyteorder ("=" ): "int64" ,
132
- np .dtype ("uint8" ).newbyteorder ("=" ): "uint8" ,
133
- np .dtype ("uint16" ).newbyteorder ("=" ): "uint16" ,
134
- np .dtype ("uint32" ).newbyteorder ("=" ): "uint32" ,
135
- np .dtype ("uint64" ).newbyteorder ("=" ): "uint64" ,
136
- np .dtype ("float16" ).newbyteorder ("=" ): "float16" ,
137
- np .dtype ("float32" ).newbyteorder ("=" ): "float32" ,
138
- np .dtype ("float64" ).newbyteorder ("=" ): "float64" ,
139
- # np.dtype("complex64").newbyteorder("="): "complex64",
140
- # np.dtype("complex128").newbyteorder("="): "complex128",
141
- }
142
-
143
-
144
- def _dtype_to_dataset_dtype (dtype : np .dtype ) -> str :
145
- """Map NumPy dtype to datasets.Value dtype string, falls back to "binary" for unknown or unsupported dtypes."""
146
-
147
- # FIXME: endian fix necessary/correct?
148
- base_dtype = dtype .newbyteorder ("=" )
149
- if base_dtype in _DTYPE_TO_DATASETS :
150
- return _DTYPE_TO_DATASETS [base_dtype ]
151
-
152
- if base_dtype .kind in {"S" , "a" }:
153
- return "binary"
154
-
155
- # FIXME: seems h5 converts unicode back to bytes?
156
- if base_dtype .kind == "U" :
157
- return "binary"
158
-
159
- if base_dtype .kind == "O" :
160
- return "binary"
161
-
162
- # FIXME: support varlen?
163
-
164
- return "binary"
142
+ def _base_dtype (dtype ):
143
+ if hasattr (dtype , "metadata" ) and dtype .metadata and "vlen" in dtype .metadata :
144
+ return dtype .metadata ["vlen" ]
145
+ if hasattr (dtype , "subdtype" ) and dtype .subdtype is not None :
146
+ return _base_dtype (dtype .subdtype [0 ])
147
+ return dtype
148
+
149
+
150
+ def _ragged_array_to_pyarrow_largelist (arr : np .ndarray , dset : h5py .Dataset ) -> pa .Array :
151
+ if _is_variable_length_string (dset ):
152
+ list_of_strings = []
153
+ for item in arr :
154
+ if item is None :
155
+ list_of_strings .append (None )
156
+ else :
157
+ if isinstance (item , bytes ):
158
+ item = item .decode ("utf-8" )
159
+ list_of_strings .append (item )
160
+ return datasets .features .features .list_of_pa_arrays_to_pyarrow_listarray (
161
+ [pa .array ([item ]) if item is not None else None for item in list_of_strings ]
162
+ )
163
+ else :
164
+ return _convert_nested_ragged_array_recursive (arr , dset .dtype )
165
+
166
+
167
+ def _convert_nested_ragged_array_recursive (arr : np .ndarray , dtype ):
168
+ if hasattr (dtype , "subdtype" ) and dtype .subdtype is not None :
169
+ inner_dtype = dtype .subdtype [0 ]
170
+ list_of_arrays = []
171
+ for item in arr :
172
+ if item is None :
173
+ list_of_arrays .append (None )
174
+ else :
175
+ inner_array = _convert_nested_ragged_array_recursive (item , inner_dtype )
176
+ list_of_arrays .append (inner_array )
177
+ return datasets .features .features .list_of_pa_arrays_to_pyarrow_listarray (
178
+ [pa .array (item ) if item is not None else None for item in list_of_arrays ]
179
+ )
180
+ else :
181
+ list_of_arrays = []
182
+ for item in arr :
183
+ if item is None :
184
+ list_of_arrays .append (None )
185
+ else :
186
+ if not isinstance (item , np .ndarray ):
187
+ item = np .array (item , dtype = dtype )
188
+ list_of_arrays .append (item )
189
+ return datasets .features .features .list_of_pa_arrays_to_pyarrow_listarray (
190
+ [pa .array (item ) if item is not None else None for item in list_of_arrays ]
191
+ )
165
192
166
193
167
194
def _infer_feature_from_dataset (dset : h5py .Dataset ):
168
- """Infer a ``datasets.Features`` entry for one HDF5 dataset."""
195
+ if _is_variable_length_string (dset ):
196
+ return Value ("string" ) # FIXME: large_string?
169
197
170
- import datasets as hfd
198
+ if _is_ragged_dataset (dset ):
199
+ return _infer_nested_feature_recursive (dset .dtype , dset )
171
200
172
- dtype_str = _dtype_to_dataset_dtype (dset .dtype )
201
+ value_feature = _np_to_pa_to_hf_value (dset .dtype )
202
+ dtype_str = value_feature .dtype
173
203
value_shape = dset .shape [1 :]
174
204
175
- # Reject ragged datasets (variable-length or None dims)
176
- if dset .dtype .kind == "O" or any (s is None for s in value_shape ):
177
- raise ValueError (f"Ragged dataset { dset .name } with shape { value_shape } and dtype { dset .dtype } not supported" )
178
-
179
205
if dset .dtype .kind not in {"b" , "i" , "u" , "f" , "S" , "a" }:
180
- raise ValueError (f"Unsupported dtype { dset .dtype } for dataset { dset .name } " )
206
+ raise TypeError (f"Unsupported dtype { dset .dtype } for dataset { dset .name } " )
181
207
182
208
rank = len (value_shape )
183
- if 2 <= rank <= 5 :
184
- from datasets .features import Array2D , Array3D , Array4D , Array5D
185
-
186
- array_cls = [None , None , Array2D , Array3D , Array4D , Array5D ][rank ]
187
- return array_cls (shape = value_shape , dtype = dtype_str )
209
+ if rank == 0 :
210
+ return value_feature
211
+ elif rank == 1 :
212
+ return Sequence (value_feature , length = value_shape [0 ])
213
+ elif 2 <= rank <= 5 :
214
+ return _sized_arrayxd (rank )(shape = value_shape , dtype = dtype_str )
215
+ else :
216
+ raise TypeError (f"Array{ rank } D not supported. Only up to 5D arrays are supported." )
188
217
189
- # Fallback to nested Sequence
190
- def _build_feature (shape : tuple [int , ...]):
191
- if len (shape ) == 0 :
192
- return hfd .Value (dtype_str )
193
- return hfd .Sequence (length = shape [0 ], feature = _build_feature (shape [1 :]))
194
218
195
- return _build_feature (value_shape )
219
+ def _infer_nested_feature_recursive (dtype , dset : h5py .Dataset ):
220
+ if hasattr (dtype , "subdtype" ) and dtype .subdtype is not None :
221
+ inner_dtype = dtype .subdtype [0 ]
222
+ inner_feature = _infer_nested_feature_recursive (inner_dtype , dset )
223
+ return Sequence (inner_feature )
224
+ else :
225
+ if hasattr (dtype , "kind" ) and dtype .kind == "O" :
226
+ if _is_variable_length_string (dset ):
227
+ base_dtype = np .dtype ("S1" )
228
+ else :
229
+ base_dtype = _base_dtype (dset .dtype )
230
+ return Sequence (_np_to_pa_to_hf_value (base_dtype ))
231
+ else :
232
+ return _np_to_pa_to_hf_value (dtype )
196
233
197
234
198
- def has_zero_dimensions (feature : _ArrayXD | Sequence | LargeList ):
235
+ def _has_zero_dimensions (feature ):
199
236
if isinstance (feature , _ArrayXD ):
200
237
return any (dim == 0 for dim in feature .shape )
201
238
elif isinstance (feature , (Sequence , LargeList )):
202
- return feature .length == 0 or has_zero_dimensions (feature .feature )
239
+ return feature .length == 0 or _has_zero_dimensions (feature .feature )
203
240
else :
204
241
return False
242
+
243
+
244
+ def _sized_arrayxd (rank : int ):
245
+ return {2 : Array2D , 3 : Array3D , 4 : Array4D , 5 : Array5D }[rank ]
246
+
247
+
248
+ def _np_to_pa_to_hf_value (numpy_dtype : np .dtype ) -> Value :
249
+ return Value (dtype = _arrow_to_datasets_dtype (pa .from_numpy_dtype (numpy_dtype )))
250
+
251
+
252
+ def _is_ragged_dataset (dset : h5py .Dataset ) -> bool :
253
+ return dset .dtype .kind == "O" and hasattr (dset .dtype , "subdtype" )
254
+
255
+
256
+ def _is_variable_length_string (dset : h5py .Dataset ) -> bool :
257
+ if not _is_ragged_dataset (dset ) or dset .shape [0 ] == 0 :
258
+ return False
259
+ num_samples = min (3 , dset .shape [0 ])
260
+ for i in range (num_samples ):
261
+ try :
262
+ if isinstance (dset [i ], (str , bytes )):
263
+ return True
264
+ except (IndexError , TypeError ):
265
+ continue
266
+ return False
267
+
268
+
269
+ def _variable_length_string_to_pyarrow (arr : np .ndarray , dset : h5py .Dataset ) -> pa .Array :
270
+ list_of_strings = []
271
+ for item in arr :
272
+ if isinstance (item , bytes ):
273
+ item = item .decode ("utf-8" )
274
+ list_of_strings .append (item )
275
+ return pa .array (list_of_strings )
0 commit comments