Skip to content

Commit 5770dfa

Browse files
icarosaderostes
andauthored
Update cebra/integrations/sklearn/utils.py
Co-authored-by: Steffen Schneider <[email protected]>
1 parent 128257b commit 5770dfa

File tree

1 file changed

+11
-26
lines changed

1 file changed

+11
-26
lines changed

cebra/integrations/sklearn/utils.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -114,32 +114,17 @@ def check_label_array(y: npt.NDArray, *, min_samples: int):
114114
Returns:
115115
The converted and validated labels.
116116
"""
117-
if sklearn_version < version.parse("1.8"):
118-
return sklearn_utils_validation.check_array(
119-
y,
120-
accept_sparse=False,
121-
accept_large_sparse=False,
122-
dtype="numeric",
123-
order=None,
124-
copy=False,
125-
force_all_finite=True,
126-
ensure_2d=False,
127-
allow_nd=False,
128-
ensure_min_samples=min_samples,
129-
)
130-
else:
131-
return sklearn_utils_validation.check_array(
132-
y,
133-
accept_sparse=False,
134-
accept_large_sparse=False,
135-
dtype="numeric",
136-
order=None,
137-
copy=False,
138-
ensure_all_finite=True,
139-
ensure_2d=False,
140-
allow_nd=False,
141-
ensure_min_samples=min_samples,
142-
)
117+
return _check_array_ensure_all_finite(
118+
y,
119+
accept_sparse=False,
120+
accept_large_sparse=False,
121+
dtype="numeric",
122+
order=None,
123+
copy=False,
124+
ensure_2d=False,
125+
allow_nd=False,
126+
ensure_min_samples=min_samples,
127+
)
143128

144129

145130
def check_device(device: str) -> str:

0 commit comments

Comments
 (0)