Skip to content

Commit 3697d1a

Browse files
committed
Source formatting
1 parent 5770dfa commit 3697d1a

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

cebra/integrations/sklearn/utils.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,25 @@
2222
import warnings
2323

2424
import numpy.typing as npt
25+
import packaging
26+
import sklearn
2527
import sklearn.utils.validation as sklearn_utils_validation
2628
import torch
2729

2830
import cebra.helper
2931

30-
from packaging import version
31-
import sklearn
3232

3333
def _check_array_ensure_all_finite(array, **kwargs):
34-
if version.parse(sklearn.__version__) < version.parse("1.8"):
34+
# NOTE(stes): See discussion in https://github.com/AdaptiveMotorControlLab/CEBRA/pull/206
35+
if packaging.version.parse(
36+
sklearn.__version__) < packaging.version.parse("1.8"):
3537
key = "force_all_finite"
3638
else:
3739
key = "ensure_all_finite"
3840
kwargs[key] = True
3941
return sklearn_utils_validation.check_array(array, **kwargs)
4042

43+
4144
def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple:
4245
"""Handle deprecated arguments of a function until they are replaced.
4346
@@ -85,17 +88,17 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray:
8588
The converted and validated array.
8689
"""
8790
return _check_array_ensure_all_finite(
88-
X,
89-
accept_sparse=False,
90-
accept_large_sparse=False,
91-
dtype=("float16", "float32", "float64"),
92-
order=None,
93-
copy=False,
94-
ensure_2d=True,
95-
allow_nd=False,
96-
ensure_min_samples=min_samples,
97-
ensure_min_features=1,
98-
)
91+
X,
92+
accept_sparse=False,
93+
accept_large_sparse=False,
94+
dtype=("float16", "float32", "float64"),
95+
order=None,
96+
copy=False,
97+
ensure_2d=True,
98+
allow_nd=False,
99+
ensure_min_samples=min_samples,
100+
ensure_min_features=1,
101+
)
99102

100103

101104
def check_label_array(y: npt.NDArray, *, min_samples: int):

0 commit comments

Comments
 (0)