Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions skfp/applicability_domain/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,12 @@ def score_samples(self, X: np.ndarray) -> np.ndarray:

return self._get_agg_dists(k_nearest)

def _get_agg_dists(self, k_nearest) -> np.ndarray[float]:
def _get_agg_dists(self, k_nearest) -> np.ndarray:
if self.agg == "mean":
agg_dists = np.mean(k_nearest, axis=1)
elif self.agg == "max":
agg_dists = np.max(k_nearest, axis=1)
elif self.agg == "min":
else: # "min"
agg_dists = np.min(k_nearest, axis=1)

return agg_dists
22 changes: 11 additions & 11 deletions tests/applicability_domain/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
@pytest.mark.parametrize("agg", ALLOWED_AGGREGATIONS)
def test_inside_knn_ad(metric, agg):
if "binary" in metric:
X_train, X_test = get_data_inside_ad(binarize=True)
X_train, X_test = get_data_inside_ad(n_train=100, n_test=10, binarize=True)
else:
X_train, X_test = get_data_inside_ad()
X_train, X_test = get_data_inside_ad(n_train=100, n_test=10)

ad_checker = KNNADChecker(k=3, agg=agg)
ad_checker.fit(X_train)
Expand All @@ -39,9 +39,9 @@ def test_inside_knn_ad(metric, agg):
@pytest.mark.parametrize("agg", ALLOWED_AGGREGATIONS)
def test_outside_knn_ad(metric, agg):
if "binary" in metric:
X_train, X_test = get_data_outside_ad(binarize=True)
X_train, X_test = get_data_outside_ad(n_train=100, n_test=10, binarize=True)
else:
X_train, X_test = get_data_outside_ad()
X_train, X_test = get_data_outside_ad(n_train=100, n_test=10)

ad_checker = KNNADChecker(k=3, metric=metric, agg=agg)
ad_checker.fit(X_train)
Expand All @@ -60,9 +60,9 @@ def test_outside_knn_ad(metric, agg):
@pytest.mark.parametrize("agg", ALLOWED_AGGREGATIONS)
def test_knn_different_k_values(metric, agg):
if "binary" in metric:
X_train, X_test = get_data_inside_ad(binarize=True)
X_train, X_test = get_data_inside_ad(n_train=100, n_test=10, binarize=True)
else:
X_train, X_test = get_data_inside_ad()
X_train, X_test = get_data_inside_ad(n_train=100, n_test=10)

# smaller k, stricter check
ad_checker_k1 = KNNADChecker(k=1, metric=metric, agg=agg)
Expand All @@ -84,9 +84,9 @@ def test_knn_different_k_values(metric, agg):
def test_knn_pass_y_train(metric, agg):
# smoke test, should not throw errors
if "binary" in metric:
X_train, _ = get_data_inside_ad(binarize=True)
X_train, _ = get_data_inside_ad(n_train=100, n_test=10, binarize=True)
else:
X_train, _ = get_data_inside_ad()
X_train, _ = get_data_inside_ad(n_train=100, n_test=10)

y_train = np.zeros(len(X_train))
ad_checker = KNNADChecker(k=3, metric=metric, agg=agg)
Expand All @@ -97,9 +97,9 @@ def test_knn_pass_y_train(metric, agg):
@pytest.mark.parametrize("agg", ALLOWED_AGGREGATIONS)
def test_knn_invalid_k(metric, agg):
if "binary" in metric:
X_train, _ = get_data_inside_ad(binarize=True)
X_train, _ = get_data_inside_ad(n_train=100, n_test=10, binarize=True)
else:
X_train, _ = get_data_inside_ad()
X_train, _ = get_data_inside_ad(n_train=100, n_test=10)

with pytest.raises(
ValueError,
Expand All @@ -110,7 +110,7 @@ def test_knn_invalid_k(metric, agg):


def test_knn_invalid_metric():
X_train, _ = get_data_inside_ad()
X_train, _ = get_data_inside_ad(n_train=100)
ad_checker = KNNADChecker(k=3, metric="euclidean")
with pytest.raises(KeyError):
ad_checker.fit(X_train)
Loading