Skip to content

Commit fb0593d

Browse files
committed
improve backwards compatibility implementation
1 parent 3697d1a commit fb0593d

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

cebra/integrations/sklearn/utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,13 @@
3030
import cebra.helper
3131

3232

33-
def _check_array_ensure_all_finite(array, **kwargs):
33+
def _sklearn_check_array(array, **kwargs):
3434
# NOTE(stes): See discussion in https://github.com/AdaptiveMotorControlLab/CEBRA/pull/206
3535
if packaging.version.parse(
3636
sklearn.__version__) < packaging.version.parse("1.8"):
37-
key = "force_all_finite"
38-
else:
39-
key = "ensure_all_finite"
40-
kwargs[key] = True
37+
if "ensure_all_finite" in kwargs:
38+
kwargs["force_all_finite"] = kwargs["ensure_all_finite"]
39+
del kwargs["ensure_all_finite"]
4140
return sklearn_utils_validation.check_array(array, **kwargs)
4241

4342

@@ -87,14 +86,15 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray:
8786
Returns:
8887
The converted and validated array.
8988
"""
90-
return _check_array_ensure_all_finite(
89+
return _sklearn_check_array(
9190
X,
9291
accept_sparse=False,
9392
accept_large_sparse=False,
9493
dtype=("float16", "float32", "float64"),
9594
order=None,
9695
copy=False,
9796
ensure_2d=True,
97+
ensure_all_finite=True,
9898
allow_nd=False,
9999
ensure_min_samples=min_samples,
100100
ensure_min_features=1,
@@ -117,14 +117,15 @@ def check_label_array(y: npt.NDArray, *, min_samples: int):
117117
Returns:
118118
The converted and validated labels.
119119
"""
120-
return _check_array_ensure_all_finite(
120+
return _sklearn_check_array(
121121
y,
122122
accept_sparse=False,
123123
accept_large_sparse=False,
124124
dtype="numeric",
125125
order=None,
126126
copy=False,
127127
ensure_2d=False,
128+
ensure_all_finite=True,
128129
allow_nd=False,
129130
ensure_min_samples=min_samples,
130131
)

0 commit comments

Comments
 (0)