Skip to content

Commit 69259e7

Browse files
Add AD kNN (#469)
1 parent a2d7ead commit 69259e7

File tree

6 files changed

+359
-10
lines changed

6 files changed

+359
-10
lines changed

docs/modules/applicability_domain.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ Classes for checking applicability domain.
1818
ConvexHullADChecker
1919
DistanceToCentroidADChecker
2020
HotellingT2TestADChecker
21+
KNNADChecker
2122
LeverageADChecker
2223
PCABoundingBoxADChecker

examples/09_molecular_filters.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@
773773
],
774774
"metadata": {
775775
"kernelspec": {
776-
"display_name": "Python 3 (ipykernel)",
776+
"display_name": "Python 3",
777777
"language": "python",
778778
"name": "python3"
779779
},
@@ -787,7 +787,7 @@
787787
"name": "python",
788788
"nbconvert_exporter": "python",
789789
"pygments_lexer": "ipython3",
790-
"version": "3.10.18"
790+
"version": "3.11.4"
791791
}
792792
},
793793
"nbformat": 4,

skfp/applicability_domain/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from .convex_hull import ConvexHullADChecker
33
from .distance_to_centroid import DistanceToCentroidADChecker
44
from .hotelling_t2_test import HotellingT2TestADChecker
5+
from .knn import KNNADChecker
56
from .leverage import LeverageADChecker
67
from .pca_bounding_box import PCABoundingBoxADChecker

