Skip to content

Commit 32a642d

Browse files
authored
FEAT add scikit-learn wrappers (#20599)
* FEAT add scikit-learn wrappers * import cleanup * run black * linters * lint * add scikit-learn to requirements-common * generate public api * fix tests for sklearn 1.5 * check fixes * skip numpy tests * xfail instead of skip * apply review comments * change names to SKL* and add transformer example * fix API and imports * fix for new sklearn * sklearn1.6 test * review comments and remove random_state * add another skipped test * rename file * change imports * unindent * docstrings
1 parent 8465c3d commit 32a642d

File tree

11 files changed

+806
-0
lines changed

11 files changed

+806
-0
lines changed

keras/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from keras.api import utils
5050
from keras.api import version
5151
from keras.api import visualization
52+
from keras.api import wrappers
5253

5354
# END DO NOT EDIT.
5455

keras/api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from keras.api import tree
3333
from keras.api import utils
3434
from keras.api import visualization
35+
from keras.api import wrappers
3536
from keras.src.backend import Variable
3637
from keras.src.backend import device
3738
from keras.src.backend import name_scope

keras/api/_tf_keras/keras/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from keras.api import tree
2626
from keras.api import utils
2727
from keras.api import visualization
28+
from keras.api import wrappers
2829
from keras.api._tf_keras.keras import backend
2930
from keras.api._tf_keras.keras import layers
3031
from keras.api._tf_keras.keras import losses
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""DO NOT EDIT.
2+
3+
This file was autogenerated. Do not edit it by hand,
4+
since your modifications would be overwritten.
5+
"""
6+
7+
from keras.src.wrappers.sklearn_wrapper import SKLearnClassifier
8+
from keras.src.wrappers.sklearn_wrapper import SKLearnRegressor
9+
from keras.src.wrappers.sklearn_wrapper import SKLearnTransformer

keras/api/wrappers/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""DO NOT EDIT.
2+
3+
This file was autogenerated. Do not edit it by hand,
4+
since your modifications would be overwritten.
5+
"""
6+
7+
from keras.src.wrappers.sklearn_wrapper import SKLearnClassifier
8+
from keras.src.wrappers.sklearn_wrapper import SKLearnRegressor
9+
from keras.src.wrappers.sklearn_wrapper import SKLearnTransformer

keras/src/wrappers/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from keras.src.wrappers.sklearn_wrapper import SKLearnClassifier
2+
from keras.src.wrappers.sklearn_wrapper import SKLearnRegressor
3+
from keras.src.wrappers.sklearn_wrapper import SKLearnTransformer
4+
5+
__all__ = ["SKLearnClassifier", "SKLearnRegressor", "SKLearnTransformer"]

keras/src/wrappers/fixes.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import sklearn
2+
from packaging.version import parse as parse_version
3+
from sklearn import get_config
4+
5+
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)
6+
7+
if sklearn_version < parse_version("1.6"):
8+
9+
def patched_more_tags(estimator, expected_failed_checks):
10+
import copy
11+
12+
from sklearn.utils._tags import _safe_tags
13+
14+
original_tags = copy.deepcopy(_safe_tags(estimator))
15+
16+
def patched_more_tags(self):
17+
original_tags.update({"_xfail_checks": expected_failed_checks})
18+
return original_tags
19+
20+
estimator.__class__._more_tags = patched_more_tags
21+
return estimator
22+
23+
def parametrize_with_checks(
24+
estimators,
25+
*,
26+
legacy: bool = True,
27+
expected_failed_checks=None,
28+
):
29+
# legacy is not supported and ignored
30+
from sklearn.utils.estimator_checks import parametrize_with_checks # noqa: F401, I001
31+
32+
estimators = [
33+
patched_more_tags(estimator, expected_failed_checks(estimator))
34+
for estimator in estimators
35+
]
36+
37+
return parametrize_with_checks(estimators)
38+
else:
39+
from sklearn.utils.estimator_checks import parametrize_with_checks # noqa: F401, I001
40+
41+
42+
def _validate_data(estimator, *args, **kwargs):
43+
"""Validate the input data.
44+
45+
wrapper for sklearn.utils.validation.validate_data or
46+
BaseEstimator._validate_data depending on the scikit-learn version.
47+
48+
TODO: remove when minimum scikit-learn version is 1.6
49+
"""
50+
try:
51+
# scikit-learn >= 1.6
52+
from sklearn.utils.validation import validate_data
53+
54+
return validate_data(estimator, *args, **kwargs)
55+
except ImportError:
56+
return estimator._validate_data(*args, **kwargs)
57+
except:
58+
raise
59+
60+
61+
def type_of_target(y, input_name="", *, raise_unknown=False):
62+
# fix for raise_unknown which is introduced in scikit-learn 1.6
63+
from sklearn.utils.multiclass import type_of_target
64+
65+
def _raise_or_return(target_type):
66+
"""Depending on the value of raise_unknown, either raise an error or
67+
return 'unknown'.
68+
"""
69+
if raise_unknown and target_type == "unknown":
70+
input = input_name if input_name else "data"
71+
raise ValueError(f"Unknown label type for {input}: {y!r}")
72+
else:
73+
return target_type
74+
75+
target_type = type_of_target(y, input_name=input_name)
76+
return _raise_or_return(target_type)
77+
78+
79+
def _routing_enabled():
80+
"""Return whether metadata routing is enabled.
81+
82+
Returns:
83+
enabled : bool
84+
Whether metadata routing is enabled. If the config is not set, it
85+
defaults to False.
86+
87+
TODO: remove when the config key is no longer available in scikit-learn
88+
"""
89+
return get_config().get("enable_metadata_routing", False)
90+
91+
92+
def _raise_for_params(params, owner, method):
93+
"""Raise an error if metadata routing is not enabled and params are passed.
94+
95+
Parameters:
96+
params : dict
97+
The metadata passed to a method.
98+
owner : object
99+
The object to which the method belongs.
100+
method : str
101+
The name of the method, e.g. "fit".
102+
103+
Raises:
104+
ValueError
105+
If metadata routing is not enabled and params are passed.
106+
"""
107+
caller = (
108+
f"{owner.__class__.__name__}.{method}"
109+
if method
110+
else owner.__class__.__name__
111+
)
112+
if not _routing_enabled() and params:
113+
raise ValueError(
114+
f"Passing extra keyword arguments to {caller} is only supported if"
115+
" enable_metadata_routing=True, which you can set using"
116+
" `sklearn.set_config`. See the User Guide"
117+
" <https://scikit-learn.org/stable/metadata_routing.html> for more"
118+
f" details. Extra parameters passed are: {set(params)}"
119+
)

