Skip to content

Commit df52454

Browse files
committed
allow mismatched lengths if ignored
1 parent cdb7f73 commit df52454

File tree

2 files changed

+77
-4
lines changed

2 files changed

+77
-4
lines changed

src/datasets/packaged_modules/hdf5/hdf5.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,21 @@ def _generate_tables(self, files):
133133
if not dataset_map:
134134
logger.warning(f"File '{file}' contains no data, skipping...")
135135
continue
136+
137+
if self.config.columns is not None:
138+
filtered_dataset_map = {
139+
path: dset for path, dset in dataset_map.items() if path in self.config.columns
140+
}
141+
if not filtered_dataset_map:
142+
logger.warning(
143+
f"No datasets match the specified columns {self.config.columns}, skipping..."
144+
)
145+
continue
146+
dataset_map = filtered_dataset_map
147+
148+
# Sanity-check lengths for selected datasets
136149
first_dset = next(iter(dataset_map.values()))
137150
num_rows = first_dset.shape[0]
138-
# Sanity-check lengths
139151
for path, dset in dataset_map.items():
140152
if dset.shape[0] != num_rows:
141153
raise ValueError(
@@ -146,8 +158,6 @@ def _generate_tables(self, files):
146158
end = min(start + effective_batch, num_rows)
147159
batch_dict = {}
148160
for path, dset in dataset_map.items():
149-
if self.config.columns is not None and path not in self.config.columns:
150-
continue
151161
arr = dset[start:end]
152162

153163
# Handle variable-length arrays

tests/packaged_modules/test_hdf5.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,16 @@ def hdf5_file_with_mismatched_lengths(tmp_path):
190190

191191
with h5py.File(filename, "w") as f:
192192
f.create_dataset("data1", data=np.arange(5, dtype=np.int32))
193-
f.create_dataset("data2", data=np.arange(3, dtype=np.int32)) # Different length
193+
# Dataset with 3 rows (mismatched)
194+
f.create_dataset("data2", data=np.arange(3, dtype=np.int32))
195+
f.create_dataset("data3", data=np.random.randn(5, 3, 4).astype(np.float32))
196+
f.create_dataset("data4", data=np.arange(5, dtype=np.float64) / 10.0)
197+
f.create_dataset("data5", data=np.array([True, False, True, False, True]))
198+
var_strings = ["short", "medium length", "very long string", "tiny", "another string"]
199+
dt = h5py.vlen_dtype(str)
200+
dset = f.create_dataset("data6", (5,), dtype=dt)
201+
for i, s in enumerate(var_strings):
202+
dset[i] = s
194203

195204
return str(filename)
196205

@@ -815,3 +824,57 @@ def test_hdf5_compound_collision_detection(hdf5_file_with_compound_collision):
815824
dl_manager = StreamingDownloadManager()
816825
with pytest.raises(ValueError, match="Column name collision detected"):
817826
hdf5._split_generators(dl_manager)
827+
828+
829+
def test_hdf5_mismatched_lengths_with_column_filtering(hdf5_file_with_mismatched_lengths):
830+
"""Test that mismatched dataset lengths are ignored when the mismatched dataset is excluded via columns config."""
831+
config = HDF5Config(columns=["data1"])
832+
hdf5 = HDF5()
833+
hdf5.config = config
834+
835+
generator = hdf5._generate_tables([[hdf5_file_with_mismatched_lengths]])
836+
tables = list(generator)
837+
838+
# Should work without error since we're only including the first dataset
839+
assert len(tables) == 1
840+
_, table = tables[0]
841+
842+
# Check that only the specified column is present
843+
expected_columns = {"data1"}
844+
assert set(table.column_names) == expected_columns
845+
assert "data2" not in table.column_names
846+
847+
# Check the data
848+
data1_values = table["data1"].to_pylist()
849+
assert data1_values == [0, 1, 2, 3, 4]
850+
851+
# Test 2: Include multiple compatible datasets (all with 5 rows)
852+
config2 = HDF5Config(columns=["data1", "data3", "data4", "data5", "data6"])
853+
hdf5.config = config2
854+
855+
generator2 = hdf5._generate_tables([[hdf5_file_with_mismatched_lengths]])
856+
tables2 = list(generator2)
857+
858+
# Should work without error since we're excluding the mismatched dataset
859+
assert len(tables2) == 1
860+
_, table2 = tables2[0]
861+
862+
# Check that all specified columns are present
863+
expected_columns2 = {"data1", "data3", "data4", "data5", "data6"}
864+
assert set(table2.column_names) == expected_columns2
865+
assert "data2" not in table2.column_names
866+
867+
# Check data types and values
868+
assert table2["data1"].to_pylist() == [0, 1, 2, 3, 4] # int32
869+
assert len(table2["data3"].to_pylist()) == 5 # Array2D
870+
assert len(table2["data3"].to_pylist()[0]) == 3 # 3 rows in each 2D array
871+
assert len(table2["data3"].to_pylist()[0][0]) == 4 # 4 columns in each 2D array
872+
np.testing.assert_allclose(table2["data4"].to_pylist(), [0.0, 0.1, 0.2, 0.3, 0.4], rtol=1e-6) # float64
873+
assert table2["data5"].to_pylist() == [True, False, True, False, True] # boolean
874+
assert table2["data6"].to_pylist() == [
875+
"short",
876+
"medium length",
877+
"very long string",
878+
"tiny",
879+
"another string",
880+
] # vlen string

0 commit comments

Comments
 (0)