Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
46864d4
feat: copy nn explainers from shapiq_student
Zaphoood Dec 16, 2025
50f6dea
feat: integrate nn explainers into dynamic explainer dispatch
Zaphoood Dec 16, 2025
f118f5f
feat: remove unused property `mode` from nn explainers
Zaphoood Dec 16, 2025
ae25472
feat: clean up NN explainer base class
Zaphoood Dec 16, 2025
b22cb71
refactor: turn BruteForceKNNExplainer into benchmark game
Zaphoood Dec 16, 2025
c6ccb68
refactor: factor out base class for nn explainer benchmarks
Zaphoood Dec 16, 2025
f4dfaf1
refactor: delete unused brute force knn explainer
Zaphoood Dec 16, 2025
df87ad3
refactor: turn BruteForceWKNNExplainer into benchmark game
Zaphoood Dec 16, 2025
cfe489c
feat: merge WeightedKNNExplainer with its base class
Zaphoood Dec 16, 2025
c262160
docs: improve wording in WeightedKNNExplainer docstring
Zaphoood Dec 16, 2025
48dc376
refactor: clean up nn benchmark games
Zaphoood Dec 16, 2025
72af94c
refacgtor delete unused brute force tnn explainer
Zaphoood Dec 16, 2025
0807aca
refactor: delete unused lookup game
Zaphoood Dec 16, 2025
b47a618
refactor: merge CommonKNNExplainer into its subclasses
Zaphoood Dec 16, 2025
7f19abe
feat: check that index and order are valid for nn explainers
Zaphoood Dec 16, 2025
4ef1ee3
refactor: remove custom exception classes
Zaphoood Dec 16, 2025
beeec10
refactor: clean up nn explainers' constructors
Zaphoood Dec 16, 2025
c9c0554
fix: set normalization value in NN benchmark game base
Zaphoood Jan 6, 2026
f972489
fix: sort coalition in KNN benchmark game
Zaphoood Jan 6, 2026
e8a2771
tests: add unit test for KNN explainer
Zaphoood Jan 6, 2026
99acebe
tests: add unit test for WKNN explainer
Zaphoood Jan 6, 2026
62ef64e
refactor: rename WKNN test function
Zaphoood Jan 6, 2026
c02d111
feat: add TNN benchmark game
Zaphoood Jan 6, 2026
5f2db46
fix: utility of coalition with no points in radius of threshold nn cl…
Zaphoood Jan 6, 2026
f9a11c9
tests: add unit test for TNN explainer
Zaphoood Jan 6, 2026
0f8fb6a
tests: test NN explainers with all train points instead of just 1
Zaphoood Jan 6, 2026
1ef70b0
fix: handle case N < k in KNN benchmark and add test
Zaphoood Jan 6, 2026
6bb3cd5
feat: handle access to sklearn private members gracefully
Zaphoood Jan 6, 2026
10595f6
fix: remove kwarg `class_index` from KNN's explain_function
Zaphoood Jan 6, 2026
966a160
refactor: factor out InteractionValues from/to array helpers
Zaphoood Jan 6, 2026
159b53f
feat: add notebook for NN explainers
Zaphoood Jan 6, 2026
37dc007
chore: remove obsolete TODO
Zaphoood Jan 6, 2026
c164735
tests: use randomly generated test points for testing NN explainers
Zaphoood Jan 6, 2026
e4521d3
feat: use footnote citations in docstrings
Zaphoood Jan 7, 2026
c61dbb8
feat: improve NN explainer base class docstring
Zaphoood Jan 7, 2026
ec26df9
fix: unify spelling of 'nearest neighbor' without hyphen
Zaphoood Jan 7, 2026
c1e9af9
feat: improve wording in notebook
Zaphoood Jan 7, 2026
de69010
document new explainers in changelog
Zaphoood Jan 7, 2026
fd0aca2
improve wording in changelog
Zaphoood Jan 7, 2026
3b8b420
fix: unterminated f-string literal
Zaphoood Jan 7, 2026
4c7388b
fix: execute data valuation notebook
Zaphoood Jan 7, 2026
4560c63
fix: name tests correctly to make them discoverable
Zaphoood Jan 7, 2026
e470e92
refactor: place nn explainer benchmark games alongside efficient expl…
Zaphoood Jan 14, 2026
9556d2c
refactor: parametrize nn explainer unit tests
Zaphoood Jan 14, 2026
9bf4d34
tests: tests error handling of nn explainer base class
Zaphoood Jan 14, 2026
47431fa
feat: test knn explainer error handling
Zaphoood Jan 19, 2026
1c7ee1b
tests: increase radius of TNN classifier
Zaphoood Jan 19, 2026
0250f72
tests: test error handling of tnn classifier
Zaphoood Jan 19, 2026
3db23d1
chore: delete duplicate function
Zaphoood Jan 19, 2026
9e45dd5
feat: add more data checks to nn explainer games
Zaphoood Jan 19, 2026
f59333e
tests: test nn explainer game base
Zaphoood Jan 19, 2026
0d72436
tests: test index/max_order verification util
Zaphoood Jan 20, 2026
c78092e
tests: test warning for ignored parameters
Zaphoood Jan 20, 2026
895c344
tests: add case for automatic dispatch to nn explainers
Zaphoood Jan 20, 2026
2030858
feat: add stress test for knn explainer to notebook
Zaphoood Jan 20, 2026
5e6ca07
refactor: rename binary weighted knn explainer game
Zaphoood Jan 20, 2026
e2daf3e
refactor: add conversion from/to array into InteractionValues object
Zaphoood Jan 20, 2026
061792a
fix: value of empty coalition of tnn game
Zaphoood Jan 20, 2026
fec0561
feat: replace binom fraction with product
Zaphoood Jan 21, 2026
0fe60c0
feat: optimize tnn explainer
Zaphoood Jan 21, 2026
fb637db
refactor: rename knn data valuation notebook
Zaphoood Jan 21, 2026
5cd22d5
fix: re-run knn notebook
Zaphoood Jan 21, 2026
834a255
feat: demonstrate performance of wknn, tnn explainers in notebook
Zaphoood Jan 21, 2026
94662da
chore: remove first order index check when creating iv from array
Zaphoood Jan 21, 2026
a542d4d
tests: add unit tests from iv from/to array
Zaphoood Jan 21, 2026
d919349
refactor: require setting `class_index` for nn explainers
Zaphoood Jan 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

