Skip to content

Commit b1980cd

Browse files
authored
Merge branch 'main' into batched-inference-and-padding
2 parents 0eac868 + 5f46c32 commit b1980cd

19 files changed

+47
-39
lines changed

.github/workflows/build.yml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,16 @@ jobs:
1919
# as well as selected previous versions on
2020
# https://pytorch.org/get-started/previous-versions/
2121
torch-version: ["2.2.2", "2.4.0"]
22+
sklearn-version: ["latest"]
2223
include:
2324
- os: windows-latest
2425
torch-version: 2.4.0
2526
python-version: "3.10"
27+
sklearn-version: "latest"
28+
- os: ubuntu-latest
29+
torch-version: 2.4.0
30+
python-version: "3.10"
31+
sklearn-version: "legacy"
2632

2733
runs-on: ${{ matrix.os }}
2834

@@ -32,7 +38,7 @@ jobs:
3238
uses: actions/cache@v3
3339
with:
3440
path: ~/.cache/pip
35-
key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}
41+
key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}-sklearn_${{ matrix.sklearn-version }}
3642

3743
- name: Checkout code
3844
uses: actions/checkout@v2
@@ -48,6 +54,11 @@ jobs:
4854
python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu
4955
pip install '.[dev,datasets,integrations]'
5056
57+
- name: Check sklearn legacy version
58+
if: matrix.sklearn-version == 'legacy'
59+
run: |
60+
pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]'
61+
5162
- name: Run the formatter
5263
run: |
5364
make format

cebra/integrations/sklearn/cebra.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
import pkg_resources
3131
import sklearn.utils.validation as sklearn_utils_validation
3232
import torch
33+
import sklearn
3334
from sklearn.base import BaseEstimator
3435
from sklearn.base import TransformerMixin
36+
from sklearn.utils.metaestimators import available_if
3537
from torch import nn
3638

3739
import cebra.data
@@ -41,6 +43,11 @@
4143
import cebra.models
4244
import cebra.solver
4345

46+
def check_version(estimator):
47+
# NOTE(stes): required as a check for the old way of specifying tags
48+
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
49+
from packaging import version
50+
return version.parse(sklearn.__version__) < version.parse("1.6.dev")
4451

4552
def _init_loader(
4653
is_cont: bool,
@@ -364,7 +371,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
364371
return cebra_
365372

366373

367-
class CEBRA(BaseEstimator, TransformerMixin):
374+
class CEBRA(TransformerMixin, BaseEstimator):
368375
"""CEBRA model defined as part of a ``scikit-learn``-like API.
369376
370377
Attributes:
@@ -1317,6 +1324,15 @@ def fit_transform(
13171324
callback_frequency=callback_frequency)
13181325
return self.transform(X)
13191326

1327+
def __sklearn_tags__(self):
1328+
# NOTE(stes): from 1.6.dev, this is the new way to specify tags
1329+
# https://scikit-learn.org/dev/developers/develop.html
1330+
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
1331+
tags = super().__sklearn_tags__()
1332+
tags.non_deterministic = True
1333+
return tags
1334+
1335+
@available_if(check_version)
13201336
def _more_tags(self):
13211337
# NOTE(stes): This tag is needed as seeding is not fully implemented in the
13221338
# current version of CEBRA.

conda/cebra_paper.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ dependencies:
3939
- "cebra[dev,integrations,datasets,demos]"
4040
- joblib
4141
- literate-dataclasses
42-
- sklearn
42+
- scikit-learn
4343
- scipy
4444
- torch
4545
- keras==2.3.1

conda/cebra_paper_m1.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ dependencies:
4848
- tensorflow-metal
4949
- joblib
5050
- literate-dataclasses
51-
- sklearn
51+
- scikit-learn
5252
- scipy
5353
- torch
5454
- umap-learn

tests/test_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,5 @@
2121
#
2222
def test_api():
2323
import cebra.distributions
24-
from cebra.distributions import TimedeltaDistribution
2524

2625
cebra.distributions.TimedeltaDistribution

tests/test_cli.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,3 @@
1919
# See the License for the specific language governing permissions and
2020
# limitations under the License.
2121
#
22-
import argparse
23-
24-
import pytest

tests/test_criterions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
# See the License for the specific language governing permissions and
2020
# limitations under the License.
2121
#
22-
import numpy as np
2322
import pytest
2423
import torch
2524
from torch import nn
@@ -294,7 +293,7 @@ def _sample_dist_matrices(seed):
294293

295294

296295
@pytest.mark.parametrize("seed", [42, 4242, 424242])
297-
def test_infonce(seed):
296+
def test_infonce_check_output_parts(seed):
298297
pos_dist, neg_dist = _sample_dist_matrices(seed)
299298

300299
ref_loss, ref_align, ref_uniform = _reference_infonce(pos_dist, neg_dist)

tests/test_datasets.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def test_hippocampus():
9999

100100
@pytest.mark.requires_dataset
101101
def test_monkey():
102-
103102
dataset = cebra.datasets.init(
104103
"area2-bump-pos-active-passive",
105104
path=pathlib.Path(_DEFAULT_DATADIR) / "monkey_reaching_preload_smth_40",
@@ -110,7 +109,6 @@ def test_monkey():
110109

111110
@pytest.mark.requires_dataset
112111
def test_allen():
113-
114112
pytest.skip("Test takes too long")
115113

116114
ca_dataset = cebra.datasets.init("allen-movie-one-ca-VISp-100-train-10-111")

tests/test_demo.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#
2222
import glob
2323
import re
24-
import sys
2524

2625
import pytest
2726

tests/test_distributions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def prepare(N=1000, n=128, d=5, probs=[0.3, 0.1, 0.6], device="cpu"):
4343
continuous = torch.randn(N, d).to(device)
4444

4545
rand = torch.from_numpy(np.random.randint(0, N, (n,))).to(device)
46-
qidx = discrete[rand].to(device)
46+
_ = discrete[rand].to(device)
4747
query = continuous[rand] + 0.1 * torch.randn(n, d).to(device)
4848
query = query.to(device)
4949

@@ -173,7 +173,7 @@ def test_mixed():
173173
discrete, continuous)
174174

175175
reference_idx = distribution.sample_prior(10)
176-
positive_idx = distribution.sample_conditional(reference_idx)
176+
_ = distribution.sample_conditional(reference_idx)
177177

178178
# The conditional distribution p(· | disc, cont) should yield
179179
# samples where the label exactly matches the reference sample.
@@ -193,7 +193,7 @@ def test_continuous(benchmark):
193193
def _test_distribution(dist):
194194
distribution = dist(continuous)
195195
reference_idx = distribution.sample_prior(10)
196-
positive_idx = distribution.sample_conditional(reference_idx)
196+
_ = distribution.sample_conditional(reference_idx)
197197
return distribution
198198

199199
distribution = _test_distribution(

0 commit comments

Comments
 (0)