From 01571757387f5bb41fd89190a38a6bf3217c43f9 Mon Sep 17 00:00:00 2001 From: divakaivan Date: Thu, 7 Aug 2025 17:36:11 +0100 Subject: [PATCH] feat: Add predict_proba on SKLearnClassifier --- keras/src/wrappers/sklearn_wrapper.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/keras/src/wrappers/sklearn_wrapper.py b/keras/src/wrappers/sklearn_wrapper.py index 90d36c669792..8103f6a46c26 100644 --- a/keras/src/wrappers/sklearn_wrapper.py +++ b/keras/src/wrappers/sklearn_wrapper.py @@ -278,6 +278,14 @@ def dynamic_model(X, y, loss, layers=[10]): ``` """ + def predict_proba(self, X): + """Predict class probabilities of the input samples X.""" + from sklearn.utils.validation import check_is_fitted + + check_is_fitted(self) + X = _validate_data(self, X, reset=False) + return self.model_.predict(X) + def _process_target(self, y, reset=False): """Classifiers do OHE.""" target_type = type_of_target(y, raise_unknown=True)