keras/src/wrappers/sklearn_test.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""Tests using Scikit-Learn's bundled estimator_checks."""
2+
3+
from contextlib import contextmanager
4+
5+
import pytest
6+
7+
import keras
8+
from keras.src.backend import floatx
9+
from keras.src.backend import set_floatx
10+
from keras.src.layers import Dense
11+
from keras.src.layers import Input
12+
from keras.src.models import Model
13+
from keras.src.wrappers import SKLearnClassifier
14+
from keras.src.wrappers import SKLearnRegressor
15+
from keras.src.wrappers import SKLearnTransformer
16+
from keras.src.wrappers.fixes import parametrize_with_checks
17+
18+
19+
def dynamic_model(X, y, loss, layers=[10]):
20+
"""Creates a basic MLP classifier dynamically choosing binary/multiclass
21+
classification loss and ouput activations.
22+
"""
23+
n_features_in = X.shape[1]
24+
inp = Input(shape=(n_features_in,))
25+
26+
hidden = inp
27+
for layer_size in layers:
28+
hidden = Dense(layer_size, activation="relu")(hidden)
29+
30+
n_outputs = y.shape[1] if len(y.shape) > 1 else 1
31+
out = [Dense(n_outputs, activation="softmax")(hidden)]
32+
model = Model(inp, out)
33+
model.compile(loss=loss, optimizer="rmsprop")
34+
35+
return model
36+
37+
38+
@contextmanager
39+
def use_floatx(x: str):
40+
"""Context manager to temporarily
41+
set the keras backend precision.
42+
"""
43+
_floatx = floatx()
44+
set_floatx(x)
45+
try:
46+
yield
47+
finally:
48+
set_floatx(_floatx)
49+
50+
51+
EXPECTED_FAILED_CHECKS = {
52+
"SKLearnClassifier": {
53+
"check_classifiers_regression_target": "not an issue in sklearn>=1.6",
54+
"check_parameters_default_constructible": (
55+
"not an issue in sklearn>=1.6"
56+
),
57+
"check_classifiers_one_label_sample_weights": (
58+
"0 sample weight is not ignored"
59+
),
60+
"check_classifiers_classes": (
61+
"with small test cases the estimator returns not all classes "
62+
"sometimes"
63+
),
64+
"check_classifier_data_not_an_array": (
65+
"This test assumes reproducibility in fit."
66+
),
67+
"check_supervised_y_2d": "This test assumes reproducibility in fit.",
68+
"check_fit_idempotent": "This test assumes reproducibility in fit.",
69+
},
70+
"SKLearnRegressor": {
71+
"check_parameters_default_constructible": (
72+
"not an issue in sklearn>=1.6"
73+
),
74+
},
75+
"SKLearnTransformer": {
76+
"check_parameters_default_constructible": (
77+
"not an issue in sklearn>=1.6"
78+
),
79+
},
80+
}
81+
82+
83+
@parametrize_with_checks(
84+
estimators=[
85+
SKLearnClassifier(
86+
model=dynamic_model,
87+
model_kwargs={
88+
"loss": "categorical_crossentropy",
89+
"layers": [20, 20, 20],
90+
},
91+
fit_kwargs={"epochs": 5},
92+
),
93+
SKLearnRegressor(
94+
model=dynamic_model,
95+
model_kwargs={"loss": "mse"},
96+
),
97+
SKLearnTransformer(
98+
model=dynamic_model,
99+
model_kwargs={"loss": "mse"},
100+
),
101+
],
102+
expected_failed_checks=lambda estimator: EXPECTED_FAILED_CHECKS[
103+
type(estimator).__name__
104+
],
105+
)
106+
def test_sklearn_estimator_checks(estimator, check):
107+
"""Checks that can be passed with sklearn's default tolerances
108+
and in a single epoch.
109+
"""
110+
try:
111+
check(estimator)
112+
except Exception as exc:
113+
if keras.config.backend() == "numpy" and (
114+
isinstance(exc, NotImplementedError)
115+
or "NotImplementedError" in str(exc)
116+
):
117+
pytest.xfail("Backend not implemented")
118+
else:
119+
raise

0 commit comments

Comments
 (0)