Skip to content

Commit 2ca1163

Browse files
authored
Added version compatibility for sklearn 1.8+
1 parent 74fc232 commit 2ca1163

File tree

1 file changed

+58
-25
lines changed

1 file changed

+58
-25
lines changed

cebra/integrations/sklearn/utils.py

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
import cebra.helper
2929

30+
from packaging import version
31+
from sklearn import __version__ as sklearn_version
32+
3033

3134
def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple:
3235
"""Handle deprecated arguments of a function until they are replaced.
@@ -74,19 +77,35 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray:
7477
Returns:
7578
The converted and validated array.
7679
"""
77-
return sklearn_utils_validation.check_array(
78-
X,
79-
accept_sparse=False,
80-
accept_large_sparse=False,
81-
dtype=("float16", "float32", "float64"),
82-
order=None,
83-
copy=False,
84-
ensure_all_finite=True,
85-
ensure_2d=True,
86-
allow_nd=False,
87-
ensure_min_samples=min_samples,
88-
ensure_min_features=1,
89-
)
80+
81+
if sklearn_version < version.parse("1.8"):
82+
return sklearn_utils_validation.check_array(
83+
X,
84+
accept_sparse=False,
85+
accept_large_sparse=False,
86+
dtype=("float16", "float32", "float64"),
87+
order=None,
88+
copy=False,
89+
force_all_finite=True,
90+
ensure_2d=True,
91+
allow_nd=False,
92+
ensure_min_samples=min_samples,
93+
ensure_min_features=1,
94+
)
95+
else:
96+
return sklearn_utils_validation.check_array(
97+
X,
98+
accept_sparse=False,
99+
accept_large_sparse=False,
100+
dtype=("float16", "float32", "float64"),
101+
order=None,
102+
copy=False,
103+
ensure_all_finite=True,
104+
ensure_2d=True,
105+
allow_nd=False,
106+
ensure_min_samples=min_samples,
107+
ensure_min_features=1,
108+
)
90109

91110

92111
def check_label_array(y: npt.NDArray, *, min_samples: int):
@@ -105,18 +124,32 @@ def check_label_array(y: npt.NDArray, *, min_samples: int):
105124
Returns:
106125
The converted and validated labels.
107126
"""
108-
return sklearn_utils_validation.check_array(
109-
y,
110-
accept_sparse=False,
111-
accept_large_sparse=False,
112-
dtype="numeric",
113-
order=None,
114-
copy=False,
115-
ensure_all_finite=True,
116-
ensure_2d=False,
117-
allow_nd=False,
118-
ensure_min_samples=min_samples,
119-
)
127+
if sklearn_version < version.parse("1.8"):
128+
return sklearn_utils_validation.check_array(
129+
y,
130+
accept_sparse=False,
131+
accept_large_sparse=False,
132+
dtype="numeric",
133+
order=None,
134+
copy=False,
135+
force_all_finite=True,
136+
ensure_2d=False,
137+
allow_nd=False,
138+
ensure_min_samples=min_samples,
139+
)
140+
else:
141+
return sklearn_utils_validation.check_array(
142+
y,
143+
accept_sparse=False,
144+
accept_large_sparse=False,
145+
dtype="numeric",
146+
order=None,
147+
copy=False,
148+
ensure_all_finite=True,
149+
ensure_2d=False,
150+
allow_nd=False,
151+
ensure_min_samples=min_samples,
152+
)
120153

121154

122155
def check_device(device: str) -> str:

0 commit comments

Comments
 (0)