Skip to content

Commit ed04d43

Browse files
authored
Merge branch 'main' into stes/release-0.5.0
2 parents f0c8087 + 3100730 commit ed04d43

File tree

6 files changed

+77
-30
lines changed

6 files changed

+77
-30
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
# We aim to support the versions on pytorch.org
1919
# as well as selected previous versions on
2020
# https://pytorch.org/get-started/previous-versions/
21-
torch-version: ["2.2.2", "2.4.0"]
21+
torch-version: ["2.4.0", "2.6.0"]
2222
sklearn-version: ["latest"]
2323
include:
2424
- os: windows-latest

cebra/data/load.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,11 +275,11 @@ def _is_dlc_df(h5_file: IO[bytes], df_keys: List[str]) -> bool:
275275
"""
276276
try:
277277
if ["_i_table", "table"] in df_keys:
278-
df = pd.read_hdf(h5_file, key="table")
278+
df = read_hdf(h5_file, key="table")
279279
else:
280-
df = pd.read_hdf(h5_file, key=df_keys[0])
280+
df = read_hdf(h5_file, key=df_keys[0])
281281
except KeyError:
282-
df = pd.read_hdf(h5_file)
282+
df = read_hdf(h5_file)
283283
return all(value in df.columns.names
284284
for value in ["scorer", "bodyparts", "coords"])
285285

@@ -348,7 +348,7 @@ def load_from_h5(file: Union[pathlib.Path, str], key: str,
348348
Returns:
349349
A :py:func:`numpy.array` containing the data of interest extracted from the :py:class:`pandas.DataFrame`.
350350
"""
351-
df = pd.read_hdf(file, key=key)
351+
df = read_hdf(file, key=key)
352352
if columns is None:
353353
loaded_array = df.values
354354
elif isinstance(columns, list) and df.columns.nlevels == 1:
@@ -716,3 +716,21 @@ def _get_loader(file_ending: str) -> _BaseLoader:
716716
if file_ending not in __loaders.keys() or file_ending == "":
717717
raise OSError(f"File ending {file_ending} not supported.")
718718
return __loaders[file_ending]
719+
720+
721+
def read_hdf(filename, key=None):
722+
"""Read HDF5 file using pandas, with fallback to h5py if pandas fails.
723+
724+
Args:
725+
filename: Path to HDF5 file
726+
key: Optional key to read from HDF5 file. If None, tries "df_with_missing"
727+
then falls back to first available key.
728+
729+
Returns:
730+
pandas.DataFrame: The loaded data
731+
732+
Raises:
733+
RuntimeError: If both pandas and h5py fail to load the file
734+
"""
735+
736+
return pd.read_hdf(filename, key=key)

cebra/integrations/sklearn/cebra.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import numpy as np
2929
import numpy.typing as npt
30+
import packaging.version
3031
import pkg_resources
3132
import sklearn
3233
import sklearn.utils.validation as sklearn_utils_validation
@@ -43,12 +44,37 @@
4344
import cebra.models
4445
import cebra.solver
4546

47+
# NOTE(stes): From torch 2.6 onwards, we need to specify the following list
48+
# when loading CEBRA models to allow weights_only = True.
49+
CEBRA_LOAD_SAFE_GLOBALS = [
50+
cebra.data.Offset, torch.torch_version.TorchVersion, np.dtype,
51+
np.dtypes.Float64DType, np.dtypes.Int64DType
52+
]
4653

4754
def check_version(estimator):
4855
# NOTE(stes): required as a check for the old way of specifying tags
4956
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
50-
from packaging import version
51-
return version.parse(sklearn.__version__) < version.parse("1.6.dev")
57+
return packaging.version.parse(
58+
sklearn.__version__) < packaging.version.parse("1.6.dev")
59+
60+
61+
def _safe_torch_load(filename, weights_only, **kwargs):
62+
if weights_only is None:
63+
if packaging.version.parse(
64+
torch.__version__) >= packaging.version.parse("2.6.0"):
65+
weights_only = True
66+
else:
67+
weights_only = False
68+
69+
if not weights_only:
70+
checkpoint = torch.load(filename, weights_only=False, **kwargs)
71+
else:
72+
# NOTE(stes): This is only supported for torch 2.6+
73+
with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS):
74+
checkpoint = torch.load(filename, weights_only=True, **kwargs)
75+
76+
return checkpoint
77+
5278

