3030import cebra .helper
3131
3232
33- def _check_array_ensure_all_finite (array , ** kwargs ):
33+ def _sklearn_check_array (array , ** kwargs ):
3434 # NOTE(stes): See discussion in https://github.com/AdaptiveMotorControlLab/CEBRA/pull/206
3535 if packaging .version .parse (
3636 sklearn .__version__ ) < packaging .version .parse ("1.8" ):
37- key = "force_all_finite"
38- else :
39- key = "ensure_all_finite"
40- kwargs [key ] = True
37+ if "ensure_all_finite" in kwargs :
38+ kwargs ["force_all_finite" ] = kwargs ["ensure_all_finite" ]
39+ del kwargs ["ensure_all_finite" ]
4140 return sklearn_utils_validation .check_array (array , ** kwargs )
4241
4342
@@ -87,14 +86,15 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray:
8786 Returns:
8887 The converted and validated array.
8988 """
90- return _check_array_ensure_all_finite (
89+ return _sklearn_check_array (
9190 X ,
9291 accept_sparse = False ,
9392 accept_large_sparse = False ,
9493 dtype = ("float16" , "float32" , "float64" ),
9594 order = None ,
9695 copy = False ,
9796 ensure_2d = True ,
97+ ensure_all_finite = True ,
9898 allow_nd = False ,
9999 ensure_min_samples = min_samples ,
100100 ensure_min_features = 1 ,
@@ -117,14 +117,15 @@ def check_label_array(y: npt.NDArray, *, min_samples: int):
117117 Returns:
118118 The converted and validated labels.
119119 """
120- return _check_array_ensure_all_finite (
120+ return _sklearn_check_array (
121121 y ,
122122 accept_sparse = False ,
123123 accept_large_sparse = False ,
124124 dtype = "numeric" ,
125125 order = None ,
126126 copy = False ,
127127 ensure_2d = False ,
128+ ensure_all_finite = True ,
128129 allow_nd = False ,
129130 ensure_min_samples = min_samples ,
130131 )
0 commit comments