Skip to content

Commit d0001d0

Browse files
committed
Integrate the two types of weights
1 parent 7390c2b commit d0001d0

File tree

3 files changed

+35
-18
lines changed

3 files changed

+35
-18
lines changed

NEWS.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
### Major bug fix
44

55
`kernelshap()` used a wrong weighting logic, leading to values slightly off. This has
6-
been fixed with the help of Prof Mario Wuethrich of ETHZ and
7-
[Ian Covert's wonderful Github repo](https://github.com/iancovert/shapley-regression).
8-
Now, exact Kernel SHAP returns identical values as exact permutation SHAP.
6+
been fixed with the help of Prof Mario Wuethrich of ETHZ and Ian Covert and his
7+
[wonderful Github repo](https://github.com/iancovert/shapley-regression).
8+
Now, exact Kernel SHAP returns identical values as exact permutation SHAP.
9+
All variants of `kernelshap()` had been affected by this (exact, sampling, hybrid).
10+
For models with interactions up to two, the bug had no consequences - which is why
11+
it went unnoticed.
912

1013
Fixed in [#168](https://github.com/ModelOriented/kernelshap/pull/168).
1114

R/utils_kernelshap.R

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ solver <- function(A, b, constraint) {
107107
# to Kernel SHAP weights -> (m x p) matrix.
108108
# The argument S can be used to restrict the range of sum(z).
109109
sample_Z <- function(p, m, feature_names, S = 1:(p - 1L)) {
110-
probs <- kernel_weights_per_coalition_size(p, S = S)
110+
probs <- kernel_weights(p, per_coalition_size = TRUE, S = S)
111111
N <- S[sample.int(length(S), m, replace = TRUE, prob = probs)]
112112

113113
# Then, conditional on that number, set random positions of z to 1
@@ -158,7 +158,7 @@ input_sampling <- function(p, m, deg, feature_names) {
158158
input_exact <- function(p, feature_names) {
159159
Z <- exact_Z(p, feature_names = feature_names)
160160
Z <- Z[2L:(nrow(Z) - 1L), , drop = FALSE]
161-
kw <- kernel_weights(p) # Kernel weights for all subsets
161+
kw <- kernel_weights(p, per_coalition_size = FALSE) # Kernel weights for all subsets
162162
w <- kw[rowSums(Z)] # Corresponding weight for each row in Z
163163
w <- w / sum(w)
164164
list(Z = Z, w = w, A = crossprod(Z, w * Z))
@@ -203,7 +203,7 @@ input_partly_exact <- function(p, deg, feature_names) {
203203
stop("p must be >=2*deg")
204204
}
205205

206-
kw <- kernel_weights(p)
206+
kw <- kernel_weights(p, per_coalition_size = FALSE)
207207

208208
Z <- vector("list", deg)
209209
for (k in seq_len(deg)) {
@@ -216,15 +216,16 @@ input_partly_exact <- function(p, deg, feature_names) {
216216
list(Z = Z, w = w, A = crossprod(Z, w * Z))
217217
}
218218

219-
# Kernel weights
220-
kernel_weights <- function(p, S = seq_len(p - 1L)) {
221-
probs <- (p - 1L) / (choose(p, S) * S * (p - S))
222-
return(probs / sum(probs))
223-
}
224-
225-
# Kernel weights per coalition size
226-
kernel_weights_per_coalition_size <- function(p, S = seq_len(p - 1L)) {
227-
probs <- 1 / (S * (p - S))
219+
# Kernel weight distribution
220+
#
221+
# `per_coalition_size = TRUE` is required, e.g., when one wants to sample random masks
222+
# according to the Kernel SHAP distribution: Pick a coalition size as per
223+
# these weights, then randomly place "on" positions. `FALSE` refer to weights
224+
# if all masks has been calculated and one wants to calculate their weights based
225+
# on the number of "on" positions.
226+
kernel_weights <- function(p, per_coalition_size, S = seq_len(p - 1L)) {
227+
const <- if (per_coalition_size) 1 else choose(p, S)
228+
probs <- (p - 1) / (const * S * (p - S)) # could drop the numerator
228229
return(probs / sum(probs))
229230
}
230231

@@ -234,7 +235,7 @@ prop_exact <- function(p, deg) {
234235
if (deg == 0) {
235236
return(0)
236237
}
237-
w <- kernel_weights_per_coalition_size(p)
238+
w <- kernel_weights(p, per_coalition_size = TRUE)
238239
w_total <- 2 * sum(w[seq_len(deg)]) - w[deg] * (p == 2 * deg)
239240
return(w_total)
240241
}

tests/testthat/test-kernelshap-utils.R

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
test_that("sum of kernel weights is 1", {
22
for (p in 2:10) {
3-
expect_equal(sum(kernel_weights(p)), 1.0)
3+
expect_equal(sum(kernel_weights(p, per_coalition_size = FALSE)), 1.0)
4+
expect_equal(sum(kernel_weights(p, per_coalition_size = TRUE)), 1.0)
45
}
56
})
67

78
test_that("Sum of kernel weights is 1, even for subset of domain", {
8-
expect_equal(sum(kernel_weights(10L, S = 2:5)), 1.0)
9+
expect_equal(sum(kernel_weights(10L, S = 2:5, per_coalition_size = FALSE)), 1.0)
10+
expect_equal(sum(kernel_weights(10L, S = 2:5, per_coalition_size = TRUE)), 1.0)
911
})
1012

1113
p <- 10L
@@ -105,6 +107,17 @@ test_that("hybrid weights sum to 1 for different p and degree 2", {
105107
}
106108
})
107109

110+
test_that("sampling input A is comparable from exact input", {
111+
set.seed(1)
112+
113+
for (p in 2:6) {
114+
feature_names <- LETTERS[1:p]
115+
pa <- input_exact(p, feature_names)
116+
sa <- input_sampling(p, m = 100000L, deg = 0, feature_names = feature_names)
117+
expect_true(all(abs(pa$A - sa$A) < 0.01))
118+
}
119+
})
120+
108121
test_that("partly_exact_Z(p, k) fails for bad p or k", {
109122
expect_error(partly_exact_Z(0L, k = 1L, feature_names = LETTERS[1:p]))
110123
expect_error(partly_exact_Z(5L, k = 3L, feature_names = LETTERS[1:p]))

0 commit comments

Comments
 (0)