Skip to content

Commit 5573448

Browse files
committed
Implement workaround for pytables
1 parent 53adaf2 commit 5573448

File tree

3 files changed

+64
-20
lines changed

3 files changed

+64
-20
lines changed

cebra/data/load.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,11 +275,11 @@ def _is_dlc_df(h5_file: IO[bytes], df_keys: List[str]) -> bool:
275275
"""
276276
try:
277277
if ["_i_table", "table"] in df_keys:
278-
df = pd.read_hdf(h5_file, key="table")
278+
df = read_hdf(h5_file, key="table")
279279
else:
280-
df = pd.read_hdf(h5_file, key=df_keys[0])
280+
df = read_hdf(h5_file, key=df_keys[0])
281281
except KeyError:
282-
df = pd.read_hdf(h5_file)
282+
df = read_hdf(h5_file)
283283
return all(value in df.columns.names
284284
for value in ["scorer", "bodyparts", "coords"])
285285

@@ -348,7 +348,7 @@ def load_from_h5(file: Union[pathlib.Path, str], key: str,
348348
Returns:
349349
A :py:func:`numpy.array` containing the data of interest extracted from the :py:class:`pandas.DataFrame`.
350350
"""
351-
df = pd.read_hdf(file, key=key)
351+
df = read_hdf(file, key=key)
352352
if columns is None:
353353
loaded_array = df.values
354354
elif isinstance(columns, list) and df.columns.nlevels == 1:
@@ -716,3 +716,50 @@ def _get_loader(file_ending: str) -> _BaseLoader:
716716
if file_ending not in __loaders.keys() or file_ending == "":
717717
raise OSError(f"File ending {file_ending} not supported.")
718718
return __loaders[file_ending]
719+
720+
721+
def read_hdf(filename, key=None):
722+
"""Read HDF5 file using pandas, with fallback to h5py if pandas fails.
723+
724+
Args:
725+
filename: Path to HDF5 file
726+
key: Optional key to read from HDF5 file. If None, tries "df_with_missing"
727+
then falls back to first available key.
728+
729+
Returns:
730+
pandas.DataFrame: The loaded data
731+
732+
Raises:
733+
RuntimeError: If both pandas and h5py fail to load the file
734+
"""
735+
736+
try:
737+
if key is not None:
738+
return pd.read_hdf(filename, key)
739+
else:
740+
return pd.read_hdf(filename)
741+
except Exception as e:
742+
with h5py.File(filename, "r") as f:
743+
try:
744+
if key is not None and key in f:
745+
hdf_key = key
746+
else:
747+
hdf_key = list(f.keys())[0]
748+
749+
data = f[hdf_key][()]
750+
column_names = f[hdf_key].attrs.get('column_names', None)
751+
752+
df = pd.DataFrame(data)
753+
if column_names is not None:
754+
df.columns = column_names
755+
756+
df.columns = pd.MultiIndex.from_tuples(
757+
[tuple(col.split('/')) for col in df.columns],
758+
names=['scorer', 'bodyparts', 'coords'])
759+
760+
return df
761+
762+
except Exception as inner_e:
763+
raise RuntimeError(
764+
f"Failed to load HDF5 file with both pandas and h5py: {str(e)} -> {str(inner_e)}"
765+
)

tests/test_dlc.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import cebra.integrations.deeplabcut as cebra_dlc
3030
from cebra import CEBRA
3131
from cebra import load_data
32+
from cebra.data.load import read_hdf
3233

3334
# NOTE(stes): The original data URL is
3435
# https://github.com/DeepLabCut/DeepLabCut/blob/main/examples
@@ -54,11 +55,7 @@ def test_imports():
5455

5556

5657
def _load_dlc_dataframe(filename):
57-
try:
58-
df = pd.read_hdf(filename, "df_with_missing")
59-
except KeyError:
60-
df = pd.read_hdf(filename)
61-
return df
58+
return read_hdf(filename)
6259

6360

6461
def _get_annotated_data(url, keypoints):

