|
22 | 22 | import warnings |
23 | 23 |
|
24 | 24 | import numpy.typing as npt |
| 25 | +import packaging |
| 26 | +import sklearn |
25 | 27 | import sklearn.utils.validation as sklearn_utils_validation |
26 | 28 | import torch |
27 | 29 |
|
28 | 30 | import cebra.helper |
29 | 31 |
|
30 | | -from packaging import version |
31 | | -import sklearn |
32 | 32 |
|
33 | 33 | def _check_array_ensure_all_finite(array, **kwargs): |
34 | | - if version.parse(sklearn.__version__) < version.parse("1.8"): |
| 34 | + # NOTE(stes): See discussion in https://github.com/AdaptiveMotorControlLab/CEBRA/pull/206 |
| 35 | + if packaging.version.parse( |
| 36 | + sklearn.__version__) < packaging.version.parse("1.8"): |
35 | 37 | key = "force_all_finite" |
36 | 38 | else: |
37 | 39 | key = "ensure_all_finite" |
38 | 40 | kwargs[key] = True |
39 | 41 | return sklearn_utils_validation.check_array(array, **kwargs) |
40 | 42 |
|
| 43 | + |
41 | 44 | def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple: |
42 | 45 | """Handle deprecated arguments of a function until they are replaced. |
43 | 46 |
|
@@ -85,17 +88,17 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray: |
85 | 88 | The converted and validated array. |
86 | 89 | """ |
87 | 90 | return _check_array_ensure_all_finite( |
88 | | - X, |
89 | | - accept_sparse=False, |
90 | | - accept_large_sparse=False, |
91 | | - dtype=("float16", "float32", "float64"), |
92 | | - order=None, |
93 | | - copy=False, |
94 | | - ensure_2d=True, |
95 | | - allow_nd=False, |
96 | | - ensure_min_samples=min_samples, |
97 | | - ensure_min_features=1, |
98 | | - ) |
| 91 | + X, |
| 92 | + accept_sparse=False, |
| 93 | + accept_large_sparse=False, |
| 94 | + dtype=("float16", "float32", "float64"), |
| 95 | + order=None, |
| 96 | + copy=False, |
| 97 | + ensure_2d=True, |
| 98 | + allow_nd=False, |
| 99 | + ensure_min_samples=min_samples, |
| 100 | + ensure_min_features=1, |
| 101 | + ) |
99 | 102 |
|
100 | 103 |
|
101 | 104 | def check_label_array(y: npt.NDArray, *, min_samples: int): |
|
0 commit comments