1
1
import itertools
2
2
from dataclasses import dataclass
3
- from typing import Any , Dict , List , Optional
3
+ from typing import TYPE_CHECKING , Any , Dict , List , Optional
4
4
5
5
import numpy as np
6
6
import pyarrow as pa
7
7
8
8
import datasets
9
- import h5py
10
9
from datasets .features .features import (
11
10
Array2D ,
12
11
Array3D ,
21
20
from datasets .table import table_cast
22
21
23
22
23
+ if TYPE_CHECKING :
24
+ import h5py
25
+
24
26
logger = datasets .utils .logging .get_logger (__name__ )
25
27
26
28
EXTENSIONS = [".h5" , ".hdf5" ]
@@ -56,6 +58,8 @@ def _info(self):
56
58
return datasets .DatasetInfo (features = self .config .features )
57
59
58
60
def _split_generators (self , dl_manager ):
61
+ import h5py
62
+
59
63
if not self .config .data_files :
60
64
raise ValueError (f"At least one data file must be specified, but got data_files={ self .config .data_files } " )
61
65
dl_manager .download_config .extract_on_the_fly = True
@@ -119,6 +123,8 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:
119
123
return pa_table
120
124
121
125
def _generate_tables (self , files ):
126
+ import h5py
127
+
122
128
batch_size_cfg = self .config .batch_size
123
129
for file_idx , file in enumerate (itertools .chain .from_iterable (files )):
124
130
try :
@@ -179,7 +185,9 @@ def _generate_tables(self, files):
179
185
raise
180
186
181
187
182
- def _traverse_datasets (h5_obj , prefix : str = "" ) -> Dict [str , h5py .Dataset ]:
188
+ def _traverse_datasets (h5_obj , prefix : str = "" ) -> Dict [str , "h5py.Dataset" ]:
189
+ import h5py
190
+
183
191
mapping : Dict [str , h5py .Dataset ] = {}
184
192
185
193
def collect_datasets (name , obj ):
@@ -201,7 +209,7 @@ def _is_complex_dtype(dtype: np.dtype) -> bool:
201
209
return dtype .kind == "c"
202
210
203
211
204
- def _create_complex_features (base_path : str , dset : h5py .Dataset ) -> Dict [str , Value ]:
212
+ def _create_complex_features (base_path : str , dset : " h5py.Dataset" ) -> Dict [str , Value ]:
205
213
"""Create separate features for real and imaginary parts of complex data.
206
214
207
215
NOTE: Always uses float64 for the real and imaginary parts.
@@ -212,7 +220,7 @@ def _create_complex_features(base_path: str, dset: h5py.Dataset) -> Dict[str, Va
212
220
return {f"{ base_path } _real" : Value ("float64" ), f"{ base_path } _imag" : Value ("float64" )}
213
221
214
222
215
- def _convert_complex_to_separate_columns (base_path : str , arr : np .ndarray , dset : h5py .Dataset ) -> Dict [str , pa .Array ]:
223
+ def _convert_complex_to_separate_columns (base_path : str , arr : np .ndarray , dset : " h5py.Dataset" ) -> Dict [str , pa .Array ]:
216
224
"""Convert complex array to separate real and imaginary columns."""
217
225
result = {}
218
226
result [f"{ base_path } _real" ] = datasets .features .features .numpy_to_pyarrow_listarray (arr .real )
@@ -236,7 +244,7 @@ def __init__(self, dtype):
236
244
self .names = dtype .names
237
245
238
246
239
- def _create_compound_features (base_path : str , dset : h5py .Dataset ) -> Dict [str , Any ]:
247
+ def _create_compound_features (base_path : str , dset : " h5py.Dataset" ) -> Dict [str , Any ]:
240
248
"""Create separate features for each field in compound data."""
241
249
field_names = list (dset .dtype .names )
242
250
logger .info (
@@ -262,7 +270,9 @@ def _create_compound_features(base_path: str, dset: h5py.Dataset) -> Dict[str, A
262
270
return features
263
271
264
272
265
- def _convert_compound_to_separate_columns (base_path : str , arr : np .ndarray , dset : h5py .Dataset ) -> Dict [str , pa .Array ]:
273
+ def _convert_compound_to_separate_columns (
274
+ base_path : str , arr : np .ndarray , dset : "h5py.Dataset"
275
+ ) -> Dict [str , pa .Array ]:
266
276
"""Convert compound array to separate columns for each field."""
267
277
result = {}
268
278
for field_name in list (dset .dtype .names ):
@@ -314,7 +324,7 @@ def _convert_vlen_string_to_array(arr: np.ndarray) -> pa.Array:
314
324
# └───────────┘
315
325
316
326
317
- def _infer_feature_from_dataset (dset : h5py .Dataset ):
327
+ def _infer_feature_from_dataset (dset : " h5py.Dataset" ):
318
328
# non-string varlen
319
329
if hasattr (dset .dtype , "metadata" ) and dset .dtype .metadata and "vlen" in dset .dtype .metadata :
320
330
vlen_dtype = dset .dtype .metadata ["vlen" ]
0 commit comments