Skip to content

Commit b12fa5e

Browse files
committed
update tests
1 parent b406154 commit b12fa5e

File tree

1 file changed

+64
-81
lines changed

1 file changed

+64
-81
lines changed

tests/packaged_modules/test_hdf5.py

Lines changed: 64 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,64 @@ def empty_hdf5_file(tmp_path):
225225
return str(filename)
226226

227227

228+
@pytest.fixture
229+
def hdf5_file_with_mixed_data_types(tmp_path):
230+
"""Create an HDF5 file with mixed data types in the same file."""
231+
filename = tmp_path / "mixed.h5"
232+
n_rows = 3
233+
234+
with h5py.File(filename, "w") as f:
235+
# Regular numeric data
236+
f.create_dataset("regular_int", data=np.arange(n_rows, dtype=np.int32))
237+
f.create_dataset("regular_float", data=np.arange(n_rows, dtype=np.float32))
238+
239+
# Complex data
240+
complex_data = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64)
241+
f.create_dataset("complex_data", data=complex_data)
242+
243+
# Compound data
244+
dt_compound = np.dtype([("x", "i4"), ("y", "f8")])
245+
compound_data = np.array([(1, 2.5), (3, 4.5), (5, 6.5)], dtype=dt_compound)
246+
f.create_dataset("compound_data", data=compound_data)
247+
248+
return str(filename)
249+
250+
251+
@pytest.fixture
252+
def hdf5_file_with_complex_collision(tmp_path):
253+
"""Create an HDF5 file where complex dataset would collide with existing dataset name."""
254+
filename = tmp_path / "collision.h5"
255+
256+
with h5py.File(filename, "w") as f:
257+
# Create a complex dataset
258+
complex_data = np.array([1 + 2j, 3 + 4j], dtype=np.complex64)
259+
f.create_dataset("data", data=complex_data)
260+
261+
# Create a regular dataset that would collide with the complex real part
262+
regular_data = np.array([1.0, 2.0], dtype=np.float32)
263+
f.create_dataset("data_real", data=regular_data) # This should cause a collision
264+
265+
return str(filename)
266+
267+
268+
@pytest.fixture
269+
def hdf5_file_with_compound_collision(tmp_path):
270+
"""Create an HDF5 file where compound dataset would collide with existing dataset name."""
271+
filename = tmp_path / "compound_collision.h5"
272+
273+
with h5py.File(filename, "w") as f:
274+
# Create a compound dataset
275+
dt_compound = np.dtype([("x", "i4"), ("y", "f8")])
276+
compound_data = np.array([(1, 2.5), (3, 4.5)], dtype=dt_compound)
277+
f.create_dataset("position", data=compound_data)
278+
279+
# Create a regular dataset that would collide with compound field
280+
regular_data = np.array([10, 20], dtype=np.int32)
281+
f.create_dataset("position_x", data=regular_data) # This should cause a collision
282+
283+
return str(filename)
284+
285+
228286
def test_config_raises_when_invalid_name():
229287
"""Test that invalid config names raise an error."""
230288
with pytest.raises(InvalidConfigName, match="Bad characters"):
@@ -596,16 +654,6 @@ def test_hdf5_no_data_files_error():
596654
hdf5._split_generators(None)
597655

598656

599-
def test_hdf5_config_options():
600-
"""Test HDF5Config with different options."""
601-
# Test default options
602-
config = HDF5Config()
603-
# Complex and compound types are always split now, no config options needed
604-
assert config.batch_size is None
605-
assert config.columns is None
606-
assert config.features is None
607-
608-
609657
def test_hdf5_complex_numbers(hdf5_file_with_complex_data):
610658
"""Test HDF5 loading with complex number datasets."""
611659
config = HDF5Config()
@@ -670,32 +718,6 @@ def test_hdf5_compound_types(hdf5_file_with_compound_data):
670718
assert y_data == [2.5, 4.5, 6.5]
671719

672720

673-
def test_hdf5_unsupported_dtype_handling(tmp_path):
674-
"""Test handling of truly unsupported dtypes."""
675-
filename = tmp_path / "unsupported.h5"
676-
677-
with h5py.File(filename, "w") as f:
678-
# Create a dataset with an unsupported dtype (e.g., bitfield)
679-
# This should raise a TypeError during feature inference
680-
bitfield_data = np.array([1, 2, 3], dtype=np.uint8)
681-
# We'll create a dataset that will fail during feature inference
682-
# by using a custom dtype that's not supported
683-
f.create_dataset("bitfield_data", data=bitfield_data)
684-
685-
config = HDF5Config()
686-
hdf5 = HDF5()
687-
hdf5.config = config
688-
hdf5.config.data_files = DataFilesDict({"train": [str(filename)]})
689-
690-
# This should not raise an error since uint8 is supported
691-
# Let's test with a different approach - create a dataset that will fail
692-
# during the actual data loading phase
693-
dl_manager = StreamingDownloadManager()
694-
hdf5._split_generators(dl_manager)
695-
696-
# The test passes if no error is raised, since uint8 is actually supported
697-
698-
699721
def test_hdf5_feature_inference_complex(hdf5_file_with_complex_data):
700722
"""Test automatic feature inference for complex datasets."""
701723
config = HDF5Config()
@@ -740,29 +762,13 @@ def test_hdf5_feature_inference_compound(hdf5_file_with_compound_data):
740762
assert features["simple_compound_y"] == Value("float64")
741763

