File tree Expand file tree Collapse file tree 1 file changed +16
-0
lines changed
cebra/integrations/sklearn Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Original file line number Diff line number Diff line change 3030import pkg_resources
3131import sklearn .utils .validation as sklearn_utils_validation
3232import torch
33+ import sklearn
3334from sklearn .base import BaseEstimator
3435from sklearn .base import TransformerMixin
36+ from sklearn .utils .metaestimators import available_if
3537from torch import nn
3638
3739import cebra .data
4143import cebra .models
4244import cebra .solver
4345
46+ def check_version (estimator ):
47+ # NOTE(stes): required as a check for the old way of specifying tags
48+ # https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
49+ from packaging import version
50+ return version .parse (sklearn .__version__ ) < version .parse ("1.6.dev" )
4451
4552def _init_loader (
4653 is_cont : bool ,
@@ -1294,6 +1301,15 @@ def fit_transform(
12941301 callback_frequency = callback_frequency )
12951302 return self .transform (X )
12961303
1304+ def __sklearn_tags__ (self ):
1305+ # NOTE(stes): from 1.6.dev, this is the new way to specify tags
1306+ # https://scikit-learn.org/dev/developers/develop.html
1307+ # https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
1308+ tags = super ().__sklearn_tags__ ()
1309+ tags .non_deterministic = True
1310+ return tags
1311+
1312+ @available_if (check_version )
12971313 def _more_tags (self ):
12981314 # NOTE(stes): This tag is needed as seeding is not fully implemented in the
12991315 # current version of CEBRA.
You can’t perform that action at this time.
0 commit comments