|
2 | 2 | from pathlib import Path |
3 | 3 | from typing import Dict, Any, Optional, Union, List |
4 | 4 | from warnings import warn |
| 5 | +from packaging.version import Version |
5 | 6 |
|
6 | 7 | import numpy as np |
7 | 8 | import pandas as pd |
@@ -49,6 +50,32 @@ def concat_arrays(x1, x2) -> Any: |
49 | 50 | return np.concatenate([x1, x2], axis=0) |
50 | 51 |
|
51 | 52 |
|
| 53 | +def check_X_y_wrapper(*args, **kwargs): |
| 54 | + if Version(sklearn.__version__) >= Version("1.8.0"): |
| 55 | + if 'force_all_finite' in kwargs: |
| 56 | + kwargs['ensure_all_finite'] = kwargs['force_all_finite'] |
| 57 | + kwargs['force_all_finite'] = None |
| 58 | + else: |
| 59 | + if 'ensure_all_finite' in kwargs: |
| 60 | + kwargs['force_all_finite'] = kwargs['ensure_all_finite'] |
| 61 | + kwargs['ensure_all_finite'] = None |
| 62 | + |
| 63 | + check_X_y(*args, **kwargs) |
| 64 | + |
| 65 | + |
| 66 | +def check_array_wrapper(*args, **kwargs): |
| 67 | + if Version(sklearn.__version__) >= Version("1.8.0"): |
| 68 | + if 'force_all_finite' in kwargs: |
| 69 | + kwargs['ensure_all_finite'] = kwargs['force_all_finite'] |
| 70 | + kwargs['force_all_finite'] = None |
| 71 | + else: |
| 72 | + if 'ensure_all_finite' in kwargs: |
| 73 | + kwargs['force_all_finite'] = kwargs['ensure_all_finite'] |
| 74 | + kwargs['ensure_all_finite'] = None |
| 75 | + |
| 76 | + check_array(*args, **kwargs) |
| 77 | + |
| 78 | + |
52 | 79 | class AlgInterfaceEstimator(BaseEstimator): |
53 | 80 | """ |
54 | 81 | Base class for wrapping AlgInterface subclasses with a scikit-learn compatible interface. |
@@ -138,7 +165,7 @@ def fit(self, X, y, X_val: Optional = None, y_val: Optional = None, val_idxs: Op |
138 | 165 | """ |
139 | 166 |
|
140 | 167 | # do a first check, this includes to check if X or y are not None before other things are done to them |
141 | | - check_X_y(X, y, force_all_finite='allow-nan', multi_output=True, dtype=None) |
| 168 | + check_X_y_wrapper(X, y, force_all_finite='allow-nan', multi_output=True, dtype=None) |
142 | 169 |
|
143 | 170 | # if X is None: |
144 | 171 | # raise ValueError(f'This estimator requires X to be passed, but X is None') |
@@ -184,7 +211,7 @@ def fit(self, X, y, X_val: Optional = None, y_val: Optional = None, val_idxs: Op |
184 | 211 | y = concat_arrays(y, y_val) |
185 | 212 |
|
186 | 213 | # check again with the validation set concatenated |
187 | | - check_X_y(X, y, force_all_finite='allow-nan', multi_output=True, dtype=None) |
| 214 | + check_X_y_wrapper(X, y, force_all_finite='allow-nan', multi_output=True, dtype=None) |
188 | 215 |
|
189 | 216 | if self._is_classification(): |
190 | 217 | # classes_ is overridden later, but this raises an error when y is a regression target, so it is useful |
@@ -446,7 +473,7 @@ def _predict_raw(self, X) -> torch.Tensor: |
446 | 473 |
|
447 | 474 | # Input validation |
448 | 475 | # if isinstance(X, np.ndarray): |
449 | | - check_array(X, force_all_finite='allow-nan', dtype=None) |
| 476 | + check_array_wrapper(X, force_all_finite='allow-nan', dtype=None) |
450 | 477 |
|
451 | 478 | x_ds = self.x_converter_.transform(to_df(X)) |
452 | 479 | if torch.any(torch.isnan(x_ds.tensors['x_cont'])): |
|
0 commit comments