skfp/applicability_domain/knn.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
from collections.abc import Callable
2+
from numbers import Integral, Real
3+
4+
import numpy as np
5+
from sklearn.neighbors import NearestNeighbors
6+
from sklearn.utils._param_validation import Interval, InvalidParameterError, StrOptions
7+
from sklearn.utils.validation import check_is_fitted, validate_data
8+
9+
from skfp.bases.base_ad_checker import BaseADChecker
10+
from skfp.distances import (
11+
_BULK_METRIC_NAMES as SKFP_BULK_METRIC_NAMES,
12+
)
13+
from skfp.distances import (
14+
_BULK_METRICS as SKFP_BULK_METRICS,
15+
)
16+
from skfp.distances import (
17+
_METRIC_NAMES as SKFP_METRIC_NAMES,
18+
)
19+
from skfp.distances import (
20+
_METRICS as SKFP_METRICS,
21+
)
22+
23+
METRIC_FUNCTIONS = {**SKFP_METRICS, **SKFP_BULK_METRICS}
24+
METRIC_NAMES = set(SKFP_METRIC_NAMES) | set(SKFP_BULK_METRIC_NAMES)
25+
26+
27+
class KNNADChecker(BaseADChecker):
28+
r"""
29+
k-nearest neighbors applicability domain checker.
30+
31+
This method determines whether a query molecule falls within the applicability
32+
domain by comparing its distance to k nearest neighbors [1]_ [2]_ [3]_ in the training set,
33+
using a threshold derived from the training data.
34+
35+
The applicability domain is defined as one of:
36+
37+
- the mean distance to k nearest neighbors,
38+
- the distance to k-th nearest neighbor (max distance),
39+
- the distance to the closest neighbor from the training set (min distance)
40+
41+
A threshold is then set at the 95th percentile of these aggregated distances.
42+
Query molecules with an aggregated distance to their k nearest neighbors below
43+
this threshold are considered within the applicability domain.
44+
45+
This implementation supports binary and count Tanimoto similarity metrics.
46+
47+
Parameters
48+
----------
49+
k : int, default=1
50+
Number of nearest neighbors to consider for distance calculations.
51+
Must be smaller than the number of training samples.
52+
53+
metric: Callable or string, default="tanimoto_binary_distance"
54+
Distance metric to use.
55+
56+
agg: "mean" or "max" or "min", default="mean"
57+
Aggregation method for distances to k nearest neigbors:
58+
- "mean": use the mean distance to k neighbors,
59+
- "max": use the maximum distance among k neighbors,
60+
- "min": use the distance to the closest neigbor.
61+
62+
n_jobs : int, default=None
63+
The number of jobs to run in parallel. :meth:`transform_x_y` and
64+
:meth:`transform` are parallelized over the input molecules. ``None`` means 1
65+
unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all
66+
processors. See scikit-learn documentation on ``n_jobs`` for more details.
67+
68+
verbose : int or dict, default=0
69+
Controls the verbosity when filtering molecules.
70+
If a dictionary is passed, it is treated as kwargs for ``tqdm()``,
71+
and can be used to control the progress bar.
72+
73+
References
74+
----------
75+
.. [1] `Klingspohn, W., Mathea, M., ter Laak, A. et al.
76+
"Efficiency of different measures for defining the applicability domain of
77+
classification models."
78+
J Cheminform 9, 44 (2017)
79+
<https://doi.org/10.1186/s13321-017-0230-2>`_
80+
81+
.. [2] `Harmeling, S., Dornhege G., Tax D., Meinecke F., Müller KR.
82+
"From outliers to prototypes: Ordering data."
83+
Neurocomputing, 69, 13, pages 1608-1618, (2006)
84+
<https://doi.org/10.1016/j.neucom.2005.05.015>`_
85+
86+
.. [3] `Kar S., Roy K., Leszczynski J.
87+
"Applicability Domain: A Step Toward Confident Predictions and Decidability for QSAR Modeling"
88+
Methods Mol Biol, 1800, pages 141-169, (2018)
89+
<https://doi.org/10.1007/978-1-4939-7899-1_6>`_
90+
91+
Examples
92+
--------
93+
>>> from skfp.applicability_domain import KNNADChecker
94+
>>> import numpy as np
95+
>>> X_train_binary = np.array([
96+
... [1, 1, 1],
97+
... [0, 1, 1],
98+
... [0, 0, 1]
99+
... ])
100+
>>> X_test_binary = 1 - X_train_binary
101+
>>> knn_ad_checker_binary = KNNADChecker(k=2, metric="tanimoto_binary_distance", agg="mean")
102+
>>> knn_ad_checker_binary
103+
KNNADChecker()
104+
105+
>>> knn_ad_checker_binary.fit(X_train_binary)
106+
KNNADChecker()
107+
108+
>>> knn_ad_checker_binary.predict(X_test_binary)
109+
array([False, False, False])
110+
111+
"""
112+
113+
_parameter_constraints: dict = {
114+
**BaseADChecker._parameter_constraints,
115+
"k": [Interval(Integral, 1, None, closed="left")],
116+
"metric": [callable, StrOptions(METRIC_NAMES)],
117+
"agg": [StrOptions({"mean", "max", "min"})],
118+
"threshold": [None, Interval(Real, 0, 1, closed="both")],
119+
}
120+
121+
def __init__(
122+
self,
123+
k: int = 1,
124+
metric: str | Callable = "tanimoto_binary_distance",
125+
agg: str = "mean",
126+
threshold: float = 0.95,
127+
n_jobs: int | None = None,
128+
verbose: int | dict = 0,
129+
):
130+
super().__init__(
131+
n_jobs=n_jobs,
132+
verbose=verbose,
133+
)
134+
self.k = k
135+
self.metric = metric
136+
self.agg = agg
137+
self.threshold = threshold
138+
139+
def _validate_params(self) -> None:
140+
super()._validate_params()
141+
if isinstance(self.metric, str) and self.metric not in METRIC_FUNCTIONS:
142+
raise InvalidParameterError(
143+
f"Allowed metrics: {METRIC_NAMES}. Got: {self.metric}"
144+
)
145+
146+
def fit( # noqa: D102
147+
self,
148+
X: np.ndarray,
149+
y: np.ndarray | None = None, # noqa: ARG002
150+
):
151+
X = validate_data(self, X=X)
152+
if self.k > X.shape[0]:
153+
raise ValueError(
154+
f"k ({self.k}) must be smaller than or equal to the number of training samples ({X.shape[0]})"
155+
)
156+
157+
self.X_train_ = X
158+
self.k_used = 1 if self.agg == "min" else self.k
159+
160+
if callable(self.metric):
161+
metric_func = self.metric
162+
elif isinstance(self.metric, str) and self.metric in METRIC_FUNCTIONS:
163+
metric_func = METRIC_FUNCTIONS[self.metric]
164+
else:
165+
raise KeyError(
166+
f"Unknown metric: {self.metric}. Must be a callable or one of {list(METRIC_FUNCTIONS.keys())}"
167+
)
168+
169+
self.knn_ = NearestNeighbors(
170+
n_neighbors=self.k_used, metric=metric_func, n_jobs=self.n_jobs
171+
)
172+
self.knn_.fit(X)
173+
k_nearest, _ = self.knn_.kneighbors(X)
174+
175+
agg_dists = self._get_agg_dists(k_nearest)
176+
self.threshold_ = np.percentile(agg_dists, self.threshold)
177+
178+
def predict(self, X: np.ndarray) -> np.ndarray: # noqa: D102
179+
return self.score_samples(X) <= self.threshold_
180+
181+
def score_samples(self, X: np.ndarray) -> np.ndarray:
182+
"""
183+
Calculate the applicability domain score of samples. It is simply a 0/1
184+
decision equal to ``.predict()``.
185+
186+
Parameters
187+
----------
188+
X : array-like of shape (n_samples, n_features)
189+
The data matrix.
190+
191+
Returns
192+
-------
193+
scores : ndarray of shape (n_samples,)
194+
Applicability domain scores of samples.
195+
"""
196+
check_is_fitted(self)
197+
X = validate_data(self, X=X, reset=False)
198+
k_nearest, _ = self.knn_.kneighbors(X, n_neighbors=self.k_used)
199+
200+
return self._get_agg_dists(k_nearest)
201+
202+
def _get_agg_dists(self, k_nearest) -> np.ndarray[float]:
203+
if self.agg == "mean":
204+
agg_dists = np.mean(k_nearest, axis=1)
205+
elif self.agg == "max":
206+
agg_dists = np.max(k_nearest, axis=1)
207+
elif self.agg == "min":
208+
agg_dists = np.min(k_nearest, axis=1)
209+
210+
return agg_dists

