2727
2828import cebra .helper
2929
30+ from packaging import version
31+ from sklearn import __version__ as sklearn_version
32+
3033
3134def update_old_param (old : dict , new : dict , kwargs : dict , default ) -> tuple :
3235 """Handle deprecated arguments of a function until they are replaced.
@@ -74,19 +77,35 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray:
7477 Returns:
7578 The converted and validated array.
7679 """
77- return sklearn_utils_validation .check_array (
78- X ,
79- accept_sparse = False ,
80- accept_large_sparse = False ,
81- dtype = ("float16" , "float32" , "float64" ),
82- order = None ,
83- copy = False ,
84- ensure_all_finite = True ,
85- ensure_2d = True ,
86- allow_nd = False ,
87- ensure_min_samples = min_samples ,
88- ensure_min_features = 1 ,
89- )
80+
81+ if sklearn_version < version .parse ("1.8" ):
82+ return sklearn_utils_validation .check_array (
83+ X ,
84+ accept_sparse = False ,
85+ accept_large_sparse = False ,
86+ dtype = ("float16" , "float32" , "float64" ),
87+ order = None ,
88+ copy = False ,
89+ force_all_finite = True ,
90+ ensure_2d = True ,
91+ allow_nd = False ,
92+ ensure_min_samples = min_samples ,
93+ ensure_min_features = 1 ,
94+ )
95+ else :
96+ return sklearn_utils_validation .check_array (
97+ X ,
98+ accept_sparse = False ,
99+ accept_large_sparse = False ,
100+ dtype = ("float16" , "float32" , "float64" ),
101+ order = None ,
102+ copy = False ,
103+ ensure_all_finite = True ,
104+ ensure_2d = True ,
105+ allow_nd = False ,
106+ ensure_min_samples = min_samples ,
107+ ensure_min_features = 1 ,
108+ )
90109
91110
92111def check_label_array (y : npt .NDArray , * , min_samples : int ):
@@ -105,18 +124,32 @@ def check_label_array(y: npt.NDArray, *, min_samples: int):
105124 Returns:
106125 The converted and validated labels.
107126 """
108- return sklearn_utils_validation .check_array (
109- y ,
110- accept_sparse = False ,
111- accept_large_sparse = False ,
112- dtype = "numeric" ,
113- order = None ,
114- copy = False ,
115- ensure_all_finite = True ,
116- ensure_2d = False ,
117- allow_nd = False ,
118- ensure_min_samples = min_samples ,
119- )
127+ if sklearn_version < version .parse ("1.8" ):
128+ return sklearn_utils_validation .check_array (
129+ y ,
130+ accept_sparse = False ,
131+ accept_large_sparse = False ,
132+ dtype = "numeric" ,
133+ order = None ,
134+ copy = False ,
135+ force_all_finite = True ,
136+ ensure_2d = False ,
137+ allow_nd = False ,
138+ ensure_min_samples = min_samples ,
139+ )
140+ else :
141+ return sklearn_utils_validation .check_array (
142+ y ,
143+ accept_sparse = False ,
144+ accept_large_sparse = False ,
145+ dtype = "numeric" ,
146+ order = None ,
147+ copy = False ,
148+ ensure_all_finite = True ,
149+ ensure_2d = False ,
150+ allow_nd = False ,
151+ ensure_min_samples = min_samples ,
152+ )
120153
121154
122155def check_device (device : str ) -> str :
0 commit comments