Skip to content

Commit 545c9a0

Browse files
authored
Merge branch 'main' into add_options_mixed
2 parents 0326fb9 + f99530c commit 545c9a0

File tree

26 files changed

+521
-71
lines changed

26 files changed

+521
-71
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
# We aim to support the versions on pytorch.org
1919
# as well as selected previous versions on
2020
# https://pytorch.org/get-started/previous-versions/
21-
torch-version: ["2.2.2", "2.4.0"]
21+
torch-version: ["2.4.0", "2.6.0"]
2222
sklearn-version: ["latest"]
2323
include:
2424
- os: windows-latest

.github/workflows/release-pypi.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ jobs:
2828
path: ~/.cache/pip
2929
key: ${{ runner.os }}-pip
3030

31+
- name: Install dependencies
32+
run: |
33+
pip install --upgrade pip
34+
pip install wheel
35+
# NOTE(stes) see https://github.com/pypa/twine/issues/1216#issuecomment-2629069669
36+
pip install "packaging>=24.2"
37+
3138
- name: Checkout code
3239
uses: actions/checkout@v3
3340

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ RUN make dist
4040
FROM cebra-base
4141

4242
# install the cebra wheel
43-
ENV WHEEL=cebra-0.4.0-py2.py3-none-any.whl
43+
ENV WHEEL=cebra-0.5.0rc1-py3-none-any.whl
4444
WORKDIR /build
4545
COPY --from=wheel /build/dist/${WHEEL} .
4646
RUN pip install --no-cache-dir ${WHEEL}'[dev,integrations,datasets]'

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
CEBRA_VERSION := 0.4.0
1+
CEBRA_VERSION := 0.5.0rc1
22

33
dist:
44
python3 -m pip install virtualenv

PKGBUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Maintainer: Steffen Schneider <[email protected]>
22
pkgname=python-cebra
33
_pkgname=cebra
4-
pkgver=0.4.0
4+
pkgver=0.5.0rc1
55
pkgrel=1
66
pkgdesc="Consistent Embeddings of high-dimensional Recordings using Auxiliary variables"
77
url="https://cebra.ai"
@@ -40,7 +40,7 @@ build() {
4040

4141
package() {
4242
cd $srcdir/${_pkgname}-${pkgver}
43-
pip install --ignore-installed --no-deps --root="${pkgdir}" dist/${_pkgname}-${pkgver}-py2.py3-none-any.whl
43+
pip install --ignore-installed --no-deps --root="${pkgdir}" dist/${_pkgname}-${pkgver}-py3-none-any.whl
4444
find ${pkgdir} -iname __pycache__ -exec rm -r {} \; 2>/dev/null || echo
4545
install -Dm 644 LICENSE.md $pkgdir/usr/share/licenses/${pkgname}/LICENSE
4646
}

cebra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666

6767
import cebra.integrations.sklearn as sklearn
6868

69-
__version__ = "0.4.0"
69+
__version__ = "0.5.0rc1"
7070
__all__ = ["CEBRA"]
7171
__allow_lazy_imports = False
7272
__lazy_imports = {}

cebra/data/load.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,11 +275,11 @@ def _is_dlc_df(h5_file: IO[bytes], df_keys: List[str]) -> bool:
275275
"""
276276
try:
277277
if ["_i_table", "table"] in df_keys:
278-
df = pd.read_hdf(h5_file, key="table")
278+
df = read_hdf(h5_file, key="table")
279279
else:
280-
df = pd.read_hdf(h5_file, key=df_keys[0])
280+
df = read_hdf(h5_file, key=df_keys[0])
281281
except KeyError:
282-
df = pd.read_hdf(h5_file)
282+
df = read_hdf(h5_file)
283283
return all(value in df.columns.names
284284
for value in ["scorer", "bodyparts", "coords"])
285285

@@ -348,7 +348,7 @@ def load_from_h5(file: Union[pathlib.Path, str], key: str,
348348
Returns:
349349
A :py:func:`numpy.array` containing the data of interest extracted from the :py:class:`pandas.DataFrame`.
350350
"""
351-
df = pd.read_hdf(file, key=key)
351+
df = read_hdf(file, key=key)
352352
if columns is None:
353353
loaded_array = df.values
354354
elif isinstance(columns, list) and df.columns.nlevels == 1:
@@ -716,3 +716,21 @@ def _get_loader(file_ending: str) -> _BaseLoader:
716716
if file_ending not in __loaders.keys() or file_ending == "":
717717
raise OSError(f"File ending {file_ending} not supported.")
718718
return __loaders[file_ending]
719+
720+
721+
def read_hdf(filename, key=None):
722+
"""Read HDF5 file using pandas, with fallback to h5py if pandas fails.
723+
724+
Args:
725+
filename: Path to HDF5 file
726+
key: Optional key to read from HDF5 file. If None, tries "df_with_missing"
727+
then falls back to first available key.
728+
729+
Returns:
730+
pandas.DataFrame: The loaded data
731+
732+
Raises:
733+
RuntimeError: If both pandas and h5py fail to load the file
734+
"""
735+
736+
return pd.read_hdf(filename, key=key)

cebra/integrations/sklearn/cebra.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@
2727

2828
import numpy as np
2929
import numpy.typing as npt
30+
import packaging.version
3031
import pkg_resources
32+
import sklearn
3133
import sklearn.utils.validation as sklearn_utils_validation
3234
import torch
33-
import sklearn
3435
from sklearn.base import BaseEstimator
3536
from sklearn.base import TransformerMixin
3637
from sklearn.utils.metaestimators import available_if
@@ -43,11 +44,38 @@
4344
import cebra.models
4445
import cebra.solver
4546

47+
# NOTE(stes): From torch 2.6 onwards, we need to specify the following list
48+
# when loading CEBRA models to allow weights_only = True.
49+
CEBRA_LOAD_SAFE_GLOBALS = [
50+
cebra.data.Offset, torch.torch_version.TorchVersion, np.dtype,
51+
np.dtypes.Float64DType, np.dtypes.Int64DType
52+
]
53+
4654
def check_version(estimator):
4755
# NOTE(stes): required as a check for the old way of specifying tags
4856
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
49-
from packaging import version
50-
return version.parse(sklearn.__version__) < version.parse("1.6.dev")
57+
return packaging.version.parse(
58+
sklearn.__version__) < packaging.version.parse("1.6.dev")
59+
60+
61+
def _safe_torch_load(filename, weights_only, **kwargs):
62+
if weights_only is None:
63+
if packaging.version.parse(
64+
torch.__version__) >= packaging.version.parse("2.6.0"):
65+
weights_only = True
66+
else:
67+
weights_only = False
68+
69+
if not weights_only:
70+
checkpoint = torch.load(filename, weights_only=False, **kwargs)
71+
else:
72+
# NOTE(stes): This is only supported for torch 2.6+
73+
with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS):
74+
checkpoint = torch.load(filename, weights_only=True, **kwargs)
75+
76+
return checkpoint
77+
78+
5179

