2222import warnings
2323
2424import numpy .typing as npt
25+ import packaging
26+ import sklearn
2527import sklearn .utils .validation as sklearn_utils_validation
2628import torch
2729
2830import cebra .helper
2931
3032
33+ def _sklearn_check_array (array , ** kwargs ):
34+ # NOTE(stes): See discussion in https://github.com/AdaptiveMotorControlLab/CEBRA/pull/206
35+ # https://scikit-learn.org/1.6/modules/generated/sklearn.utils.check_array.html
36+ # force_all_finite was renamed to ensure_all_finite and will be removed in 1.8.
37+ if packaging .version .parse (
38+ sklearn .__version__ ) < packaging .version .parse ("1.6" ):
39+ if "ensure_all_finite" in kwargs :
40+ kwargs ["force_all_finite" ] = kwargs ["ensure_all_finite" ]
41+ del kwargs ["ensure_all_finite" ]
42+ return sklearn_utils_validation .check_array (array , ** kwargs )
43+
44+
3145def update_old_param (old : dict , new : dict , kwargs : dict , default ) -> tuple :
3246 """Handle deprecated arguments of a function until they are replaced.
3347
@@ -74,15 +88,15 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray:
7488 Returns:
7589 The converted and validated array.
7690 """
77- return sklearn_utils_validation . check_array (
91+ return _sklearn_check_array (
7892 X ,
7993 accept_sparse = False ,
8094 accept_large_sparse = False ,
8195 dtype = ("float16" , "float32" , "float64" ),
8296 order = None ,
8397 copy = False ,
84- force_all_finite = True ,
8598 ensure_2d = True ,
99+ ensure_all_finite = True ,
86100 allow_nd = False ,
87101 ensure_min_samples = min_samples ,
88102 ensure_min_features = 1 ,
@@ -105,15 +119,15 @@ def check_label_array(y: npt.NDArray, *, min_samples: int):
105119 Returns:
106120 The converted and validated labels.
107121 """
108- return sklearn_utils_validation . check_array (
122+ return _sklearn_check_array (
109123 y ,
110124 accept_sparse = False ,
111125 accept_large_sparse = False ,
112126 dtype = "numeric" ,
113127 order = None ,
114128 copy = False ,
115- force_all_finite = True ,
116129 ensure_2d = False ,
130+ ensure_all_finite = True ,
117131 allow_nd = False ,
118132 ensure_min_samples = min_samples ,
119133 )
0 commit comments