### Introducing Explainers for Nearest Neighbor Models

Adds three new explainers, namely `KNNExplainer`, `WeightedKNNExplainer` and `ThresholdNNExplainer`, which efficiently compute explanations for nearest neighbor models from the [scikit-learn](https://scikit-learn.org/stable/) library.
One application of these explainers is Data Valuation, i.e. the task of evaluating the usefulness of training data points for training models.

## v1.4.1 (2025-11-10)

### Bugfix
Expand Down
506 changes: 506 additions & 0 deletions docs/source/notebooks/tabular_notebooks/knn_data_valuation.ipynb

Large diffs are not rendered by default.

74 changes: 74 additions & 0 deletions docs/source/notebooks/tabular_notebooks/plot_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Helper functions for plotting to be used by the notebooks."""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
from matplotlib.axes import Axes

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D


def plot_datasets(
ax: Axes,
X_train: npt.NDArray[np.floating],
y_train: npt.NDArray[np.floating],
X_test: npt.NDArray[np.floating] | None = None,
y_test: npt.NDArray[np.floating] | None = None,
title: str | None = None,
) -> None:
"""Plots train and test datasets in the same figure."""
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

if title is not None:
ax.set_title(title)
ax.scatter(
X_train[:, 0],
X_train[:, 1],
c=[colors[i] for i in y_train],
label="Training Points",
marker="o",
)
if X_test is not None and y_test is not None:
ax.scatter(
X_test[:, 0],
X_test[:, 1],
c=[colors[i] for i in y_test],
label="Test Points",
marker="x",
)

handles = [
Line2D(
[0],
[0],
marker="o",
color="w",
markerfacecolor=colors[i],
markersize=10,
label=f"Class {i} (Train)",
)
for i in set(y_train)
]
if y_test is not None:
handles += [
Line2D(
[0],
[0],
marker="x",
linewidth=0,
color=colors[i],
markerfacecolor=colors[i],
markersize=10,
label=f"Class {i} (Test)",
)
for i in set(y_train)
]
ax.legend(handles=handles, loc="upper right", title="Data Points")

ax.set_xlabel("Feature 1")
ax.set_ylabel("Feature 2")
25 changes: 25 additions & 0 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,28 @@ @inproceedings{Yu.2022
booktitle = {Advances in Neural Information Processing Systems 35: Annual Conference on Neural Information Processing Systems 2022 ({NeurIPS} 2022)},
url = {http://papers.nips.cc/paper\_files/paper/2022/hash/a5a3b1ef79520b7cd122d888673a3ebc-Abstract-Conference.html}
}
@article{Jia.2019,
title = {Efficient task-specific data valuation for nearest neighbor algorithms},
author = {Jia, Ruoxi and Dao, David and Wang, Boxin and Hubis, Frances Ann and Gurel, Nezihe Merve and Li, Bo and Zhang, Ce and Spanos, Costas J and Song, Dawn},
journal = {arXiv preprint arXiv:1908.08619},
year = {2019},
url = {https://doi.org/10.48550/arXiv.1908.08619/}
}
@article{Wang.2023,
title = {A privacy-friendly approach to data valuation},
author = {Wang, Jiachen Tianhao and Zhu, Yuqing and Wang, Yu-Xiang and Jia, Ruoxi and Mittal, Prateek},
journal = {Advances in Neural Information Processing Systems},
volume = {36},
pages = {60429--60467},
year = {2023},
url = {https://arxiv.org/abs/2308.15709}
}
@inproceedings{Wang.2024,
title = {Efficient data shapley for weighted nearest neighbor algorithms},
author = {Wang, Jiachen T and Mittal, Prateek and Jia, Ruoxi},
booktitle = {International Conference on Artificial Intelligence and Statistics},
pages = {2557--2565},
year = {2024},
organization = {PMLR},
url = {https://doi.org/10.48550/arXiv.1908.08619}
}
11 changes: 11 additions & 0 deletions src/shapiq/explainer/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Explainers for nearest neighbor models."""

from .knn import KNNExplainer
from .threshold_nn import ThresholdNNExplainer
from .weighted_knn import WeightedKNNExplainer

__all__ = [
"KNNExplainer",
"ThresholdNNExplainer",
"WeightedKNNExplainer",
]
47 changes: 47 additions & 0 deletions src/shapiq/explainer/nn/_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Utility function for the NormalKNNExplainer and the WeightedKNNExplainer."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping

from shapiq.explainer.custom_types import ExplainerIndices

logger = logging.getLogger()


def warn_ignored_parameters(
local_vars: Mapping[str, Any], ignored_parameter_names: Iterable[str], class_name: str
) -> None:
for param in ignored_parameter_names:
if local_vars[param] is not None:
logger.warning(
"A non-None value was passed as parameter `%s` to the constructor of %s, which will be ignored.",
class_name,
param,
)


def assert_valid_index_and_order(index: ExplainerIndices, max_order: int) -> None:
"""Check that the explainer index and max_order are valid for NN models, raise otherwise.

The only valid indices are ``'SV'`` and ``'k-SII'``; the only valid max. order is ``1``.

Args:
index: The explainer index to validate.
max_order: The max. order to validate.

Raises:
ValueError: If either of the parameters does not satisfy the requirements.
"""
valid_indices: list[ExplainerIndices] = ["SV", "k-SII"]
if index not in valid_indices:
msg = f"Explainer index '{index}' is invalid for nearest neighbor models. Valid indices are: {', '.join(valid_indices)}"
raise ValueError(msg)

if max_order != 1:
msg = f"Explanation order of {max_order} is invalid; the only valid order is 1."
raise ValueError(msg)
100 changes: 100 additions & 0 deletions src/shapiq/explainer/nn/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Implementation of the base class for nearest neighbor explainers."""

from __future__ import annotations

from typing import TYPE_CHECKING, cast

import numpy as np
from sklearn.utils.validation import check_is_fitted

from shapiq import Explainer

if TYPE_CHECKING:
import numpy.typing as npt
from sklearn.neighbors import KNeighborsClassifier, RadiusNeighborsClassifier


class NNExplainerBase(Explainer):
"""Base class for nearest neighbor Explainers."""

X_train: npt.NDArray[np.floating]
"""Training data features extracted from the model."""

y_train_indices: npt.NDArray[np.integer]
"""Training data labels as indices into the classes array."""

y_train_classes: npt.NDArray[np.object_]
"""Class labels from the model's training data."""

def __init__(
self,
model: KNeighborsClassifier | RadiusNeighborsClassifier,
class_index: int | None = None,
) -> None:
"""Initializes the class.

Args:
model: The NN model to explain. Must be an instance of ``sklearn.neighbors.KNeighborsClassifier`` or ``sklearn.neighbors.RadiusNeighborsClassifier``.
The model must not use multi-output classification, i.e. the ``y`` value provided to ``model.fit(X, y)`` must be a 1D vector.
class_index: The class index of the model to explain. Note that, as opposed to most Explainers, this must not be ``None``!

Raises:
sklearn.exceptions.NotFittedError: The constructor was called with a model that hasn't been fitted.
"""
super().__init__(model, data=None, class_index=class_index, index="SV", max_order=1)
check_is_fitted(model)

if class_index is None:
msg = (
"Nearest-neighbor explainers require setting class_index explicitely. Please pass a value to"
"class_index in the constructor"
)
raise ValueError(msg)
self.class_index = class_index

X_train = _sklearn_model_get_private_attribute(model, "_fit_X")
if not isinstance(X_train, np.ndarray):
msg = f"Expected model's training features (model._fit_X) to be numpy array but got {type(X_train)}"
raise TypeError(msg)
if not (
np.issubdtype(X_train.dtype, np.floating) or np.issubdtype(X_train.dtype, np.integer)
):
msg = f"Expected dtype of model's training features (model._fit_X) to be a subtype of np.floating or np.integer, but got {X_train.dtype}"
raise TypeError(msg)
if np.issubdtype(X_train.dtype, np.integer):
X_train = X_train.astype(np.float32)
self.X_train = X_train

y_train_indices = _sklearn_model_get_private_attribute(model, "_y")
if not isinstance(y_train_indices, np.ndarray):
msg = f"Expected model's training class indices (model._y) to be numpy array but got {type(y_train_indices)}"
raise TypeError(msg)
if not np.issubdtype(y_train_indices.dtype, np.integer):
msg = f"Expected dtype of model's training class indices (model._y) to be a subtype of np.integer, but got {y_train_indices.dtype}"
raise TypeError(msg)
if y_train_indices.ndim != 1:
msg = "Multi-output nearest neighbor classifiers are not supported. Make sure to pass the training labels as a 1D vector when calling `model.fit()`."
raise ValueError(msg)
self.y_train_indices = y_train_indices

if not isinstance(model.classes_, np.ndarray):
msg = f"Expected model's training classes (model.classes_) to be numpy array but got {type(model.classes_)}"
raise TypeError(msg)
self.y_train_classes = cast("npt.NDArray[np.object_]", model.classes_)


def _sklearn_model_get_private_attribute(
model: KNeighborsClassifier | RadiusNeighborsClassifier, attribute: str
) -> object:
if not attribute.startswith("_"):
msg = f"Name of private attribute must start with underscore, but got '{attribute}'"
raise ValueError(msg)

try:
return model.__getattribute__(attribute)
except AttributeError as e:
msg = (
f"Failed to access private attribute '{attribute}' of sklearn model. This may be caused by a change to the "
"implementation of the sklearn library. Please report this problem at https://github.com/mmschlk/shapiq/issues"
)
raise AttributeError(msg) from e
1 change: 1 addition & 0 deletions src/shapiq/explainer/nn/games/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Contains benchmark games for nearest neighbor explainers."""
39 changes: 39 additions & 0 deletions src/shapiq/explainer/nn/games/_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

from typing import TYPE_CHECKING, cast

import numpy as np

if TYPE_CHECKING:
import numpy.typing as npt


def keep_first_n(mask: npt.NDArray[np.bool], n: int) -> npt.NDArray[np.bool]:
"""Sets all entries of the input array to False except the first ``n`` entries with value ``True``.

This will just return a reference to the input array if ``np.sum(mask) <= n``

Args:
mask: The mask in question.
n: The maximum number of true entries.
"""
if n == 0:
return np.zeros_like(mask)

n_true = 0
for i, val in enumerate(mask):
n_true += int(val)
if n_true == n:
out = np.zeros_like(mask)
out[: i + 1] = mask[: i + 1]
return out

return mask


def _greater_or_close(a: np.floating, b: np.floating) -> np.bool:
"""Returns ``a >= b`` but allows for floating point error.

That is, if ``a < b`` but ``np.isclose(a, b)``, ``True`` will be returned.
"""
return cast("np.bool", a >= b) or np.isclose(a, b)
Loading