Skip to content

Commit 70b00a6

Browse files
committed
improve sklearn 1.8 compatibility
1 parent edc1ef0 commit 70b00a6

File tree

1 file changed

+30
-3
lines changed

1 file changed

+30
-3
lines changed

pytabkit/models/sklearn/sklearn_base.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pathlib import Path
33
from typing import Dict, Any, Optional, Union, List
44
from warnings import warn
5+
from packaging.version import Version
56

67
import numpy as np
78
import pandas as pd
@@ -49,6 +50,32 @@ def concat_arrays(x1, x2) -> Any:
4950
return np.concatenate([x1, x2], axis=0)
5051

5152

53+
def check_X_y_wrapper(*args, **kwargs):
54+
if Version(sklearn.__version__) >= Version("1.8.0"):
55+
if 'force_all_finite' in kwargs:
56+
kwargs['ensure_all_finite'] = kwargs['force_all_finite']
57+
kwargs['force_all_finite'] = None
58+
else:
59+
if 'ensure_all_finite' in kwargs:
60+
kwargs['force_all_finite'] = kwargs['ensure_all_finite']
61+
kwargs['ensure_all_finite'] = None
62+
63+
check_X_y(*args, **kwargs)
64+
65+
66+
def check_array_wrapper(*args, **kwargs):
67+
if Version(sklearn.__version__) >= Version("1.8.0"):
68+
if 'force_all_finite' in kwargs:
69+
kwargs['ensure_all_finite'] = kwargs['force_all_finite']
70+
kwargs['force_all_finite'] = None
71+
else:
72+
if 'ensure_all_finite' in kwargs:
73+
kwargs['force_all_finite'] = kwargs['ensure_all_finite']
74+
kwargs['ensure_all_finite'] = None
75+
76+
check_array(*args, **kwargs)
77+
78+
5279
class AlgInterfaceEstimator(BaseEstimator):
5380
"""
5481
Base class for wrapping AlgInterface subclasses with a scikit-learn compatible interface.
@@ -138,7 +165,7 @@ def fit(self, X, y, X_val: Optional = None, y_val: Optional = None, val_idxs: Op
138165
"""
139166

140167
# do a first check, this includes to check if X or y are not None before other things are done to them
141-
check_X_y(X, y, force_all_finite='allow-nan', multi_output=True, dtype=None)
168+
check_X_y_wrapper(X, y, force_all_finite='allow-nan', multi_output=True, dtype=None)
142169

143170
# if X is None:
144171
# raise ValueError(f'This estimator requires X to be passed, but X is None')
@@ -184,7 +211,7 @@ def fit(self, X, y, X_val: Optional = None, y_val: Optional = None, val_idxs: Op
184211
y = concat_arrays(y, y_val)
185212

186213
# check again with the validation set concatenated
187-
check_X_y(X, y, force_all_finite='allow-nan', multi_output=True, dtype=None)
214+
check_X_y_wrapper(X, y, force_all_finite='allow-nan', multi_output=True, dtype=None)
188215

189216
if self._is_classification():
190217
# classes_ is overridden later, but this raises an error when y is a regression target, so it is useful
@@ -446,7 +473,7 @@ def _predict_raw(self, X) -> torch.Tensor:
446473

447474
# Input validation
448475
# if isinstance(X, np.ndarray):
449-
check_array(X, force_all_finite='allow-nan', dtype=None)
476+
check_array_wrapper(X, force_all_finite='allow-nan', dtype=None)
450477

451478
x_ds = self.x_converter_.transform(to_df(X))
452479
if torch.any(torch.isnan(x_ds.tensors['x_cont'])):

0 commit comments

Comments
 (0)