-
Notifications
You must be signed in to change notification settings - Fork 25
Add AD kNN #469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add AD kNN #469
Changes from 4 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
c77fafe
Add AD kNN
StaniszewskiA fe0e5ae
Revert "Add AD kNN"
StaniszewskiA 19cc847
Revert "Revert "Add AD kNN""
StaniszewskiA 5f5d636
Fix moleculenet tests
StaniszewskiA 95ef4fa
Add support for more metrics, applied requested changes
StaniszewskiA c3b7517
Remove support for bulk metrics, adjusted docs
StaniszewskiA File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,228 @@ | ||
| from collections.abc import Callable | ||
| from numbers import Integral | ||
|
|
||
| import numpy as np | ||
| from sklearn.neighbors import NearestNeighbors | ||
| from sklearn.utils._param_validation import Interval, InvalidParameterError, StrOptions | ||
| from sklearn.utils.validation import check_is_fitted, validate_data | ||
|
|
||
| from skfp.bases.base_ad_checker import BaseADChecker | ||
| from skfp.distances import ( | ||
| tanimoto_binary_distance, | ||
| tanimoto_count_distance, | ||
| ) | ||
|
|
||
| METRIC_FUNCTIONS = { | ||
StaniszewskiA marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| "tanimoto_binary": tanimoto_binary_distance, | ||
| "tanimoto_count": tanimoto_count_distance, | ||
| } | ||
|
|
||
|
|
||
| class KNNADChecker(BaseADChecker): | ||
| r""" | ||
| k-Nearest Neighbor applicability domain checker. | ||
StaniszewskiA marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| This method determines whether a query molecule falls within the applicability | ||
| domain by comparing its distance to k nearest neighbors [1]_ [2]_ in the training set, | ||
StaniszewskiA marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| using a threshold derived from the training data. | ||
|
|
||
| The applicability domain is defined as either: | ||
StaniszewskiA marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| - the mean distance to k nearest neighbors, | ||
| - the max distance among the k nearest neighbors, | ||
StaniszewskiA marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| - the min [3]_ distance among the k nearest neighbors (effectively kNN with k of 1) | ||
StaniszewskiA marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| for each training sample. A threshold is then set at the | ||
StaniszewskiA marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| 95th percentile of these aggregated distances. Query molecules with an aggregated | ||
| distance to their k nearest neighbors below this threshold are considered within | ||
| the applicability domain. | ||
|
|
||
| This implementation supports binary and count Tanimoto similarity metrics. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| k : int | ||
StaniszewskiA marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Number of nearest neighbors to consider for distance calculations. | ||
| Must be smaller than the number of training samples. | ||
|
|
||
| metric: {"tanimoto_binary", "tanimoto_count"}, default="tanimoto_binary" | ||
| Distance metric to use. | ||
|
|
||
| agg: {"mean", "max", "min"}, default="mean" | ||
| Aggregation method for distances to k nearest neigbors: | ||
StaniszewskiA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| - "mean": use the mean distance to k neighbors, | ||
| - "max": use the maximum distance among k neighbors, | ||
| - "min": use the distance to the closest neigbor. | ||
|
|
||
| n_jobs : int, default=None | ||
| The number of jobs to run in parallel. :meth:`transform_x_y` and | ||
| :meth:`transform` are parallelized over the input molecules. ``None`` means 1 | ||
| unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all | ||
| processors. See scikit-learn documentation on ``n_jobs`` for more details. | ||
|
|
||
| verbose : int or dict, default=0 | ||
| Controls the verbosity when filtering molecules. | ||
| If a dictionary is passed, it is treated as kwargs for ``tqdm()``, | ||
| and can be used to control the progress bar. | ||
|
|
||
| References | ||
| ---------- | ||
|
|
||
StaniszewskiA marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| .. [1] `Klingspohn, W., Mathea, M., ter Laak, A. et al. | ||
| "Efficiency of different measures for defining the applicability domain of | ||
| classification models." | ||
| J Cheminform 9, 44 (2017) | ||
| <https://doi.org/10.1186/s13321-017-0230-2>`_ | ||
|
|
||
| .. [2] `Harmeling, S., Dornhege G., Tax D., Meinecke F., Müller KR. | ||
| "From outliers to prototypes: Ordering data." | ||
| Neurocomputing, 69, 13, pages 1608-1618, (2006) | ||
| <https://doi.org/10.1016/j.neucom.2005.05.015>`_ | ||
|
|
||
| .. [3] `Kar S., Roy K., Leszczynski J. | ||
| "Applicability Domain: A Step Toward Confident Predictions and Decidability for QSAR Modeling" | ||
| Methods Mol Biol, 1800, pages 141-169, (2018) | ||
| <https://doi.org/10.1007/978-1-4939-7899-1_6>`_ | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> from skfp.applicability_domain import KNNADChecker | ||
| >>> import numpy as np | ||
| >>> X_train_binary = np.array([ | ||
| ... [1, 1, 1], | ||
| ... [0, 1, 1], | ||
| ... [0, 0, 1] | ||
| ... ]) | ||
| >>> X_test_binary = 1 - X_train_binary | ||
| >>> knn_ad_checker_binary = KNNADChecker(k=2, metric="tanimoto_binary", agg="mean") | ||
| >>> knn_ad_checker_binary | ||
| KNNADChecker() | ||
|
|
||
| >>> knn_ad_checker_binary.fit(X_train_binary) | ||
| KNNADChecker() | ||
|
|
||
| >>> knn_ad_checker_binary.predict(X_test_binary) | ||
| array([False, False, False]) | ||
|
|
||
| >>> X_train_count = np.array([ | ||
StaniszewskiA marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ... [1.2, 2.3], | ||
| ... [3.4, 4.5], | ||
| ... [5.6, 6.7] | ||
| ... ]) | ||
| >>> X_test_count = X_train_count + 10 | ||
| >>> knn_ad_checker_count = KNNADChecker(k=2, metric="tanimoto_count", agg="min") | ||
| >>> knn_ad_checker_count | ||
| KNNADChecker() | ||
|
|
||
| >>> knn_ad_checker_count.fit(X_train_count) | ||
| KNNADChecker() | ||
|
|
||
| >>> knn_ad_checker_count.predict(X_test_count) | ||
| array([False, False, False]) | ||
|
|
||
| """ | ||
|
|
||
| _parameter_constraints: dict = { | ||
| **BaseADChecker._parameter_constraints, | ||
| "k": [Interval(Integral, 1, None, closed="left")], | ||
| "metric": [ | ||
| callable, | ||
| StrOptions(set(METRIC_FUNCTIONS.keys())), | ||
| ], | ||
| } | ||
|
|
||
| def __init__( | ||
| self, | ||
| k: int, | ||
| metric: str | Callable = "tanimoto_binary", | ||
| agg: str = "mean", | ||
| n_jobs: int | None = None, | ||
| verbose: int | dict = 0, | ||
| ): | ||
| super().__init__( | ||
| n_jobs=n_jobs, | ||
| verbose=verbose, | ||
| ) | ||
| self.metric = metric | ||
| self.agg = agg | ||
| self.k = k | ||
|
|
||
| def _validate_params(self) -> None: | ||
| super()._validate_params() | ||
| if isinstance(self.metric, str) and self.metric not in METRIC_FUNCTIONS: | ||
| raise InvalidParameterError( | ||
| f"The metric parameter must be one of Tanimoto variants. " | ||
| f"Allowed Tanimoto metrics: {list(METRIC_FUNCTIONS.keys())}. " | ||
| f"Got: {self.metric}" | ||
| ) | ||
| if isinstance(self.agg, str) and self.agg not in ["mean", "max", "min"]: | ||
| raise InvalidParameterError("Unknown aggregration method.") | ||
StaniszewskiA marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def fit( # noqa: D102 | ||
| self, | ||
| X: np.ndarray, | ||
| y: np.ndarray | None = None, # noqa: ARG002 | ||
| ): | ||
| X = validate_data(self, X=X) | ||
| if self.k >= X.shape[0]: | ||
| raise ValueError( | ||
| f"k ({self.k}) must be smaller than the number of training samples ({X.shape[0]})" | ||
| ) | ||
StaniszewskiA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| k_used = 1 if self.agg == "min" else self.k | ||
|
|
||
| if callable(self.metric): | ||
StaniszewskiA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| metric_func = self.metric | ||
| elif isinstance(self.metric, str) and self.metric in METRIC_FUNCTIONS: | ||
| metric_func = METRIC_FUNCTIONS[self.metric] | ||
| else: | ||
| raise InvalidParameterError( | ||
| f"Unknown metric: {self.metric}. Must be a callable or one of {list(METRIC_FUNCTIONS.keys())}" | ||
| ) | ||
|
|
||
| self.knn_ = NearestNeighbors( | ||
| n_neighbors=k_used, metric=metric_func, n_jobs=self.n_jobs | ||
| ) | ||
StaniszewskiA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.knn_.fit(X) | ||
|
|
||
| dists, _ = self.knn_.kneighbors(X) | ||
|
|
||
| if self.agg == "mean": | ||
| agg_dists = np.mean(dists, axis=1) | ||
| elif self.agg == "max": | ||
| agg_dists = np.max(dists, axis=1) | ||
| elif self.agg == "min": | ||
| agg_dists = np.min(dists, axis=1) | ||
|
|
||
| self.threshold_ = np.percentile(agg_dists, 95) | ||
StaniszewskiA marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def predict(self, X: np.ndarray) -> np.ndarray: # noqa: D102 | ||
| check_is_fitted(self) | ||
| X = validate_data(self, X=X, reset=False) | ||
|
|
||
| k_used = 1 if self.agg == "min" else self.k | ||
| dists, _ = self.knn_.kneighbors(X, n_neighbors=k_used) | ||
| if self.agg == "mean": | ||
| agg_dists = np.mean(dists, axis=1) | ||
| elif self.agg == "max": | ||
| agg_dists = np.max(dists, axis=1) | ||
| elif self.agg == "min": | ||
| agg_dists = np.min(dists, axis=1) | ||
StaniszewskiA marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| return agg_dists <= self.threshold_ | ||
|
|
||
| def score_samples(self, X: np.ndarray) -> np.ndarray: | ||
| """ | ||
| Calculate the applicability domain score of samples. It is simply a 0/1 | ||
| decision equal to ``.predict()``. | ||
StaniszewskiA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Parameters | ||
| ---------- | ||
| X : array-like of shape (n_samples, n_features) | ||
| The data matrix. | ||
|
|
||
| Returns | ||
| ------- | ||
| scores : ndarray of shape (n_samples,) | ||
| Applicability domain scores of samples. | ||
| """ | ||
| return self.predict(X) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| import numpy as np | ||
| import pytest | ||
|
|
||
| from skfp.applicability_domain import KNNADChecker | ||
| from tests.applicability_domain.utils import get_data_inside_ad, get_data_outside_ad | ||
|
|
||
| ALLOWED_METRICS = [ | ||
| "tanimoto_binary", | ||
| "tanimoto_count", | ||
| ] | ||
|
|
||
| ALLOWED_AGGS = ["mean", "max", "min"] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("metric", ALLOWED_METRICS) | ||
| @pytest.mark.parametrize("agg", ALLOWED_AGGS) | ||
| def test_inside_knn_ad(metric, agg): | ||
| if metric == "tanimoto_binary": | ||
| X_train, X_test = get_data_inside_ad(binarize=True) | ||
| else: | ||
| X_train, X_test = get_data_inside_ad() | ||
|
|
||
| ad_checker = KNNADChecker(k=3, agg=agg) | ||
| ad_checker.fit(X_train) | ||
|
|
||
| scores = ad_checker.score_samples(X_test) | ||
| assert scores.shape == (len(X_test),) | ||
|
|
||
| preds = ad_checker.predict(X_test) | ||
| assert isinstance(preds, np.ndarray) | ||
| assert np.issubdtype(preds.dtype, np.bool_) | ||
| assert preds.shape == (len(X_test),) | ||
| assert np.all(preds == 1) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("metric", ALLOWED_METRICS) | ||
| @pytest.mark.parametrize("agg", ALLOWED_AGGS) | ||
| def test_outside_knn_ad(metric, agg): | ||
| if metric == "tanimoto_binary": | ||
| X_train, X_test = get_data_outside_ad(binarize=True) | ||
| else: | ||
| X_train, X_test = get_data_outside_ad() | ||
|
|
||
| ad_checker = KNNADChecker(k=3, metric=metric, agg=agg) | ||
| ad_checker.fit(X_train) | ||
|
|
||
| scores = ad_checker.score_samples(X_test) | ||
| assert np.all(scores >= 0) | ||
|
|
||
| preds = ad_checker.predict(X_test) | ||
| print(preds) | ||
| assert isinstance(preds, np.ndarray) | ||
| assert np.issubdtype(preds.dtype, np.bool_) | ||
| assert preds.shape == (len(X_test),) | ||
| assert np.all(preds == 0) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("metric", ALLOWED_METRICS) | ||
| @pytest.mark.parametrize("agg", ALLOWED_AGGS) | ||
| def test_knn_different_k_values(metric, agg): | ||
| if metric == "tanimoto_binary": | ||
| X_train, X_test = get_data_inside_ad(binarize=True) | ||
| else: | ||
| X_train, X_test = get_data_inside_ad() | ||
|
|
||
| # smaller k, stricter check | ||
| ad_checker_k1 = KNNADChecker(k=1, metric=metric, agg=agg) | ||
| ad_checker_k1.fit(X_train) | ||
| passed_k1 = ad_checker_k1.predict(X_test).sum() | ||
|
|
||
| # larger k, potentially less strict | ||
| ad_checker_k5 = KNNADChecker(k=5, metric=metric, agg=agg) | ||
| ad_checker_k5.fit(X_train) | ||
| passed_k5 = ad_checker_k5.predict(X_test).sum() | ||
|
|
||
| # both should be valid results | ||
| assert isinstance(passed_k1, (int, np.integer)) | ||
| assert isinstance(passed_k5, (int, np.integer)) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("metric", ALLOWED_METRICS) | ||
| @pytest.mark.parametrize("agg", ALLOWED_AGGS) | ||
| def test_knn_pass_y_train(metric, agg): | ||
| # smoke test, should not throw errors | ||
| if metric == "tanimoto_binary": | ||
| X_train, _ = get_data_inside_ad(binarize=True) | ||
| else: | ||
| X_train, _ = get_data_inside_ad() | ||
|
|
||
| y_train = np.zeros(len(X_train)) | ||
| ad_checker = KNNADChecker(k=3, metric=metric, agg=agg) | ||
| ad_checker.fit(X_train, y_train) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("metric", ["mean", "max"]) | ||
| @pytest.mark.parametrize("agg", ALLOWED_AGGS) | ||
| def test_knn_invalid_k(metric, agg): | ||
| if metric == "tanimoto_binary": | ||
| X_train, _ = get_data_inside_ad(binarize=True) | ||
| else: | ||
| X_train, _ = get_data_inside_ad() | ||
|
|
||
| with pytest.raises( | ||
| ValueError, | ||
| match=r"k \(\d+\) must be smaller than the number of training samples \(\d+\)", | ||
| ): | ||
| ad_checker = KNNADChecker(k=len(X_train), metric=metric, agg=agg) | ||
| ad_checker.fit(X_train) | ||
|
|
||
|
|
||
| def test_knn_invalid_metric(): | ||
| X_train, _ = get_data_inside_ad() | ||
| ad_checker = KNNADChecker(k=3, metric="euclidean") | ||
| with pytest.raises(KeyError): | ||
| ad_checker.fit(X_train) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.