742764

743-
def test_hdf5_mixed_data_types(tmp_path):
765+
def test_hdf5_mixed_data_types(hdf5_file_with_mixed_data_types):
744766
"""Test HDF5 loading with mixed data types in the same file."""
745-
filename = tmp_path / "mixed.h5"
746-
747-
with h5py.File(filename, "w") as f:
748-
# Regular numeric data
749-
f.create_dataset("regular_int", data=np.arange(3, dtype=np.int32))
750-
f.create_dataset("regular_float", data=np.arange(3, dtype=np.float32))
751-
752-
# Complex data
753-
complex_data = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64)
754-
f.create_dataset("complex_data", data=complex_data)
755-
756-
# Compound data
757-
dt_compound = np.dtype([("x", "i4"), ("y", "f8")])
758-
compound_data = np.array([(1, 2.5), (3, 4.5), (5, 6.5)], dtype=dt_compound)
759-
f.create_dataset("compound_data", data=compound_data)
760-
761767
config = HDF5Config()
762768
hdf5 = HDF5()
763769
hdf5.config = config
764770

765-
generator = hdf5._generate_tables([[str(filename)]])
771+
generator = hdf5._generate_tables([[hdf5_file_with_mixed_data_types]])
766772
tables = list(generator)
767773

768774
assert len(tables) == 1
@@ -785,48 +791,25 @@ def test_hdf5_mixed_data_types(tmp_path):
785791
assert len(table["compound_data_x"].to_pylist()) == 3
786792

787793

788-
def test_hdf5_column_name_collision_detection(tmp_path):
794+
def test_hdf5_column_name_collision_detection(hdf5_file_with_complex_collision):
789795
"""Test that column name collision detection works correctly."""
790-
filename = tmp_path / "collision.h5"
791-
792-
with h5py.File(filename, "w") as f:
793-
# Create a complex dataset
794-
complex_data = np.array([1 + 2j, 3 + 4j], dtype=np.complex64)
795-
f.create_dataset("data", data=complex_data)
796-
797-
# Create a regular dataset that would collide with the complex real part
798-
regular_data = np.array([1.0, 2.0], dtype=np.float32)
799-
f.create_dataset("data_real", data=regular_data) # This should cause a collision
800-
801796
config = HDF5Config()
802797
hdf5 = HDF5()
803798
hdf5.config = config
804-
hdf5.config.data_files = DataFilesDict({"train": [str(filename)]})
799+
hdf5.config.data_files = DataFilesDict({"train": [hdf5_file_with_complex_collision]})
805800

806801
# This should raise a ValueError due to column name collision
807802
dl_manager = StreamingDownloadManager()
808803
with pytest.raises(ValueError, match="Column name collision detected"):
809804
hdf5._split_generators(dl_manager)
810805

811806

812-
def test_hdf5_compound_collision_detection(tmp_path):
807+
def test_hdf5_compound_collision_detection(hdf5_file_with_compound_collision):
813808
"""Test collision detection with compound types."""
814-
filename = tmp_path / "compound_collision.h5"
815-
816-
with h5py.File(filename, "w") as f:
817-
# Create a compound dataset
818-
dt_compound = np.dtype([("x", "i4"), ("y", "f8")])
819-
compound_data = np.array([(1, 2.5), (3, 4.5)], dtype=dt_compound)
820-
f.create_dataset("position", data=compound_data)
821-
822-
# Create a regular dataset that would collide with compound field
823-
regular_data = np.array([10, 20], dtype=np.int32)
824-
f.create_dataset("position_x", data=regular_data) # This should cause a collision
825-
826809
config = HDF5Config()
827810
hdf5 = HDF5()
828811
hdf5.config = config
829-
hdf5.config.data_files = DataFilesDict({"train": [str(filename)]})
812+
hdf5.config.data_files = DataFilesDict({"train": [hdf5_file_with_compound_collision]})
830813

831814
# This should raise a ValueError due to column name collision
832815
dl_manager = StreamingDownloadManager()

0 commit comments

Comments
 (0)