|
30 | 30 | import pkg_resources |
31 | 31 | import sklearn.utils.validation as sklearn_utils_validation |
32 | 32 | import torch |
| 33 | +import sklearn |
33 | 34 | from sklearn.base import BaseEstimator |
34 | 35 | from sklearn.base import TransformerMixin |
| 36 | +from sklearn.utils.metaestimators import available_if |
35 | 37 | from torch import nn |
36 | 38 |
|
37 | 39 | import cebra.data |
|
41 | 43 | import cebra.models |
42 | 44 | import cebra.solver |
43 | 45 |
|
| 46 | +def check_version(estimator): |
| 47 | + # NOTE(stes): required as a check for the old way of specifying tags |
| 48 | + # https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165 |
| 49 | + from packaging import version |
| 50 | + return version.parse(sklearn.__version__) < version.parse("1.6.dev") |
44 | 51 |
|
45 | 52 | def _init_loader( |
46 | 53 | is_cont: bool, |
@@ -364,7 +371,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": |
364 | 371 | return cebra_ |
365 | 372 |
|
366 | 373 |
|
367 | | -class CEBRA(BaseEstimator, TransformerMixin): |
| 374 | +class CEBRA(TransformerMixin, BaseEstimator): |
368 | 375 | """CEBRA model defined as part of a ``scikit-learn``-like API. |
369 | 376 |
|
370 | 377 | Attributes: |
@@ -1294,6 +1301,15 @@ def fit_transform( |
1294 | 1301 | callback_frequency=callback_frequency) |
1295 | 1302 | return self.transform(X) |
1296 | 1303 |
|
| 1304 | + def __sklearn_tags__(self): |
| 1305 | + # NOTE(stes): from 1.6.dev, this is the new way to specify tags |
| 1306 | + # https://scikit-learn.org/dev/developers/develop.html |
| 1307 | + # https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165 |
| 1308 | + tags = super().__sklearn_tags__() |
| 1309 | + tags.non_deterministic = True |
| 1310 | + return tags |
| 1311 | + |
| 1312 | + @available_if(check_version) |
1297 | 1313 | def _more_tags(self): |
1298 | 1314 | # NOTE(stes): This tag is needed as seeding is not fully implemented in the |
1299 | 1315 | # current version of CEBRA. |
|
0 commit comments