Skip to content

Commit ae82d9d

Browse files
committed
refactor: add conversion from/to array into InteractionValues object
1 parent 5e6ca07 commit ae82d9d

File tree

6 files changed

+172
-56
lines changed

6 files changed

+172
-56
lines changed

docs/source/notebooks/tabular_notebooks/data_valuation.ipynb

Lines changed: 83 additions & 44 deletions
Large diffs are not rendered by default.

src/shapiq/explainer/nn/knn.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,18 @@
77

88
import numpy as np
99

10+
from shapiq import InteractionValues
1011
from shapiq.explainer.nn.base import NNExplainerBase
1112

1213
from ._util import (
1314
assert_valid_index_and_order,
1415
warn_ignored_parameters,
1516
)
16-
from .iv_utils import interaction_values_from_array
1717

1818
if TYPE_CHECKING:
1919
import numpy.typing as npt
2020
from sklearn.neighbors import KNeighborsClassifier
2121

22-
from shapiq import InteractionValues
2322
from shapiq.explainer.custom_types import ExplainerIndices
2423

2524

@@ -79,4 +78,4 @@ def explain_function(self, x: npt.NDArray[np.floating]) -> InteractionValues:
7978
inv_sortperm = np.zeros_like(sortperm)
8079
inv_sortperm[sortperm] = np.arange(sortperm.shape[0])
8180

82-
return interaction_values_from_array(sv[inv_sortperm])
81+
return InteractionValues.from_first_order_array(sv[inv_sortperm], index="SV")

src/shapiq/explainer/nn/threshold_nn.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,14 @@
1313
import sklearn.neighbors
1414
from sklearn.neighbors import RadiusNeighborsClassifier
1515

16-
from shapiq import InteractionValues
1716
from shapiq.explainer.custom_types import ExplainerIndices
18-
17+
from shapiq import InteractionValues
1918

2019
from ._util import (
2120
assert_valid_index_and_order,
2221
warn_ignored_parameters,
2322
)
2423
from .base import NNExplainerBase
25-
from .iv_utils import interaction_values_from_array
2624

2725

2826
class ThresholdNNExplainer(NNExplainerBase):
@@ -104,4 +102,4 @@ def explain_function(self, x: npt.NDArray[np.floating]) -> InteractionValues:
104102

105103
sv = first_summand + second_summand
106104

107-
return interaction_values_from_array(sv)
105+
return InteractionValues.from_first_order_array(sv, index="SV")

src/shapiq/explainer/nn/weighted_knn.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,17 @@
88
from shapiq.explainer.nn.base import NNExplainerBase
99

1010
from ._util import assert_valid_index_and_order, warn_ignored_parameters
11-
from .iv_utils import interaction_values_from_array
1211

1312
if TYPE_CHECKING:
1413
import numpy.typing as npt
1514
from sklearn.neighbors import KNeighborsClassifier
1615

1716
from shapiq.explainer.custom_types import ExplainerIndices
18-
from shapiq.interaction_values import InteractionValues
19-
2017
import numpy as np
2118
from scipy.special import comb
2219

20+
from shapiq.interaction_values import InteractionValues
21+
2322

