|
| 1 | +from collections.abc import Callable |
| 2 | +from numbers import Integral, Real |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +from sklearn.neighbors import NearestNeighbors |
| 6 | +from sklearn.utils._param_validation import Interval, InvalidParameterError, StrOptions |
| 7 | +from sklearn.utils.validation import check_is_fitted, validate_data |
| 8 | + |
| 9 | +from skfp.bases.base_ad_checker import BaseADChecker |
| 10 | +from skfp.distances import ( |
| 11 | + _BULK_METRIC_NAMES as SKFP_BULK_METRIC_NAMES, |
| 12 | +) |
| 13 | +from skfp.distances import ( |
| 14 | + _BULK_METRICS as SKFP_BULK_METRICS, |
| 15 | +) |
| 16 | +from skfp.distances import ( |
| 17 | + _METRIC_NAMES as SKFP_METRIC_NAMES, |
| 18 | +) |
| 19 | +from skfp.distances import ( |
| 20 | + _METRICS as SKFP_METRICS, |
| 21 | +) |
| 22 | + |
| 23 | +METRIC_FUNCTIONS = {**SKFP_METRICS, **SKFP_BULK_METRICS} |
| 24 | +METRIC_NAMES = set(SKFP_METRIC_NAMES) | set(SKFP_BULK_METRIC_NAMES) |
| 25 | + |
| 26 | + |
| 27 | +class KNNADChecker(BaseADChecker): |
| 28 | + r""" |
| 29 | + k-nearest neighbors applicability domain checker. |
| 30 | +
|
| 31 | + This method determines whether a query molecule falls within the applicability |
| 32 | + domain by comparing its distance to k nearest neighbors [1]_ [2]_ [3]_ in the training set, |
| 33 | + using a threshold derived from the training data. |
| 34 | +
|
| 35 | + The applicability domain is defined as one of: |
| 36 | +
|
| 37 | + - the mean distance to k nearest neighbors, |
| 38 | + - the distance to k-th nearest neighbor (max distance), |
| 39 | + - the distance to the closest neighbor from the training set (min distance) |
| 40 | +
|
| 41 | + A threshold is then set at the 95th percentile of these aggregated distances. |
| 42 | + Query molecules with an aggregated distance to their k nearest neighbors below |
| 43 | + this threshold are considered within the applicability domain. |
| 44 | +
|
| 45 | + This implementation supports binary and count Tanimoto similarity metrics. |
| 46 | +
|
| 47 | + Parameters |
| 48 | + ---------- |
| 49 | + k : int, default=1 |
| 50 | + Number of nearest neighbors to consider for distance calculations. |
| 51 | + Must be smaller than the number of training samples. |
| 52 | +
|
| 53 | + metric: Callable or string, default="tanimoto_binary_distance" |
| 54 | + Distance metric to use. |
| 55 | +
|
| 56 | + agg: "mean" or "max" or "min", default="mean" |
| 57 | + Aggregation method for distances to k nearest neigbors: |
| 58 | + - "mean": use the mean distance to k neighbors, |
| 59 | + - "max": use the maximum distance among k neighbors, |
| 60 | + - "min": use the distance to the closest neigbor. |
| 61 | +
|
| 62 | + n_jobs : int, default=None |
| 63 | + The number of jobs to run in parallel. :meth:`transform_x_y` and |
| 64 | + :meth:`transform` are parallelized over the input molecules. ``None`` means 1 |
| 65 | + unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all |
| 66 | + processors. See scikit-learn documentation on ``n_jobs`` for more details. |
| 67 | +
|
| 68 | + verbose : int or dict, default=0 |
| 69 | + Controls the verbosity when filtering molecules. |
| 70 | + If a dictionary is passed, it is treated as kwargs for ``tqdm()``, |
| 71 | + and can be used to control the progress bar. |
| 72 | +
|
| 73 | + References |
| 74 | + ---------- |
| 75 | + .. [1] `Klingspohn, W., Mathea, M., ter Laak, A. et al. |
| 76 | + "Efficiency of different measures for defining the applicability domain of |
| 77 | + classification models." |
| 78 | + J Cheminform 9, 44 (2017) |
| 79 | + <https://doi.org/10.1186/s13321-017-0230-2>`_ |
| 80 | +
|
| 81 | + .. [2] `Harmeling, S., Dornhege G., Tax D., Meinecke F., Müller KR. |
| 82 | + "From outliers to prototypes: Ordering data." |
| 83 | + Neurocomputing, 69, 13, pages 1608-1618, (2006) |
| 84 | + <https://doi.org/10.1016/j.neucom.2005.05.015>`_ |
| 85 | +
|
| 86 | + .. [3] `Kar S., Roy K., Leszczynski J. |
| 87 | + "Applicability Domain: A Step Toward Confident Predictions and Decidability for QSAR Modeling" |
| 88 | + Methods Mol Biol, 1800, pages 141-169, (2018) |
| 89 | + <https://doi.org/10.1007/978-1-4939-7899-1_6>`_ |
| 90 | +
|
| 91 | + Examples |
| 92 | + -------- |
| 93 | + >>> from skfp.applicability_domain import KNNADChecker |
| 94 | + >>> import numpy as np |
| 95 | + >>> X_train_binary = np.array([ |
| 96 | + ... [1, 1, 1], |
| 97 | + ... [0, 1, 1], |
| 98 | + ... [0, 0, 1] |
| 99 | + ... ]) |
| 100 | + >>> X_test_binary = 1 - X_train_binary |
| 101 | + >>> knn_ad_checker_binary = KNNADChecker(k=2, metric="tanimoto_binary_distance", agg="mean") |
| 102 | + >>> knn_ad_checker_binary |
| 103 | + KNNADChecker() |
| 104 | +
|
| 105 | + >>> knn_ad_checker_binary.fit(X_train_binary) |
| 106 | + KNNADChecker() |
| 107 | +
|
| 108 | + >>> knn_ad_checker_binary.predict(X_test_binary) |
| 109 | + array([False, False, False]) |
| 110 | +
|
| 111 | + """ |
| 112 | + |
| 113 | + _parameter_constraints: dict = { |
| 114 | + **BaseADChecker._parameter_constraints, |
| 115 | + "k": [Interval(Integral, 1, None, closed="left")], |
| 116 | + "metric": [callable, StrOptions(METRIC_NAMES)], |
| 117 | + "agg": [StrOptions({"mean", "max", "min"})], |
| 118 | + "threshold": [None, Interval(Real, 0, 1, closed="both")], |
| 119 | + } |
| 120 | + |
| 121 | + def __init__( |
| 122 | + self, |
| 123 | + k: int = 1, |
| 124 | + metric: str | Callable = "tanimoto_binary_distance", |
| 125 | + agg: str = "mean", |
| 126 | + threshold: float = 0.95, |
| 127 | + n_jobs: int | None = None, |
| 128 | + verbose: int | dict = 0, |
| 129 | + ): |
| 130 | + super().__init__( |
| 131 | + n_jobs=n_jobs, |
| 132 | + verbose=verbose, |
| 133 | + ) |
| 134 | + self.k = k |
| 135 | + self.metric = metric |
| 136 | + self.agg = agg |
| 137 | + self.threshold = threshold |
| 138 | + |
| 139 | + def _validate_params(self) -> None: |
| 140 | + super()._validate_params() |
| 141 | + if isinstance(self.metric, str) and self.metric not in METRIC_FUNCTIONS: |
| 142 | + raise InvalidParameterError( |
| 143 | + f"Allowed metrics: {METRIC_NAMES}. Got: {self.metric}" |
| 144 | + ) |
| 145 | + |
| 146 | + def fit( # noqa: D102 |
| 147 | + self, |
| 148 | + X: np.ndarray, |
| 149 | + y: np.ndarray | None = None, # noqa: ARG002 |
| 150 | + ): |
| 151 | + X = validate_data(self, X=X) |
| 152 | + if self.k > X.shape[0]: |
| 153 | + raise ValueError( |
| 154 | + f"k ({self.k}) must be smaller than or equal to the number of training samples ({X.shape[0]})" |
| 155 | + ) |
| 156 | + |
| 157 | + self.X_train_ = X |
| 158 | + self.k_used = 1 if self.agg == "min" else self.k |
| 159 | + |
| 160 | + if callable(self.metric): |
| 161 | + metric_func = self.metric |
| 162 | + elif isinstance(self.metric, str) and self.metric in METRIC_FUNCTIONS: |
| 163 | + metric_func = METRIC_FUNCTIONS[self.metric] |
| 164 | + else: |
| 165 | + raise KeyError( |
| 166 | + f"Unknown metric: {self.metric}. Must be a callable or one of {list(METRIC_FUNCTIONS.keys())}" |
| 167 | + ) |
| 168 | + |
| 169 | + self.knn_ = NearestNeighbors( |
| 170 | + n_neighbors=self.k_used, metric=metric_func, n_jobs=self.n_jobs |
| 171 | + ) |
| 172 | + self.knn_.fit(X) |
| 173 | + k_nearest, _ = self.knn_.kneighbors(X) |
| 174 | + |
| 175 | + agg_dists = self._get_agg_dists(k_nearest) |
| 176 | + self.threshold_ = np.percentile(agg_dists, self.threshold) |
| 177 | + |
| 178 | + def predict(self, X: np.ndarray) -> np.ndarray: # noqa: D102 |
| 179 | + return self.score_samples(X) <= self.threshold_ |
| 180 | + |
| 181 | + def score_samples(self, X: np.ndarray) -> np.ndarray: |
| 182 | + """ |
| 183 | + Calculate the applicability domain score of samples. It is simply a 0/1 |
| 184 | + decision equal to ``.predict()``. |
| 185 | +
|
| 186 | + Parameters |
| 187 | + ---------- |
| 188 | + X : array-like of shape (n_samples, n_features) |
| 189 | + The data matrix. |
| 190 | +
|
| 191 | + Returns |
| 192 | + ------- |
| 193 | + scores : ndarray of shape (n_samples,) |
| 194 | + Applicability domain scores of samples. |
| 195 | + """ |
| 196 | + check_is_fitted(self) |
| 197 | + X = validate_data(self, X=X, reset=False) |
| 198 | + k_nearest, _ = self.knn_.kneighbors(X, n_neighbors=self.k_used) |
| 199 | + |
| 200 | + return self._get_agg_dists(k_nearest) |
| 201 | + |
| 202 | + def _get_agg_dists(self, k_nearest) -> np.ndarray[float]: |
| 203 | + if self.agg == "mean": |
| 204 | + agg_dists = np.mean(k_nearest, axis=1) |
| 205 | + elif self.agg == "max": |
| 206 | + agg_dists = np.max(k_nearest, axis=1) |
| 207 | + elif self.agg == "min": |
| 208 | + agg_dists = np.min(k_nearest, axis=1) |
| 209 | + |
| 210 | + return agg_dists |
0 commit comments