Skip to content

Add TOPKAT and PROB-STD methods with tests#480

Merged
j-adamczyk merged 11 commits intomasterfrom
ad_probstd_topkat
Aug 19, 2025
Merged

Add TOPKAT and PROB-STD methods with tests#480
j-adamczyk merged 11 commits intomasterfrom
ad_probstd_topkat

Conversation

@Kacper-Kozubowski
Copy link
Contributor

@Kacper-Kozubowski Kacper-Kozubowski commented Aug 8, 2025

Changes

Added PROB-STD and TOPKAT applicability domain (AD) checkers, and provided tests. Part of #424

Checklist before requesting a review

  • Docstrings added/updated in public functions and classes
  • Tests added, reasonable test coverage (at least ~90%, make test-coverage)
  • Sphinx docs added/updated and render properly (make docs and see docs/_build/index.html)

distribution that lies on the wrong side of the classification threshold (0.5).

This approach requires a fitted ensemble model exposing the ``estimators_``
attribute (e.g., RandomForestRegressor or BaggingRegressor), where each
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add backticks `` for model names. Also, just RandomForestRegressor is enough as an example

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we should also support classifiers with .predict_proba() method here

References
----------
.. [1] `Klingspohn, W., Mathea, M., ter Laak, A. et al.
Efficiency of different measures for defining the applicability
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add quotes " for paper name, and move journal name to next line

def _compute_prob_std(self, X: np.ndarray) -> np.ndarray:
X = validate_data(self, X=X, reset=False)

preds = np.array([est.predict(X) for est in self.model.estimators_]).T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why transpose this?

from skfp.bases.base_ad_checker import BaseADChecker


class TopKatADChecker(BaseADChecker):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TOPKAT, all capital letters

and a weighted distance (dOPS) from the center is computed.

Samples are considered in-domain if their dOPS is below a threshold. By default,
this threshold is computed as ``5 * D / (2 * N)``, where:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use :math: instead of backticks for math formulas, in entire docstring

Comment on lines +96 to +98
self.S_ = (2 * X - self.X_max_ - self.X_min_) / np.where(
(self.X_max_ - self.X_min_) != 0, (self.X_max_ - self.X_min_), 1.0
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Break this into separate variables. You are using max - min at least 3 times here also


threshold = self.threshold
if threshold is None:
threshold = (5 * self.num_dims) / (2 * self.num_points)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One empty line after if for readability

self.num_points = X.shape[0]
self.num_dims = X.shape[1]

self.S_ = (2 * X - self.X_max_ - self.X_min_) / np.where(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

S_ might be a confusing name. Especially for future maintainers without mathematical background. Please add code comments that briefly explain what we are computing. It's very technical so there's no need to include this information in documentation

Copy link
Member

@my-alaska my-alaska Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if I'm seeing correctly, S_ is not used in other methods. I think there's no need to make it a member of the class. Correct me if I'm wrong

either ``.predict(X)`` or ``.predict_proba(X)`` method on each sub-estimator.
If not provided, a default :class:`~sklearn.ensemble.RandomForestRegressor` will be created.

threshold : float, default=0.2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this threshold?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just a heuristic I assumed to be a generally good starting point. Should we use something like 0.5 or 1.0 instead?

Comment on lines +22 to +25
This approach supports both regression models (using ``.predict(X)`` with outputs
interpretable as positive-class probabilities in [0, 1], e.g., regressors trained
on binary targets) and binary classifiers (using ``.predict_proba(X)`` and the
probability of the positive class). The ensemble model must expose the ``estimators_``
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too long text in parentheses. Just make this regular sentences

):
if self.model is None:
X, y = validate_data(self, X, y, ensure_2d=False)
self.model_ = RandomForestRegressor(n_estimators=10, random_state=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not set n_estimators, rely on default sklearn value

Comment on lines +146 to +149
if preds.shape[2] == 2:
preds = preds[:, :, 1] # shape: (n_estimators, n_samples)
else:
raise ValueError("Only binary classifiers are supported.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to validate_params

Comment on lines +22 to +23
- ``D`` is the number of input features,
- ``N`` is the number of training samples.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:math:D and similar for N

Comment on lines +47 to +49
.. [1] Gombar, V. K. (1996).
Method and apparatus for validation of model-based predictions.
U.S. Patent No. 6,036,349. Washington, DC: U.S. Patent and Trademark Office.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backticks, quotation marks around name, add link to Google Patents page, similar to other citations, remove year from author list

# TOPKAT S-space: feature-wise scaling of X to [-1, 1].
# Avoid division by zero: where range==0, denom=1 => scaled value will be 0.
self.denom_ = np.where((self.range_) != 0, (self.range_), 1.0)
S = (2 * X - self.X_max_ - self.X_min_) / self.denom_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you can also use range_ right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

range_ is X_max_ - X_min_, but here we essentially have - (X_max + X_min) so we can't directly replace it with range_.

X = validate_data(self, X=X, reset=False)

# Apply the same S-space transform as in fit().
Ssample = (2 * X - self.X_max_ - self.X_min_) / self.denom_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.range_ can be used here, right?

j-adamczyk
j-adamczyk previously approved these changes Aug 13, 2025
j-adamczyk
j-adamczyk previously approved these changes Aug 13, 2025
my-alaska
my-alaska previously approved these changes Aug 15, 2025
@j-adamczyk j-adamczyk dismissed stale reviews from my-alaska and themself via 2f4d79e August 15, 2025 18:07
@j-adamczyk j-adamczyk self-requested a review August 15, 2025 18:08
j-adamczyk
j-adamczyk previously approved these changes Aug 15, 2025
@j-adamczyk j-adamczyk merged commit db8de97 into master Aug 19, 2025
13 checks passed
@j-adamczyk j-adamczyk deleted the ad_probstd_topkat branch August 19, 2025 15:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants