Skip to content

Commit a720f8c

Browse files
authored
Merge branch 'main' into stes/add-split
2 parents 46e50db + ac020a9 commit a720f8c

File tree

9 files changed

+24
-21
lines changed

9 files changed

+24
-21
lines changed

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,12 @@ To receive updates on code releases, please 👀 watch or ⭐️ star this repos
3535
It can jointly use behavioral and neural data in a hypothesis- or discovery-driven manner to produce consistent, high-performance latent spaces. While it is not specific to neural and behavioral data, this is the first domain we used the tool in. This application case is to obtain a consistent representation of latent variables driving activity and behavior, improving decoding accuracy of behavioral variables over standard supervised learning, and obtaining embeddings which are robust to domain shifts.
3636

3737

38-
# Reference
38+
# References
39+
40+
- 📄 **Publication April 2025**:
41+
[Time-series attribution maps with regularized contrastive learning.](https://arxiv.org/abs/2502.12977)
42+
Steffen Schneider, Rodrigo González Laiz, Anastasiia Filipova, Markus Frey, Mackenzie Weygandt Mathis. AISTATS 2025.
43+
3944

4045
- 📄 **Publication May 2023**:
4146
[Learnable latent embeddings for joint behavioural and neural analysis.](https://doi.org/10.1038/s41586-023-06031-6)

cebra/helper.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
import numpy as np
3434
import numpy.typing as npt
35-
import pkg_resources
35+
import packaging.version
3636
import requests
3737
import torch
3838

@@ -75,8 +75,8 @@ def download_file_from_zip_url(url, *, file):
7575

7676
def _is_mps_availabe(torch):
7777
available = False
78-
if pkg_resources.parse_version(
79-
torch.__version__) >= pkg_resources.parse_version("1.12"):
78+
if packaging.version.parse(
79+
torch.__version__) >= packaging.version.parse("1.12"):
8080
if torch.backends.mps.is_available():
8181
if torch.backends.mps.is_built():
8282
available = True
@@ -159,17 +159,17 @@ def requires_package_version(module, version: str):
159159
the required ``version``.
160160
"""
161161

162-
required_version = pkg_resources.parse_version(version)
162+
required_version = packaging.version.parse(version)
163163

164164
def _requires_package_version(function):
165165

166166
@wraps(function)
167167
def wrapper(*args, patched_version=None, **kwargs):
168168
if patched_version is not None:
169-
installed_version = pkg_resources.parse_version(
169+
installed_version = packaging.version.parse(
170170
patched_version) # Use the patched version if provided
171171
else:
172-
installed_version = pkg_resources.parse_version(
172+
installed_version = packaging.version.parse(
173173
module.__version__)
174174

175175
if installed_version < required_version:

cebra/integrations/matplotlib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1235,7 +1235,7 @@ def compare_models(
12351235

12361236
# check the color of the traces
12371237
if color is None:
1238-
cebra_map = plt.get_cmap(cmap)
1238+
cebra_map = matplotlib.colormaps.get_cmap(cmap)
12391239
colors = matplotlib.colors.ListedColormap(
12401240
cebra_map.resampled(n_models)(np.arange(n_models))).colors
12411241
else:

cebra/integrations/plotly.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _define_colorscale(self, cmap: str):
8787
Returns:
8888
colorscale: List of scaled colors to plot the embeddings
8989
"""
90-
colorscale = _convert_cmap2colorscale(matplotlib.cm.get_cmap(cmap))
90+
colorscale = _convert_cmap2colorscale(matplotlib.colormaps.get_cmap(cmap))
9191

9292
return colorscale
9393

cebra/integrations/sklearn/cebra.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import numpy as np
2929
import numpy.typing as npt
3030
import packaging.version
31-
import pkg_resources
31+
import importlib.metadata
3232
import sklearn
3333
import sklearn.utils.validation as sklearn_utils_validation
3434
import torch
@@ -1397,7 +1397,7 @@ def save(self,
13971397
'numpy_version':
13981398
np.__version__,
13991399
'sklearn_version':
1400-
pkg_resources.get_distribution("scikit-learn"
1400+
importlib.metadata.distribution("scikit-learn"
14011401
).version
14021402
}
14031403
}, filename)

setup.cfg

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@ install_requires =
3939
scipy
4040
torch>=2.4.0
4141
tqdm
42-
# NOTE(stes): Remove pin once https://github.com/AdaptiveMotorControlLab/CEBRA/issues/240
43-
# is resolved.
44-
matplotlib<3.11
42+
matplotlib
4543
requests
4644

4745
[options.extras_require]

tests/test_plot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import matplotlib
2626
import matplotlib.pyplot as plt
2727
import numpy as np
28-
import pkg_resources
28+
import packaging.version
2929
import pytest
3030
import torch
3131
from sklearn.exceptions import NotFittedError
@@ -190,8 +190,8 @@ def test_compare_models_with_different_versions(matplotlib_version):
190190
# minimum version of matplotlib
191191
minimum_version = "3.6"
192192

193-
if pkg_resources.parse_version(
194-
matplotlib_version) < pkg_resources.parse_version(minimum_version):
193+
if packaging.version.parse(
194+
matplotlib_version) < packaging.version.parse(minimum_version):
195195
with pytest.raises(ImportError):
196196
cebra_plot.compare_models(models=fitted_models,
197197
patched_version=matplotlib_version)

tests/test_plotly.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
@pytest.mark.parametrize("cmap", ["viridis", "plasma", "inferno", "magma"])
3333
def test_colorscale(cmap):
34-
cmap = matplotlib.cm.get_cmap(cmap)
34+
cmap = matplotlib.colormaps.get_cmap(cmap)
3535
colorscale = cebra_plotly._convert_cmap2colorscale(cmap)
3636
assert isinstance(colorscale, list)
3737

tests/test_sklearn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import _util
2727
import _utils_deprecated
2828
import numpy as np
29-
import pkg_resources
29+
import packaging.version
3030
import pytest
3131
import sklearn.utils.estimator_checks
3232
import torch
@@ -1320,8 +1320,8 @@ def test_check_device():
13201320
with pytest.raises(ValueError):
13211321
cebra_sklearn_utils.check_device(device)
13221322

1323-
if pkg_resources.parse_version(
1324-
torch.__version__) >= pkg_resources.parse_version("1.12"):
1323+
if packaging.version.parse(
1324+
torch.__version__) >= packaging.version.parse("1.12"):
13251325

13261326
device = "mps"
13271327
torch.backends.mps.is_available = lambda: True

0 commit comments

Comments
 (0)