diff --git a/keras/src/wrappers/sklearn_wrapper.py b/keras/src/wrappers/sklearn_wrapper.py index 90d36c669792..8103f6a46c26 100644 --- a/keras/src/wrappers/sklearn_wrapper.py +++ b/keras/src/wrappers/sklearn_wrapper.py @@ -278,6 +278,14 @@ def dynamic_model(X, y, loss, layers=[10]): ``` """ + def predict_proba(self, X): + """Predict class probabilities of the input samples X.""" + from sklearn.utils.validation import check_is_fitted + + check_is_fitted(self) + X = _validate_data(self, X, reset=False) + return self.model_.predict(X) + def _process_target(self, y, reset=False): """Classifiers do OHE.""" target_type = type_of_target(y, raise_unknown=True)