-
Notifications
You must be signed in to change notification settings - Fork 25
Add AD kNN #469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add AD kNN #469
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
c77fafe
Add AD kNN
StaniszewskiA fe0e5ae
Revert "Add AD kNN"
StaniszewskiA 19cc847
Revert "Revert "Add AD kNN""
StaniszewskiA 5f5d636
Fix moleculenet tests
StaniszewskiA 95ef4fa
Add support for more metrics, applied requested changes
StaniszewskiA c3b7517
Remove support for bulk metrics, adjusted docs
StaniszewskiA File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,210 @@ | ||
| from collections.abc import Callable | ||
| from numbers import Integral, Real | ||
|
|
||
| import numpy as np | ||
| from sklearn.neighbors import NearestNeighbors | ||
| from sklearn.utils._param_validation import Interval, InvalidParameterError, StrOptions | ||
| from sklearn.utils.validation import check_is_fitted, validate_data | ||
|
|
||
| from skfp.bases.base_ad_checker import BaseADChecker | ||
| from skfp.distances import ( | ||
| _BULK_METRIC_NAMES as SKFP_BULK_METRIC_NAMES, | ||
| ) | ||
| from skfp.distances import ( | ||
| _BULK_METRICS as SKFP_BULK_METRICS, | ||
| ) | ||
| from skfp.distances import ( | ||
| _METRIC_NAMES as SKFP_METRIC_NAMES, | ||
| ) | ||
| from skfp.distances import ( | ||
| _METRICS as SKFP_METRICS, | ||
| ) | ||
|
|
||
| METRIC_FUNCTIONS = {**SKFP_METRICS, **SKFP_BULK_METRICS} | ||
| METRIC_NAMES = set(SKFP_METRIC_NAMES) | set(SKFP_BULK_METRIC_NAMES) | ||
|
|
||
|
|
||
| class KNNADChecker(BaseADChecker): | ||
| r""" | ||
| k-nearest neighbors applicability domain checker. | ||
|
|
||
| This method determines whether a query molecule falls within the applicability | ||
| domain by comparing its distance to k nearest neighbors [1]_ [2]_ [3]_ in the training set, | ||
| using a threshold derived from the training data. | ||
|
|
||
| The applicability domain is defined as one of: | ||
|
|
||
| - the mean distance to k nearest neighbors, | ||
| - the distance to k-th nearest neighbor (max distance), | ||
| - the distance to the closest neighbor from the training set (min distance) | ||
|
|
||
| A threshold is then set at the 95th percentile of these aggregated distances. | ||
| Query molecules with an aggregated distance to their k nearest neighbors below | ||
| this threshold are considered within the applicability domain. | ||
|
|
||
| This implementation supports binary and count Tanimoto similarity metrics. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| k : int, default=1 | ||
| Number of nearest neighbors to consider for distance calculations. | ||
| Must be smaller than the number of training samples. | ||
|
|
||
| metric: Callable or string, default="tanimoto_binary_distance" | ||
| Distance metric to use. | ||
|
|
||
| agg: "mean" or "max" or "min", default="mean" | ||
| Aggregation method for distances to k nearest neigbors: | ||
StaniszewskiA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| - "mean": use the mean distance to k neighbors, | ||
| - "max": use the maximum distance among k neighbors, | ||
| - "min": use the distance to the closest neigbor. | ||
|
|
||
| n_jobs : int, default=None | ||
| The number of jobs to run in parallel. :meth:`transform_x_y` and | ||
| :meth:`transform` are parallelized over the input molecules. ``None`` means 1 | ||
| unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all | ||
| processors. See scikit-learn documentation on ``n_jobs`` for more details. | ||
|
|
||
| verbose : int or dict, default=0 | ||
| Controls the verbosity when filtering molecules. | ||
| If a dictionary is passed, it is treated as kwargs for ``tqdm()``, | ||
| and can be used to control the progress bar. | ||
|
|
||
| References | ||
| ---------- | ||
| .. [1] `Klingspohn, W., Mathea, M., ter Laak, A. et al. | ||
| "Efficiency of different measures for defining the applicability domain of | ||
| classification models." | ||
| J Cheminform 9, 44 (2017) | ||
| <https://doi.org/10.1186/s13321-017-0230-2>`_ | ||
|
|
||
| .. [2] `Harmeling, S., Dornhege G., Tax D., Meinecke F., Müller KR. | ||
| "From outliers to prototypes: Ordering data." | ||
| Neurocomputing, 69, 13, pages 1608-1618, (2006) | ||
| <https://doi.org/10.1016/j.neucom.2005.05.015>`_ | ||
|
|
||
| .. [3] `Kar S., Roy K., Leszczynski J. | ||
| "Applicability Domain: A Step Toward Confident Predictions and Decidability for QSAR Modeling" | ||
| Methods Mol Biol, 1800, pages 141-169, (2018) | ||
| <https://doi.org/10.1007/978-1-4939-7899-1_6>`_ | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> from skfp.applicability_domain import KNNADChecker | ||
| >>> import numpy as np | ||
| >>> X_train_binary = np.array([ | ||
| ... [1, 1, 1], | ||
| ... [0, 1, 1], | ||
| ... [0, 0, 1] | ||
| ... ]) | ||
| >>> X_test_binary = 1 - X_train_binary | ||
| >>> knn_ad_checker_binary = KNNADChecker(k=2, metric="tanimoto_binary_distance", agg="mean") | ||
| >>> knn_ad_checker_binary | ||
| KNNADChecker() | ||
|
|
||
| >>> knn_ad_checker_binary.fit(X_train_binary) | ||
| KNNADChecker() | ||
|
|
||
| >>> knn_ad_checker_binary.predict(X_test_binary) | ||
| array([False, False, False]) | ||
|
|
||
| """ | ||
|
|
||
| _parameter_constraints: dict = { | ||
| **BaseADChecker._parameter_constraints, | ||
| "k": [Interval(Integral, 1, None, closed="left")], | ||
| "metric": [callable, StrOptions(METRIC_NAMES)], | ||
| "agg": [StrOptions({"mean", "max", "min"})], | ||
| "threshold": [None, Interval(Real, 0, 1, closed="both")], | ||
| } | ||
|
|
||
| def __init__( | ||
| self, | ||
| k: int = 1, | ||
| metric: str | Callable = "tanimoto_binary_distance", | ||
| agg: str = "mean", | ||
| threshold: float = 0.95, | ||
| n_jobs: int | None = None, | ||
| verbose: int | dict = 0, | ||
| ): | ||
| super().__init__( | ||
| n_jobs=n_jobs, | ||
| verbose=verbose, | ||
| ) | ||
| self.k = k | ||
| self.metric = metric | ||
| self.agg = agg | ||
| self.threshold = threshold | ||
|
|
||
| def _validate_params(self) -> None: | ||
| super()._validate_params() | ||
| if isinstance(self.metric, str) and self.metric not in METRIC_FUNCTIONS: | ||
| raise InvalidParameterError( | ||
| f"Allowed metrics: {METRIC_NAMES}. Got: {self.metric}" | ||
| ) | ||
|
|
||
| def fit( # noqa: D102 | ||
| self, | ||
| X: np.ndarray, | ||
| y: np.ndarray | None = None, # noqa: ARG002 | ||
| ): | ||
| X = validate_data(self, X=X) | ||
| if self.k > X.shape[0]: | ||
| raise ValueError( | ||
| f"k ({self.k}) must be smaller than or equal to the number of training samples ({X.shape[0]})" | ||
| ) | ||
|
|
||
| self.X_train_ = X | ||
| self.k_used = 1 if self.agg == "min" else self.k | ||
|
|
||
| if callable(self.metric): | ||
StaniszewskiA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| metric_func = self.metric | ||
| elif isinstance(self.metric, str) and self.metric in METRIC_FUNCTIONS: | ||
| metric_func = METRIC_FUNCTIONS[self.metric] | ||
| else: | ||
| raise KeyError( | ||
| f"Unknown metric: {self.metric}. Must be a callable or one of {list(METRIC_FUNCTIONS.keys())}" | ||
| ) | ||
|
|
||
| self.knn_ = NearestNeighbors( | ||
| n_neighbors=self.k_used, metric=metric_func, n_jobs=self.n_jobs | ||
| ) | ||
| self.knn_.fit(X) | ||
| k_nearest, _ = self.knn_.kneighbors(X) | ||
|
|
||
| agg_dists = self._get_agg_dists(k_nearest) | ||
| self.threshold_ = np.percentile(agg_dists, self.threshold) | ||
|
|
||
| def predict(self, X: np.ndarray) -> np.ndarray: # noqa: D102 | ||
| return self.score_samples(X) <= self.threshold_ | ||
|
|
||
| def score_samples(self, X: np.ndarray) -> np.ndarray: | ||
| """ | ||
| Calculate the applicability domain score of samples. It is simply a 0/1 | ||
| decision equal to ``.predict()``. | ||
StaniszewskiA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Parameters | ||
| ---------- | ||
| X : array-like of shape (n_samples, n_features) | ||
| The data matrix. | ||
|
|
||
| Returns | ||
| ------- | ||
| scores : ndarray of shape (n_samples,) | ||
| Applicability domain scores of samples. | ||
| """ | ||
| check_is_fitted(self) | ||
| X = validate_data(self, X=X, reset=False) | ||
| k_nearest, _ = self.knn_.kneighbors(X, n_neighbors=self.k_used) | ||
|
|
||
| return self._get_agg_dists(k_nearest) | ||
|
|
||
| def _get_agg_dists(self, k_nearest) -> np.ndarray[float]: | ||
| if self.agg == "mean": | ||
| agg_dists = np.mean(k_nearest, axis=1) | ||
| elif self.agg == "max": | ||
| agg_dists = np.max(k_nearest, axis=1) | ||
| elif self.agg == "min": | ||
| agg_dists = np.min(k_nearest, axis=1) | ||
|
|
||
| return agg_dists | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| import numpy as np | ||
| import pytest | ||
|
|
||
| from skfp.applicability_domain import KNNADChecker | ||
| from tests.applicability_domain.utils import get_data_inside_ad, get_data_outside_ad | ||
|
|
||
| ALLOWED_METRICS = [ | ||
| "tanimoto_binary_distance", | ||
| "tanimoto_count_distance", | ||
| ] | ||
|
|
||
| ALLOWED_AGGS = ["mean", "max", "min"] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("metric", ALLOWED_METRICS) | ||
| @pytest.mark.parametrize("agg", ALLOWED_AGGS) | ||
| def test_inside_knn_ad(metric, agg): | ||
| if "binary" in metric: | ||
| X_train, X_test = get_data_inside_ad(binarize=True) | ||
| else: | ||
| X_train, X_test = get_data_inside_ad() | ||
|
|
||
| ad_checker = KNNADChecker(k=3, agg=agg) | ||
| ad_checker.fit(X_train) | ||
|
|
||
| scores = ad_checker.score_samples(X_test) | ||
| assert scores.shape == (len(X_test),) | ||
|
|
||
| preds = ad_checker.predict(X_test) | ||
| assert isinstance(preds, np.ndarray) | ||
| assert np.issubdtype(preds.dtype, np.bool_) | ||
| assert preds.shape == (len(X_test),) | ||
| assert np.all(preds == 1) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("metric", ALLOWED_METRICS) | ||
| @pytest.mark.parametrize("agg", ALLOWED_AGGS) | ||
| def test_outside_knn_ad(metric, agg): | ||
| if "binary" in metric: | ||
| X_train, X_test = get_data_outside_ad(binarize=True) | ||
| else: | ||
| X_train, X_test = get_data_outside_ad() | ||
|
|
||
| ad_checker = KNNADChecker(k=3, metric=metric, agg=agg) | ||
| ad_checker.fit(X_train) | ||
|
|
||
| scores = ad_checker.score_samples(X_test) | ||
| assert np.all(scores >= 0) | ||
|
|
||
| preds = ad_checker.predict(X_test) | ||
| print(preds) | ||
| assert isinstance(preds, np.ndarray) | ||
| assert np.issubdtype(preds.dtype, np.bool_) | ||
| assert preds.shape == (len(X_test),) | ||
| assert np.all(preds == 0) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("metric", ALLOWED_METRICS) | ||
| @pytest.mark.parametrize("agg", ALLOWED_AGGS) | ||
| def test_knn_different_k_values(metric, agg): | ||
| if "binary" in metric: | ||
| X_train, X_test = get_data_inside_ad(binarize=True) | ||
| else: | ||
| X_train, X_test = get_data_inside_ad() | ||
|
|
||
| # smaller k, stricter check | ||
| ad_checker_k1 = KNNADChecker(k=1, metric=metric, agg=agg) | ||
| ad_checker_k1.fit(X_train) | ||
| passed_k1 = ad_checker_k1.predict(X_test).sum() | ||
|
|
||
| # larger k, potentially less strict | ||
| ad_checker_k5 = KNNADChecker(k=5, metric=metric, agg=agg) | ||
| ad_checker_k5.fit(X_train) | ||
| passed_k5 = ad_checker_k5.predict(X_test).sum() | ||
|
|
||
| # both should be valid results | ||
| assert isinstance(passed_k1, (int, np.integer)) | ||
| assert isinstance(passed_k5, (int, np.integer)) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("metric", ALLOWED_METRICS) | ||
| @pytest.mark.parametrize("agg", ALLOWED_AGGS) | ||
| def test_knn_pass_y_train(metric, agg): | ||
| # smoke test, should not throw errors | ||
| if "binary" in metric: | ||
| X_train, _ = get_data_inside_ad(binarize=True) | ||
| else: | ||
| X_train, _ = get_data_inside_ad() | ||
|
|
||
| y_train = np.zeros(len(X_train)) | ||
| ad_checker = KNNADChecker(k=3, metric=metric, agg=agg) | ||
| ad_checker.fit(X_train, y_train) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("metric", ALLOWED_METRICS) | ||
| @pytest.mark.parametrize("agg", ALLOWED_AGGS) | ||
| def test_knn_invalid_k(metric, agg): | ||
| if "binary" in metric: | ||
| X_train, _ = get_data_inside_ad(binarize=True) | ||
| else: | ||
| X_train, _ = get_data_inside_ad() | ||
|
|
||
| with pytest.raises( | ||
| ValueError, | ||
| match=r"k \(\d+\) must be smaller than or equal to the number of training samples \(\d+\)", | ||
| ): | ||
| ad_checker = KNNADChecker(k=len(X_train) + 1, metric=metric, agg=agg) | ||
| ad_checker.fit(X_train) | ||
|
|
||
|
|
||
| def test_knn_invalid_metric(): | ||
| X_train, _ = get_data_inside_ad() | ||
| ad_checker = KNNADChecker(k=3, metric="euclidean") | ||
| with pytest.raises(KeyError): | ||
| ad_checker.fit(X_train) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.