Skip to content

Commit b91ed57

Browse files
authored
Optimize tests speed (#495)
1 parent b7eb38d commit b91ed57

File tree

2 files changed

+13
-13
lines changed
  • skfp/applicability_domain
  • tests/applicability_domain

2 files changed

+13
-13
lines changed

skfp/applicability_domain/knn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,12 @@ def score_samples(self, X: np.ndarray) -> np.ndarray:
199199

200200
return self._get_agg_dists(k_nearest)
201201

202-
def _get_agg_dists(self, k_nearest) -> np.ndarray[float]:
202+
def _get_agg_dists(self, k_nearest) -> np.ndarray:
203203
if self.agg == "mean":
204204
agg_dists = np.mean(k_nearest, axis=1)
205205
elif self.agg == "max":
206206
agg_dists = np.max(k_nearest, axis=1)
207-
elif self.agg == "min":
207+
else: # "min"
208208
agg_dists = np.min(k_nearest, axis=1)
209209

210210
return agg_dists

tests/applicability_domain/knn.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
@pytest.mark.parametrize("agg", ALLOWED_AGGREGATIONS)
1818
def test_inside_knn_ad(metric, agg):
1919
if "binary" in metric:
20-
X_train, X_test = get_data_inside_ad(binarize=True)
20+
X_train, X_test = get_data_inside_ad(n_train=100, n_test=10, binarize=True)
2121
else:
22-
X_train, X_test = get_data_inside_ad()
22+
X_train, X_test = get_data_inside_ad(n_train=100, n_test=10)
2323

2424
ad_checker = KNNADChecker(k=3, agg=agg)
2525
ad_checker.fit(X_train)
@@ -39,9 +39,9 @@ def test_inside_knn_ad(metric, agg):
3939
@pytest.mark.parametrize("agg", ALLOWED_AGGREGATIONS)
4040
def test_outside_knn_ad(metric, agg):
4141
if "binary" in metric:
42-
X_train, X_test = get_data_outside_ad(binarize=True)
42+
X_train, X_test = get_data_outside_ad(n_train=100, n_test=10, binarize=True)
4343
else:
44-
X_train, X_test = get_data_outside_ad()
44+
X_train, X_test = get_data_outside_ad(n_train=100, n_test=10)
4545

4646
ad_checker = KNNADChecker(k=3, metric=metric, agg=agg)
4747
ad_checker.fit(X_train)
@@ -60,9 +60,9 @@ def test_outside_knn_ad(metric, agg):
6060
@pytest.mark.parametrize("agg", ALLOWED_AGGREGATIONS)
6161
def test_knn_different_k_values(metric, agg):
6262
if "binary" in metric:
63-
X_train, X_test = get_data_inside_ad(binarize=True)
63+
X_train, X_test = get_data_inside_ad(n_train=100, n_test=10, binarize=True)
6464
else:
65-
X_train, X_test = get_data_inside_ad()
65+
X_train, X_test = get_data_inside_ad(n_train=100, n_test=10)
6666

6767
# smaller k, stricter check
6868
ad_checker_k1 = KNNADChecker(k=1, metric=metric, agg=agg)
@@ -84,9 +84,9 @@ def test_knn_different_k_values(metric, agg):
8484
def test_knn_pass_y_train(metric, agg):
8585
# smoke test, should not throw errors
8686
if "binary" in metric:
87-
X_train, _ = get_data_inside_ad(binarize=True)
87+
X_train, _ = get_data_inside_ad(n_train=100, n_test=10, binarize=True)
8888
else:
89-
X_train, _ = get_data_inside_ad()
89+
X_train, _ = get_data_inside_ad(n_train=100, n_test=10)
9090

9191
y_train = np.zeros(len(X_train))
9292
ad_checker = KNNADChecker(k=3, metric=metric, agg=agg)
@@ -97,9 +97,9 @@ def test_knn_pass_y_train(metric, agg):
9797
@pytest.mark.parametrize("agg", ALLOWED_AGGREGATIONS)
9898
def test_knn_invalid_k(metric, agg):
9999
if "binary" in metric:
100-
X_train, _ = get_data_inside_ad(binarize=True)
100+
X_train, _ = get_data_inside_ad(n_train=100, n_test=10, binarize=True)
101101
else:
102-
X_train, _ = get_data_inside_ad()
102+
X_train, _ = get_data_inside_ad(n_train=100, n_test=10)
103103

104104
with pytest.raises(
105105
ValueError,
@@ -110,7 +110,7 @@ def test_knn_invalid_k(metric, agg):
110110

111111

112112
def test_knn_invalid_metric():
113-
X_train, _ = get_data_inside_ad()
113+
X_train, _ = get_data_inside_ad(n_train=100)
114114
ad_checker = KNNADChecker(k=3, metric="euclidean")
115115
with pytest.raises(KeyError):
116116
ad_checker.fit(X_train)

0 commit comments

Comments
 (0)