Skip to content

Commit 5f46c32

Browse files
authored
Add support for new __sklearn_tags__ (#205)
* Add support for new __sklearn_tags__ * fix inheritance order * Add more tests * fix added test
1 parent 36a91c7 commit 5f46c32

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

.github/workflows/build.yml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,16 @@ jobs:
1919
# as well as selected previous versions on
2020
# https://pytorch.org/get-started/previous-versions/
2121
torch-version: ["2.2.2", "2.4.0"]
22+
sklearn-version: ["latest"]
2223
include:
2324
- os: windows-latest
2425
torch-version: 2.4.0
2526
python-version: "3.10"
27+
sklearn-version: "latest"
28+
- os: ubuntu-latest
29+
torch-version: 2.4.0
30+
python-version: "3.10"
31+
sklearn-version: "legacy"
2632

2733
runs-on: ${{ matrix.os }}
2834

@@ -32,7 +38,7 @@ jobs:
3238
uses: actions/cache@v3
3339
with:
3440
path: ~/.cache/pip
35-
key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}
41+
key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}-sklearn_${{ matrix.sklearn-version }}
3642

3743
- name: Checkout code
3844
uses: actions/checkout@v2
@@ -48,6 +54,11 @@ jobs:
4854
python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu
4955
pip install '.[dev,datasets,integrations]'
5056
57+
- name: Check sklearn legacy version
58+
if: matrix.sklearn-version == 'legacy'
59+
run: |
60+
pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]'
61+
5162
- name: Run the formatter
5263
run: |
5364
make format

cebra/integrations/sklearn/cebra.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
import pkg_resources
3131
import sklearn.utils.validation as sklearn_utils_validation
3232
import torch
33+
import sklearn
3334
from sklearn.base import BaseEstimator
3435
from sklearn.base import TransformerMixin
36+
from sklearn.utils.metaestimators import available_if
3537
from torch import nn
3638

3739
import cebra.data
@@ -41,6 +43,11 @@
4143
import cebra.models
4244
import cebra.solver
4345

46+
def check_version(estimator):
47+
# NOTE(stes): required as a check for the old way of specifying tags
48+
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
49+
from packaging import version
50+
return version.parse(sklearn.__version__) < version.parse("1.6.dev")
4451

4552
def _init_loader(
4653
is_cont: bool,
@@ -364,7 +371,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
364371
return cebra_
365372

366373

367-
class CEBRA(BaseEstimator, TransformerMixin):
374+
class CEBRA(TransformerMixin, BaseEstimator):
368375
"""CEBRA model defined as part of a ``scikit-learn``-like API.
369376
370377
Attributes:
@@ -1294,6 +1301,15 @@ def fit_transform(
12941301
callback_frequency=callback_frequency)
12951302
return self.transform(X)
12961303

1304+
def __sklearn_tags__(self):
1305+
# NOTE(stes): from 1.6.dev, this is the new way to specify tags
1306+
# https://scikit-learn.org/dev/developers/develop.html
1307+
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
1308+
tags = super().__sklearn_tags__()
1309+
tags.non_deterministic = True
1310+
return tags
1311+
1312+
@available_if(check_version)
12971313
def _more_tags(self):
12981314
# NOTE(stes): This tag is needed as seeding is not fully implemented in the
12991315
# current version of CEBRA.

0 commit comments

Comments
 (0)