5379

5480
def _init_loader(
@@ -1411,15 +1437,22 @@ def save(self,
14111437
def load(cls,
14121438
filename: str,
14131439
backend: Literal["auto", "sklearn", "torch"] = "auto",
1440+
weights_only: bool = None,
14141441
**kwargs) -> "CEBRA":
14151442
"""Load a model from disk.
14161443
14171444
Args:
14181445
filename: The path to the file in which to save the trained model.
14191446
backend: A string identifying the used backend.
1447+
weights_only: Indicates whether unpickler should be restricted to loading only tensors, primitive types,
1448+
dictionaries and any types added via :py:func:`torch.serialization.add_safe_globals`.
1449+
See :py:func:`torch.load` with ``weights_only=True`` for more details. It it recommended to leave this
1450+
at the default value of ``None``, which sets the argument to ``False`` for torch<2.6, and ``True`` for
1451+
higher versions of torch. If you experience issues with loading custom models (specified outside
1452+
of the CEBRA package), you can try to set this to ``False`` if you trust the source of the model.
14201453
kwargs: Optional keyword arguments passed directly to the loader.
14211454
1422-
Return:
1455+
Returns:
14231456
The model to load.
14241457
14251458
Note:
@@ -1429,7 +1462,6 @@ def load(cls,
14291462
For information about the file format please refer to :py:meth:`cebra.CEBRA.save`.
14301463
14311464
Example:
1432-
14331465
>>> import cebra
14341466
>>> import numpy as np
14351467
>>> import tempfile
@@ -1443,16 +1475,14 @@ def load(cls,
14431475
>>> loaded_model = cebra.CEBRA.load(tmp_file)
14441476
>>> embedding = loaded_model.transform(dataset)
14451477
>>> tmp_file.unlink()
1446-
14471478
"""
1448-
14491479
supported_backends = ["auto", "sklearn", "torch"]
14501480
if backend not in supported_backends:
14511481
raise NotImplementedError(
14521482
f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}"
14531483
)
14541484

1455-
checkpoint = torch.load(filename, **kwargs)
1485+
checkpoint = _safe_torch_load(filename, weights_only, **kwargs)
14561486

14571487
if backend == "auto":
14581488
backend = "sklearn" if isinstance(checkpoint, dict) else "torch"

setup.cfg

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@ where =
3131
python_requires = >=3.9
3232
install_requires =
3333
joblib
34-
numpy<2.0.0
34+
numpy<2.0;platform_system=="Windows"
35+
numpy<2.0;platform_system!="Windows" and python_version<"3.10"
36+
numpy;platform_system!="Windows" and python_version>="3.10"
3537
literate-dataclasses
3638
scikit-learn
3739
scipy
38-
torch
40+
torch>=2.4.0
3941
tqdm
4042
matplotlib
4143
requests

tests/test_dlc.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import cebra.integrations.deeplabcut as cebra_dlc
3030
from cebra import CEBRA
3131
from cebra import load_data
32+
from cebra.data.load import read_hdf
3233

3334
# NOTE(stes): The original data URL is
3435
# https://github.com/DeepLabCut/DeepLabCut/blob/main/examples
@@ -54,11 +55,7 @@ def test_imports():
5455

5556

5657
def _load_dlc_dataframe(filename):
57-
try:
58-
df = pd.read_hdf(filename, "df_with_missing")
59-
except KeyError:
60-
df = pd.read_hdf(filename)
61-
return df
58+
return read_hdf(filename)
6259

6360

6461
def _get_annotated_data(url, keypoints):

tests/test_load.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def generate_h5_no_array(filename, dtype):
248248
def generate_h5_dataframe(filename, dtype):
249249
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
250250
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
251-
df_A.to_hdf(filename, "df_A")
251+
df_A.to_hdf(filename, key="df_A")
252252
loaded_A = cebra_load.load(filename, key="df_A")
253253
return A, loaded_A
254254

@@ -258,7 +258,7 @@ def generate_h5_dataframe_columns(filename, dtype):
258258
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
259259
A_col = A[:, :2]
260260
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
261-
df_A.to_hdf(filename, "df_A")
261+
df_A.to_hdf(filename, key="df_A")
262262
loaded_A = cebra_load.load(filename, key="df_A", columns=["a", "b"])
263263
return A_col, loaded_A
264264

@@ -269,8 +269,8 @@ def generate_h5_multi_dataframe(filename, dtype):
269269
B = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
270270
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
271271
df_B = pd.DataFrame(np.array(B), columns=["c", "d", "e"])
272-
df_A.to_hdf(filename, "df_A")
273-
df_B.to_hdf(filename, "df_B")
272+
df_A.to_hdf(filename, key="df_A")
273+
df_B.to_hdf(filename, key="df_B")
274274
loaded_A = cebra_load.load(filename, key="df_A")
275275
return A, loaded_A
276276

@@ -279,7 +279,7 @@ def generate_h5_multi_dataframe(filename, dtype):
279279
def generate_h5_single_dataframe_no_key(filename, dtype):
280280
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
281281
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
282-
df_A.to_hdf(filename, "df_A")
282+
df_A.to_hdf(filename, key="df_A")
283283
loaded_A = cebra_load.load(filename)
284284
return A, loaded_A
285285

@@ -290,8 +290,8 @@ def generate_h5_multi_dataframe_no_key(filename, dtype):
290290
B = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
291291
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
292292
df_B = pd.DataFrame(np.array(B), columns=["c", "d", "e"])
293-
df_A.to_hdf(filename, "df_A")
294-
df_B.to_hdf(filename, "df_B")
293+
df_A.to_hdf(filename, key="df_A")
294+
df_B.to_hdf(filename, key="df_B")
295295
_ = cebra_load.load(filename)
296296

297297

@@ -304,7 +304,7 @@ def generate_h5_multicol_dataframe(filename, dtype):
304304
df_A = pd.DataFrame(A,
305305
columns=pd.MultiIndex.from_product([animals,
306306
keypoints]))
307-
df_A.to_hdf(filename, "df_A")
307+
df_A.to_hdf(filename, key="df_A")
308308
loaded_A = cebra_load.load(filename, key="df_A")
309309
return A, loaded_A
310310

@@ -313,15 +313,15 @@ def generate_h5_multicol_dataframe(filename, dtype):
313313
def generate_h5_dataframe_invalid_key(filename, dtype):
314314
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
315315
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
316-
df_A.to_hdf(filename, "df_A")
316+
df_A.to_hdf(filename, key="df_A")
317317
_ = cebra_load.load(filename, key="df_B")
318318

319319

320320
@register_error("h5", "hdf", "hdf5", "h")
321321
def generate_h5_dataframe_invalid_column(filename, dtype):
322322
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
323323
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
324-
df_A.to_hdf(filename, "df_A")
324+
df_A.to_hdf(filename, key="df_A")
325325
_ = cebra_load.load(filename, key="df_A", columns=["d", "b"])
326326

327327

@@ -334,7 +334,7 @@ def generate_h5_multicol_dataframe_columns(filename, dtype):
334334
df_A = pd.DataFrame(A,
335335
columns=pd.MultiIndex.from_product([animals,
336336
keypoints]))
337-
df_A.to_hdf(filename, "df_A")
337+
df_A.to_hdf(filename, key="df_A")
338338
_ = cebra_load.load(filename, key="df_A", columns=["a", "b"])
339339

340340

0 commit comments

Comments
 (0)