tests/applicability_domain/knn.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import numpy as np
2+
import pytest
3+
4+
from skfp.applicability_domain import KNNADChecker
5+
from tests.applicability_domain.utils import get_data_inside_ad, get_data_outside_ad
6+
7+
ALLOWED_METRICS = [
8+
"tanimoto_binary_distance",
9+
"tanimoto_count_distance",
10+
]
11+
12+
ALLOWED_AGGS = ["mean", "max", "min"]
13+
14+
15+
@pytest.mark.parametrize("metric", ALLOWED_METRICS)
16+
@pytest.mark.parametrize("agg", ALLOWED_AGGS)
17+
def test_inside_knn_ad(metric, agg):
18+
if "binary" in metric:
19+
X_train, X_test = get_data_inside_ad(binarize=True)
20+
else:
21+
X_train, X_test = get_data_inside_ad()
22+
23+
ad_checker = KNNADChecker(k=3, agg=agg)
24+
ad_checker.fit(X_train)
25+
26+
scores = ad_checker.score_samples(X_test)
27+
assert scores.shape == (len(X_test),)
28+
29+
preds = ad_checker.predict(X_test)
30+
assert isinstance(preds, np.ndarray)
31+
assert np.issubdtype(preds.dtype, np.bool_)
32+
assert preds.shape == (len(X_test),)
33+
assert np.all(preds == 1)
34+
35+
36+
@pytest.mark.parametrize("metric", ALLOWED_METRICS)
37+
@pytest.mark.parametrize("agg", ALLOWED_AGGS)
38+
def test_outside_knn_ad(metric, agg):
39+
if "binary" in metric:
40+
X_train, X_test = get_data_outside_ad(binarize=True)
41+
else:
42+
X_train, X_test = get_data_outside_ad()
43+
44+
ad_checker = KNNADChecker(k=3, metric=metric, agg=agg)
45+
ad_checker.fit(X_train)
46+
47+
scores = ad_checker.score_samples(X_test)
48+
assert np.all(scores >= 0)
49+
50+
preds = ad_checker.predict(X_test)
51+
print(preds)
52+
assert isinstance(preds, np.ndarray)
53+
assert np.issubdtype(preds.dtype, np.bool_)
54+
assert preds.shape == (len(X_test),)
55+
assert np.all(preds == 0)
56+
57+
58+
@pytest.mark.parametrize("metric", ALLOWED_METRICS)
59+
@pytest.mark.parametrize("agg", ALLOWED_AGGS)
60+
def test_knn_different_k_values(metric, agg):
61+
if "binary" in metric:
62+
X_train, X_test = get_data_inside_ad(binarize=True)
63+
else:
64+
X_train, X_test = get_data_inside_ad()
65+
66+
# smaller k, stricter check
67+
ad_checker_k1 = KNNADChecker(k=1, metric=metric, agg=agg)
68+
ad_checker_k1.fit(X_train)
69+
passed_k1 = ad_checker_k1.predict(X_test).sum()
70+
71+
# larger k, potentially less strict
72+
ad_checker_k5 = KNNADChecker(k=5, metric=metric, agg=agg)
73+
ad_checker_k5.fit(X_train)
74+
passed_k5 = ad_checker_k5.predict(X_test).sum()
75+
76+
# both should be valid results
77+
assert isinstance(passed_k1, (int, np.integer))
78+
assert isinstance(passed_k5, (int, np.integer))
79+
80+
81+
@pytest.mark.parametrize("metric", ALLOWED_METRICS)
82+
@pytest.mark.parametrize("agg", ALLOWED_AGGS)
83+
def test_knn_pass_y_train(metric, agg):
84+
# smoke test, should not throw errors
85+
if "binary" in metric:
86+
X_train, _ = get_data_inside_ad(binarize=True)
87+
else:
88+
X_train, _ = get_data_inside_ad()
89+
90+
y_train = np.zeros(len(X_train))
91+
ad_checker = KNNADChecker(k=3, metric=metric, agg=agg)
92+
ad_checker.fit(X_train, y_train)
93+
94+
95+
@pytest.mark.parametrize("metric", ALLOWED_METRICS)
96+
@pytest.mark.parametrize("agg", ALLOWED_AGGS)
97+
def test_knn_invalid_k(metric, agg):
98+
if "binary" in metric:
99+
X_train, _ = get_data_inside_ad(binarize=True)
100+
else:
101+
X_train, _ = get_data_inside_ad()
102+
103+
with pytest.raises(
104+
ValueError,
105+
match=r"k \(\d+\) must be smaller than or equal to the number of training samples \(\d+\)",
106+
):
107+
ad_checker = KNNADChecker(k=len(X_train) + 1, metric=metric, agg=agg)
108+
ad_checker.fit(X_train)
109+
110+
111+
def test_knn_invalid_metric():
112+
X_train, _ = get_data_inside_ad()
113+
ad_checker = KNNADChecker(k=3, metric="euclidean")
114+
with pytest.raises(KeyError):
115+
ad_checker.fit(X_train)

0 commit comments

Comments
 (0)