tests/test_load.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def generate_h5_no_array(filename, dtype):
248248
def generate_h5_dataframe(filename, dtype):
249249
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
250250
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
251-
df_A.to_hdf(filename, "df_A")
251+
df_A.to_hdf(filename, key="df_A")
252252
loaded_A = cebra_load.load(filename, key="df_A")
253253
return A, loaded_A
254254

@@ -258,7 +258,7 @@ def generate_h5_dataframe_columns(filename, dtype):
258258
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
259259
A_col = A[:, :2]
260260
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
261-
df_A.to_hdf(filename, "df_A")
261+
df_A.to_hdf(filename, key="df_A")
262262
loaded_A = cebra_load.load(filename, key="df_A", columns=["a", "b"])
263263
return A_col, loaded_A
264264

@@ -269,8 +269,8 @@ def generate_h5_multi_dataframe(filename, dtype):
269269
B = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
270270
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
271271
df_B = pd.DataFrame(np.array(B), columns=["c", "d", "e"])
272-
df_A.to_hdf(filename, "df_A")
273-
df_B.to_hdf(filename, "df_B")
272+
df_A.to_hdf(filename, key="df_A")
273+
df_B.to_hdf(filename, key="df_B")
274274
loaded_A = cebra_load.load(filename, key="df_A")
275275
return A, loaded_A
276276

@@ -279,7 +279,7 @@ def generate_h5_multi_dataframe(filename, dtype):
279279
def generate_h5_single_dataframe_no_key(filename, dtype):
280280
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
281281
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
282-
df_A.to_hdf(filename, "df_A")
282+
df_A.to_hdf(filename, key="df_A")
283283
loaded_A = cebra_load.load(filename)
284284
return A, loaded_A
285285

@@ -290,8 +290,8 @@ def generate_h5_multi_dataframe_no_key(filename, dtype):
290290
B = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
291291
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
292292
df_B = pd.DataFrame(np.array(B), columns=["c", "d", "e"])
293-
df_A.to_hdf(filename, "df_A")
294-
df_B.to_hdf(filename, "df_B")
293+
df_A.to_hdf(filename, key="df_A")
294+
df_B.to_hdf(filename, key="df_B")
295295
_ = cebra_load.load(filename)
296296

297297

@@ -304,7 +304,7 @@ def generate_h5_multicol_dataframe(filename, dtype):
304304
df_A = pd.DataFrame(A,
305305
columns=pd.MultiIndex.from_product([animals,
306306
keypoints]))
307-
df_A.to_hdf(filename, "df_A")
307+
df_A.to_hdf(filename, key="df_A")
308308
loaded_A = cebra_load.load(filename, key="df_A")
309309
return A, loaded_A
310310

@@ -313,15 +313,15 @@ def generate_h5_multicol_dataframe(filename, dtype):
313313
def generate_h5_dataframe_invalid_key(filename, dtype):
314314
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
315315
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
316-
df_A.to_hdf(filename, "df_A")
316+
df_A.to_hdf(filename, key="df_A")
317317
_ = cebra_load.load(filename, key="df_B")
318318

319319

320320
@register_error("h5", "hdf", "hdf5", "h")
321321
def generate_h5_dataframe_invalid_column(filename, dtype):
322322
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
323323
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
324-
df_A.to_hdf(filename, "df_A")
324+
df_A.to_hdf(filename, key="df_A")
325325
_ = cebra_load.load(filename, key="df_A", columns=["d", "b"])
326326

327327

@@ -334,7 +334,7 @@ def generate_h5_multicol_dataframe_columns(filename, dtype):
334334
df_A = pd.DataFrame(A,
335335
columns=pd.MultiIndex.from_product([animals,
336336
keypoints]))
337-
df_A.to_hdf(filename, "df_A")
337+
df_A.to_hdf(filename, key="df_A")
338338
_ = cebra_load.load(filename, key="df_A", columns=["a", "b"])
339339

340340

0 commit comments

Comments
 (0)