Skip to content

Commit 77b5c85

Browse files
committed
Add support for new __sklearn_tags__
1 parent 36a91c7 commit 77b5c85

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
import pkg_resources
3131
import sklearn.utils.validation as sklearn_utils_validation
3232
import torch
33+
import sklearn
3334
from sklearn.base import BaseEstimator
3435
from sklearn.base import TransformerMixin
36+
from sklearn.utils.metaestimators import available_if
3537
from torch import nn
3638

3739
import cebra.data
@@ -41,6 +43,11 @@
4143
import cebra.models
4244
import cebra.solver
4345

46+
def check_version(estimator):
47+
# NOTE(stes): required as a check for the old way of specifying tags
48+
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
49+
from packaging import version
50+
return version.parse(sklearn.__version__) < version.parse("1.6.dev")
4451

4552
def _init_loader(
4653
is_cont: bool,
@@ -1294,6 +1301,15 @@ def fit_transform(
12941301
callback_frequency=callback_frequency)
12951302
return self.transform(X)
12961303

1304+
def __sklearn_tags__(self):
1305+
# NOTE(stes): from 1.6.dev, this is the new way to specify tags
1306+
# https://scikit-learn.org/dev/developers/develop.html
1307+
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
1308+
tags = super().__sklearn_tags__()
1309+
tags.non_deterministic = True
1310+
return tags
1311+
1312+
@available_if(check_version)
12971313
def _more_tags(self):
12981314
# NOTE(stes): This tag is needed as seeding is not fully implemented in the
12991315
# current version of CEBRA.

0 commit comments

Comments
 (0)