2423
class WeightedKNNExplainer(NNExplainerBase):
2524
r"""Explainer for weighted KNN models.
@@ -82,7 +81,9 @@ def explain_function(self, x: npt.NDArray[np.floating]) -> InteractionValues:
8281

8382
n_classes = len(self.y_train_classes)
8483
if n_classes == 1:
85-
return interaction_values_from_array(np.full(n_players, 1 / n_players))
84+
return InteractionValues.from_first_order_array(
85+
np.full(n_players, 1 / n_players), index="SV"
86+
)
8687

8788
sortperm, weights = self._get_prepared_weights(x)
8889

@@ -91,7 +92,7 @@ def explain_function(self, x: npt.NDArray[np.floating]) -> InteractionValues:
9192
sv += self._explain_binary(other_class_index, sortperm, weights)
9293
sv /= n_classes - 1
9394

94-
return interaction_values_from_array(sv)
95+
return InteractionValues.from_first_order_array(sv, index="SV")
9596

9697
def _explain_binary(
9798
self,

src/shapiq/game_theory/indices.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@
119119

120120
AllIndices: tuple[IndexType, ...] = tuple(get_args(IndexType))
121121

122+
FIRST_ORDER_INDICES = ["SV", "BV", "JointSV"]
123+
122124

123125
def is_index_valid(index: str, *, raise_error: bool = False) -> bool:
124126
"""Checks if the given index is a valid interaction index.
@@ -300,3 +302,24 @@ def is_empty_value_the_baseline(index: str) -> bool:
300302
301303
"""
302304
return index not in ["SII", "FBII", "BII", "BV"]
305+
306+
307+
def is_first_order(index: str, *, raise_error: bool = False) -> bool:
308+
"""Check if the index represents only first-order interactions.
309+
310+
Args:
311+
index: The interaction index.
312+
raise_error: If ``True``, raises a ``ValueError`` if the index is not first-order
313+
314+
Returns:
315+
``True`` if the index is a first-order index, ``False`` otherwise.
316+
"""
317+
first_order = index in FIRST_ORDER_INDICES
318+
if not first_order and raise_error:
319+
msg = (
320+
f"Expected first-order index but got '{index}'. First-order indices are: "
321+
+ ", ".join(FIRST_ORDER_INDICES)
322+
)
323+
raise ValueError(msg)
324+
325+
return first_order

src/shapiq/interaction_values.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
AllIndices,
1717
get_index_from_computation_index,
1818
is_empty_value_the_baseline,
19+
is_first_order,
1920
is_index_aggregated,
2021
is_index_valid,
2122
)
@@ -924,6 +925,61 @@ def to_dict(self) -> dict:
924925
"baseline_value": self.baseline_value,
925926
}
926927

928+
@classmethod
929+
def from_first_order_array(
930+
cls, first_order_values: np.ndarray, index: str, baseline_value: float = 0
931+
) -> InteractionValues:
932+
"""Convert an array of first-order values to an :class:`shapiq.InteractionValues` object.
933+
934+
Args:
935+
first_order_values: An array containing the value of the ith training point at index i.
936+
index: The game theoretic index of the resulting :class:`shapiq.InteractionValues` object. Must be a
937+
first-order index.
938+
baseline_value: Baseline value, defaults to ``0``.
939+
940+
Returns:
941+
An :class:`~shapiq.InteractionValues` object containing the provided values.
942+
943+
Raises:
944+
ValueError: If the provided ``index`` is not a first-order index.
945+
"""
946+
is_first_order(index, raise_error=True)
947+
948+
n_players = first_order_values.shape[0]
949+
interaction_lookup: dict[tuple[int, ...], int] = {(i,): i for i in range(n_players)}
950+
951+
return InteractionValues(
952+
first_order_values,
953+
index=index,
954+
min_order=0,
955+
max_order=1,
956+
n_players=n_players,
957+
baseline_value=baseline_value,
958+
interaction_lookup=interaction_lookup,
959+
)
960+
961+
def to_first_order_array(self) -> np.ndarray:
962+
"""Convert to an array of first-order values.
963+
964+
Returns:
965+
An array of shape ``(self.n_players,)`` containing at index ``i`` the first-order value of player ``i``.
966+
967+
Raises:
968+
RuntimeError: If the method was called on an :class:`~shapiq.InteractionValues` object with max order
969+
not equal to ``1``.
970+
"""
971+
if self.max_order != 1:
972+
msg = f"Max order must be 1 but was {self.max_order}"
973+
raise ValueError(msg)
974+
975+
out = np.zeros((self.n_players,))
976+
for coalition, lookup_idx in self.interaction_lookup.items():
977+
if coalition == ():
978+
continue
979+
out[coalition[0]] = self.values[lookup_idx]
980+
981+
return out
982+
927983
def aggregate(
928984
self,
929985
others: Sequence[InteractionValues],

0 commit comments

Comments
 (0)