5280
def _init_loader(
5381
is_cont: bool,
@@ -1409,15 +1437,22 @@ def save(self,
14091437
def load(cls,
14101438
filename: str,
14111439
backend: Literal["auto", "sklearn", "torch"] = "auto",
1440+
weights_only: bool = None,
14121441
**kwargs) -> "CEBRA":
14131442
"""Load a model from disk.
14141443
14151444
Args:
14161445
filename: The path to the file in which to save the trained model.
14171446
backend: A string identifying the used backend.
1447+
weights_only: Indicates whether unpickler should be restricted to loading only tensors, primitive types,
1448+
dictionaries and any types added via :py:func:`torch.serialization.add_safe_globals`.
1449+
See :py:func:`torch.load` with ``weights_only=True`` for more details. It it recommended to leave this
1450+
at the default value of ``None``, which sets the argument to ``False`` for torch<2.6, and ``True`` for
1451+
higher versions of torch. If you experience issues with loading custom models (specified outside
1452+
of the CEBRA package), you can try to set this to ``False`` if you trust the source of the model.
14181453
kwargs: Optional keyword arguments passed directly to the loader.
14191454
1420-
Return:
1455+
Returns:
14211456
The model to load.
14221457
14231458
Note:
@@ -1427,7 +1462,6 @@ def load(cls,
14271462
For information about the file format please refer to :py:meth:`cebra.CEBRA.save`.
14281463
14291464
Example:
1430-
14311465
>>> import cebra
14321466
>>> import numpy as np
14331467
>>> import tempfile
@@ -1441,16 +1475,14 @@ def load(cls,
14411475
>>> loaded_model = cebra.CEBRA.load(tmp_file)
14421476
>>> embedding = loaded_model.transform(dataset)
14431477
>>> tmp_file.unlink()
1444-
14451478
"""
1446-
14471479
supported_backends = ["auto", "sklearn", "torch"]
14481480
if backend not in supported_backends:
14491481
raise NotImplementedError(
14501482
f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}"
14511483
)
14521484

1453-
checkpoint = torch.load(filename, **kwargs)
1485+
checkpoint = _safe_torch_load(filename, weights_only, **kwargs)
14541486

14551487
if backend == "auto":
14561488
backend = "sklearn" if isinstance(checkpoint, dict) else "torch"

cebra/integrations/sklearn/metrics.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,149 @@ def infonce_loss(
108108
return avg_loss
109109

110110

111+
def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA,
112+
X: Union[npt.NDArray, torch.Tensor],
113+
*y,
114+
session_id: Optional[int] = None,
115+
num_batches: int = 500) -> float:
116+
"""Compute the goodness of fit score on a *single session* dataset on the model.
117+
118+
This function uses the :func:`infonce_loss` function to compute the InfoNCE loss
119+
for a given `cebra_model` and the :func:`infonce_to_goodness_of_fit` function
120+
to derive the goodness of fit from the InfoNCE loss.
121+
122+
Args:
123+
cebra_model: The model to use to compute the InfoNCE loss on the samples.
124+
X: A 2D data matrix, corresponding to a *single session* recording.
125+
y: An arbitrary amount of continuous indices passed as 2D matrices, and up to one
126+
discrete index passed as a 1D array. Each index has to match the length of ``X``.
127+
session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`cebra.CEBRA.num_sessions`
128+
for multisession, set to ``None`` for single session.
129+
num_batches: The number of iterations to consider to evaluate the model on the new data.
130+
Higher values will give a more accurate estimate. Set it to at least 500 iterations.
131+
132+
Returns:
133+
The average GoF score estimated over ``num_batches`` batches from the data distribution.
134+
135+
Related:
136+
:func:`infonce_to_goodness_of_fit`
137+
138+
Example:
139+
140+
>>> import cebra
141+
>>> import numpy as np
142+
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
143+
>>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512)
144+
>>> cebra_model.fit(neural_data)
145+
CEBRA(batch_size=512, max_iterations=10)
146+
>>> gof = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, neural_data)
147+
"""
148+
loss = infonce_loss(cebra_model,
149+
X,
150+
*y,
151+
session_id=session_id,
152+
num_batches=num_batches,
153+
correct_by_batchsize=False)
154+
return infonce_to_goodness_of_fit(loss, cebra_model)
155+
156+
157+
def goodness_of_fit_history(model: cebra_sklearn_cebra.CEBRA) -> np.ndarray:
158+
"""Return the history of the goodness of fit score.
159+
160+
Args:
161+
model: A trained CEBRA model.
162+
163+
Returns:
164+
A numpy array containing the goodness of fit values, measured in bits.
165+
166+
Related:
167+
:func:`infonce_to_goodness_of_fit`
168+
169+
Example:
170+
171+
>>> import cebra
172+
>>> import numpy as np
173+
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
174+
>>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512)
175+
>>> cebra_model.fit(neural_data)
176+
CEBRA(batch_size=512, max_iterations=10)
177+
>>> gof_history = cebra.sklearn.metrics.goodness_of_fit_history(cebra_model)
178+
"""
179+
infonce = np.array(model.state_dict_["log"]["total"])
180+
return infonce_to_goodness_of_fit(infonce, model)
181+
182+
183+
def infonce_to_goodness_of_fit(
184+
infonce: Union[float, np.ndarray],
185+
model: Optional[cebra_sklearn_cebra.CEBRA] = None,
186+
batch_size: Optional[int] = None,
187+
num_sessions: Optional[int] = None) -> Union[float, np.ndarray]:
188+
"""Given a trained CEBRA model, return goodness of fit metric.
189+
190+
The goodness of fit ranges from 0 (lowest meaningful value)
191+
to a positive number with the unit "bits", the higher the
192+
better.
193+
194+
Values lower than 0 bits are possible, but these only occur
195+
due to numerical effects. A perfectly collapsed embedding
196+
(e.g., because the data cannot be fit with the provided
197+
auxiliary variables) will have a goodness of fit of 0.
198+
199+
The conversion between the generalized InfoNCE metric that
200+
CEBRA is trained with and the goodness of fit computed with this
201+
function is
202+
203+
.. math::
204+
205+
S = \\log N - \\text{InfoNCE}
206+
207+
To use this function, either provide a trained CEBRA model or the
208+
batch size and number of sessions.
209+
210+
Args:
211+
infonce: The InfoNCE loss, either a single value or an iterable of values.
212+
model: The trained CEBRA model.
213+
batch_size: The batch size used to train the model.
214+
num_sessions: The number of sessions used to train the model.
215+
216+
Returns:
217+
Numpy array containing the goodness of fit values, measured in bits
218+
219+
Raises:
220+
RuntimeError: If the provided model is not fit to data.
221+
ValueError: If both ``model`` and ``(batch_size, num_sessions)`` are provided.
222+
"""
223+
if model is not None:
224+
if batch_size is not None or num_sessions is not None:
225+
raise ValueError(
226+
"batch_size and num_sessions should not be provided if model is provided."
227+
)
228+
if not hasattr(model, "state_dict_"):
229+
raise RuntimeError("Fit the CEBRA model first.")
230+
if model.batch_size is None:
231+
raise ValueError(
232+
"Computing the goodness of fit is not yet supported for "
233+
"models trained on the full dataset (batchsize = None). ")
234+
batch_size = model.batch_size
235+
num_sessions = model.num_sessions_
236+
if num_sessions is None:
237+
num_sessions = 1
238+
239+
if model.batch_size is None:
240+
raise ValueError(
241+
"Computing the goodness of fit is not yet supported for "
242+
"models trained on the full dataset (batchsize = None). ")
243+
else:
244+
if batch_size is None or num_sessions is None:
245+
raise ValueError(
246+
f"batch_size ({batch_size}) and num_sessions ({num_sessions})"
247+
f"should be provided if model is not provided.")
248+
249+
nats_to_bits = np.log2(np.e)
250+
chance_level = np.log(batch_size * num_sessions)
251+
return (chance_level - infonce) * nats_to_bits
252+
253+
111254
def _consistency_scores(
112255
embeddings: List[Union[npt.NDArray, torch.Tensor]],
113256
datasets: List[Union[int, str]],

cebra/solver/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,8 @@ def fit(
210210
self.decoding(loader, valid_loader))
211211
if save_hook is not None:
212212
save_hook(num_steps, self)
213-
self.save(logdir, f"checkpoint_{num_steps:#07d}.pth")
213+
if logdir is not None:
214+
self.save(logdir, f"checkpoint_{num_steps:#07d}.pth")
214215

215216
def step(self, batch: cebra.data.Batch) -> dict:
216217
"""Perform a single gradient update.

0 commit comments

Comments
 (0)