Skip to content

feat: Add a predict_proba method on SKLearnClassifier #21556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions keras/src/wrappers/sklearn_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,14 @@ def dynamic_model(X, y, loss, layers=[10]):
```
"""

def predict_proba(self, X):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem here is that since the model is configurable, we have no way to know whether the model outputs probabilities or not. This method serves no additional purpose over just predict().

Copy link
Author

@divakaivan divakaivan Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet Could you elaborate, please? I'm not sure I understand your comment. In the case when the user expects probas they will get probas. The only difference between this predict_proba and predict is that the target is not transformed back.

If the user expects probabilities, then they will get them. Although predict_proba might not always return proper probabilities, its inclusion allows users to interoperate with sklearn workflows that expect it. Some examples are in the original issue request.

"""Predict class probabilities of the input samples X."""
from sklearn.utils.validation import check_is_fitted

check_is_fitted(self)
X = _validate_data(self, X, reset=False)
return self.model_.predict(X)

def _process_target(self, y, reset=False):
"""Classifiers do OHE."""
target_type = type_of_target(y, raise_unknown=True)
Expand Down