Skip to content

Commit a79c2de

Browse files
authored
Fix deprecation warning force_all_finite -> ensure_all_finite for sklearn>=1.6 (#206)
1 parent 7a4d3fc commit a79c2de

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

cebra/integrations/sklearn/utils.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,26 @@
2222
import warnings
2323

2424
import numpy.typing as npt
25+
import packaging
26+
import sklearn
2527
import sklearn.utils.validation as sklearn_utils_validation
2628
import torch
2729

2830
import cebra.helper
2931

3032

33+
def _sklearn_check_array(array, **kwargs):
34+
# NOTE(stes): See discussion in https://github.com/AdaptiveMotorControlLab/CEBRA/pull/206
35+
# https://scikit-learn.org/1.6/modules/generated/sklearn.utils.check_array.html
36+
# force_all_finite was renamed to ensure_all_finite and will be removed in 1.8.
37+
if packaging.version.parse(
38+
sklearn.__version__) < packaging.version.parse("1.6"):
39+
if "ensure_all_finite" in kwargs:
40+
kwargs["force_all_finite"] = kwargs["ensure_all_finite"]
41+
del kwargs["ensure_all_finite"]
42+
return sklearn_utils_validation.check_array(array, **kwargs)
43+
44+
3145
def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple:
3246
"""Handle deprecated arguments of a function until they are replaced.
3347
@@ -74,15 +88,15 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray:
7488
Returns:
7589
The converted and validated array.
7690
"""
77-
return sklearn_utils_validation.check_array(
91+
return _sklearn_check_array(
7892
X,
7993
accept_sparse=False,
8094
accept_large_sparse=False,
8195
dtype=("float16", "float32", "float64"),
8296
order=None,
8397
copy=False,
84-
force_all_finite=True,
8598
ensure_2d=True,
99+
ensure_all_finite=True,
86100
allow_nd=False,
87101
ensure_min_samples=min_samples,
88102
ensure_min_features=1,
@@ -105,15 +119,15 @@ def check_label_array(y: npt.NDArray, *, min_samples: int):
105119
Returns:
106120
The converted and validated labels.
107121
"""
108-
return sklearn_utils_validation.check_array(
122+
return _sklearn_check_array(
109123
y,
110124
accept_sparse=False,
111125
accept_large_sparse=False,
112126
dtype="numeric",
113127
order=None,
114128
copy=False,
115-
force_all_finite=True,
116129
ensure_2d=False,
130+
ensure_all_finite=True,
117131
allow_nd=False,
118132
ensure_min_samples=min_samples,
119133
)

0 commit comments

Comments
 (0)