Skip to content

Commit b9ff089

Browse files
authored
Merge pull request #168 from ModelOriented/fix-weighting-logic
FIX-bug-in-weighting-logic
2 parents abed201 + e37913a commit b9ff089

File tree

14 files changed

+292
-115
lines changed

14 files changed

+292
-115
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: kernelshap
22
Title: Kernel SHAP
3-
Version: 0.8.1
3+
Version: 0.9.0
44
Authors@R: c(
55
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre"),
66
comment = c(ORCID = "0009-0007-2540-9629")),

NEWS.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
1-
# kernelshap 0.8.1
1+
# kernelshap 0.9.0
2+
3+
### Bug fix
4+
5+
With input from Mario Wuethrich and Ian Covert and his [repo](https://github.com/iancovert/shapley-regression),
6+
we have fixed a bug in how `kernelshap()` calculates Kernel weights.
7+
8+
- The differences caused by this are typically very small.
9+
- Models with interactions of order up to two have been unaffected.
10+
- Exact Kernel SHAP now provides identical results to exact permutation SHAP.
11+
12+
Fixed in [#168](https://github.com/ModelOriented/kernelshap/pull/168), which also has received
13+
unit tests against Python's "shap".
214

315
### API
416

R/kernelshap.R

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,11 @@
88
#' Otherwise, an almost exact hybrid algorithm combining exact calculations and
99
#' iterative paired sampling is used, see Details.
1010
#'
11-
#' Note that (exact) Kernel SHAP is only an approximation of (exact) permutation SHAP.
12-
#' Thus, for up to eight features, we recommend [permshap()]. For more features,
13-
#' [permshap()] tends to be inefficient compared the optimized hybrid strategy
14-
#' of Kernel SHAP.
15-
#'
1611
#' @details
1712
#' The pure iterative Kernel SHAP sampling as in Covert and Lee (2021) works like this:
1813
#'
19-
#' 1. A binary "on-off" vector \eqn{z} is drawn from \eqn{\{0, 1\}^p}
20-
#' such that its sum follows the SHAP Kernel weight distribution
21-
#' (normalized to the range \eqn{\{1, \dots, p-1\}}).
14+
#' 1. A binary "on-off" vector \eqn{z} is drawn from \eqn{\{0, 1\}^p} according to
15+
#' a special weighting logic.
2216
#' 2. For each \eqn{j} with \eqn{z_j = 1}, the \eqn{j}-th column of the
2317
#' original background data is replaced by the corresponding feature value \eqn{x_j}
2418
#' of the observation to be explained.
@@ -33,17 +27,13 @@
3327
#'
3428
#' This is repeated multiple times until convergence, see CL21 for details.
3529
#'
36-
#' A drawback of this strategy is that many (at least 75%) of the \eqn{z} vectors will
37-
#' have \eqn{\sum z \in \{1, p-1\}}, producing many duplicates. Similarly, at least 92%
38-
#' of the mass will be used for the \eqn{p(p+1)} possible vectors with
39-
#' \eqn{\sum z \in \{1, 2, p-2, p-1\}}.
40-
#' This inefficiency can be fixed by a hybrid strategy, combining exact calculations
41-
#' with sampling.
30+
#' To avoid the re-evaluation of identical coalition vectors, we have implemented
31+
#' a hybrid strategy, combining exact calculations with sampling.
4232
#'
4333
#' The hybrid algorithm has two steps:
4434
#' 1. Step 1 (exact part): There are \eqn{2p} different on-off vectors \eqn{z} with
45-
#' \eqn{\sum z \in \{1, p-1\}}, covering a large proportion of the Kernel SHAP
46-
#' distribution. The degree 1 hybrid will list those vectors and use them according
35+
#' \eqn{\sum z \in \{1, p-1\}}.
36+
#' The degree 1 hybrid will list those vectors and use them according
4737
#' to their weights in the upcoming calculations. Depending on \eqn{p}, we can also go
4838
#' a step further to a degree 2 hybrid by adding all \eqn{p(p-1)} vectors with
4939
#' \eqn{\sum z \in \{2, p-2\}} to the process etc. The necessary predictions are
@@ -96,12 +86,10 @@
9686
#' worse than the hybrid strategy and should therefore only be used for
9787
#' studying properties of the Kernel SHAP algorithm.
9888
#' - `1`: Uses all \eqn{2p} on-off vectors \eqn{z} with \eqn{\sum z \in \{1, p-1\}}
99-
#' for the exact part, which covers at least 75% of the mass of the Kernel weight
100-
#' distribution. The remaining mass is covered by random sampling.
89+
#' for the exact part. The remaining mass is covered by random sampling.
10190
#' - `2`: Uses all \eqn{p(p+1)} on-off vectors \eqn{z} with
102-
#' \eqn{\sum z \in \{1, 2, p-2, p-1\}}. This covers at least 92% of the mass of the
103-
#' Kernel weight distribution. The remaining mass is covered by sampling.
104-
#' Convergence usually happens in the minimal possible number of iterations of two.
91+
#' \eqn{\sum z \in \{1, 2, p-2, p-1\}}. The remaining mass is covered by sampling.
92+
#' Convergence usually happens very fast.
10593
#' - `k>2`: Uses all on-off vectors with
10694
#' \eqn{\sum z \in \{1, \dots, k, p-k, \dots, p-1\}}.
10795
#' @param m Even number of on-off vectors sampled during one iteration.

R/permshap.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
#' Furthermore, the 2p on-off vectors with sum <=1 or >=p-1 are evaluated only once,
3333
#' similar to the degree 1 hybrid in [kernelshap()] (but covering less weight).
3434
#'
35-
#' @param exact If `TRUE`, the algorithm will produce exact SHAP values
35+
#' @param exact If `TRUE`, the algorithm produces exact SHAP values
3636
#' with respect to the background data.
3737
#' The default is `TRUE` for up to eight features, and `FALSE` otherwise.
3838
#' @param low_memory If `FALSE` (default up to p = 15), the algorithm evaluates p

R/utils_kernelshap.R

Lines changed: 34 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ solver <- function(A, b, constraint) {
107107
# to Kernel SHAP weights -> (m x p) matrix.
108108
# The argument S can be used to restrict the range of sum(z).
109109
sample_Z <- function(p, m, feature_names, S = 1:(p - 1L)) {
110-
# First draw s = sum(z) according to Kernel weights (renormalized to sum 1)
111-
probs <- kernel_weights(p, S = S)
110+
probs <- kernel_weights(p, per_coalition_size = TRUE, S = S)
112111
N <- S[sample.int(length(S), m, replace = TRUE, prob = probs)]
113112

114113
# Then, conditional on that number, set random positions of z to 1
@@ -144,8 +143,8 @@ input_sampling <- function(p, m, deg, feature_names) {
144143
S <- (deg + 1L):(p - deg - 1L)
145144
Z <- sample_Z(p = p, m = m / 2, feature_names = feature_names, S = S)
146145
Z <- rbind(Z, !Z)
147-
w_total <- if (deg == 0L) 1 else 1 - 2 * sum(kernel_weights(p)[seq_len(deg)])
148-
w <- w_total / m
146+
w <- if (deg == 0L) 1 else 1 - prop_exact(p, deg = deg)
147+
w <- w / m
149148
list(Z = Z, w = rep.int(w, m), A = crossprod(Z) * w)
150149
}
151150

@@ -159,33 +158,10 @@ input_sampling <- function(p, m, deg, feature_names) {
159158
input_exact <- function(p, feature_names) {
160159
Z <- exact_Z(p, feature_names = feature_names)
161160
Z <- Z[2L:(nrow(Z) - 1L), , drop = FALSE]
162-
# Each Kernel weight(j) is divided by the number of vectors z having sum(z) = j
163-
w <- kernel_weights(p) / choose(p, 1:(p - 1L))
164-
list(Z = Z, w = w[rowSums(Z)], A = exact_A(p, feature_names = feature_names))
165-
}
166-
167-
#' Exact Matrix A
168-
#'
169-
#' Internal function that calculates exact A.
170-
#' Notice the difference to the off-diagnonals in the Supplement of
171-
#' Covert and Lee (2021). Credits to David Watson for figuring out the correct formula,
172-
#' see our discussions in https://github.com/ModelOriented/kernelshap/issues/22
173-
#'
174-
#' @noRd
175-
#' @keywords internal
176-
#'
177-
#' @param p Number of features.
178-
#' @param feature_names Feature names.
179-
#' @returns A (p x p) matrix.
180-
exact_A <- function(p, feature_names) {
181-
S <- 1:(p - 1L)
182-
c_pr <- S * (S - 1) / p / (p - 1)
183-
off_diag <- sum(kernel_weights(p) * c_pr)
184-
A <- matrix(
185-
data = off_diag, nrow = p, ncol = p, dimnames = list(feature_names, feature_names)
186-
)
187-
diag(A) <- 0.5
188-
A
161+
kw <- kernel_weights(p, per_coalition_size = FALSE) # Kernel weights for all subsets
162+
w <- kw[rowSums(Z)] # Corresponding weight for each row in Z
163+
w <- w / sum(w)
164+
list(Z = Z, w = w, A = crossprod(Z, w * Z))
189165
}
190166

191167
# List all length p vectors z with sum(z) in {k, p - k}
@@ -227,23 +203,39 @@ input_partly_exact <- function(p, deg, feature_names) {
227203
stop("p must be >=2*deg")
228204
}
229205

230-
kw <- kernel_weights(p)
231-
Z <- w <- vector("list", deg)
206+
kw <- kernel_weights(p, per_coalition_size = FALSE)
232207

208+
Z <- vector("list", deg)
233209
for (k in seq_len(deg)) {
234210
Z[[k]] <- partly_exact_Z(p, k = k, feature_names = feature_names)
235-
n <- nrow(Z[[k]])
236-
w_tot <- kw[k] * (2 - (p == 2L * k))
237-
w[[k]] <- rep.int(w_tot / n, n)
238211
}
239-
w <- unlist(w, recursive = FALSE, use.names = FALSE)
240212
Z <- do.call(rbind, Z)
241-
213+
w <- kw[rowSums(Z)]
214+
w_target <- prop_exact(p, deg = deg) # How much of total weight to spend here
215+
w <- w / sum(w) * w_target
242216
list(Z = Z, w = w, A = crossprod(Z, w * Z))
243217
}
244218

245-
# Kernel weights normalized to a non-empty subset S of {1, ..., p-1}
246-
kernel_weights <- function(p, S = seq_len(p - 1L)) {
247-
probs <- (p - 1L) / (choose(p, S) * S * (p - S))
248-
probs / sum(probs)
219+
# Kernel weight distribution
220+
#
221+
# `per_coalition_size = TRUE` is required, e.g., when one wants to sample random masks
222+
# according to the Kernel SHAP distribution: Pick a coalition size as per
223+
# these weights, then randomly place "on" positions. `FALSE` refer to weights
224+
# if all masks has been calculated and one wants to calculate their weights based
225+
# on the number of "on" positions.
226+
kernel_weights <- function(p, per_coalition_size, S = seq_len(p - 1L)) {
227+
const <- if (per_coalition_size) 1 else choose(p, S)
228+
probs <- (p - 1) / (const * S * (p - S)) # could drop the numerator
229+
return(probs / sum(probs))
230+
}
231+
232+
# How much Kernel SHAP weights do coalitions of size
233+
# {1, ..., deg, ..., p-deg-1 ..., p-1} have?
234+
prop_exact <- function(p, deg) {
235+
if (deg == 0) {
236+
return(0)
237+
}
238+
w <- kernel_weights(p, per_coalition_size = TRUE)
239+
w_total <- 2 * sum(w[seq_len(deg)]) - w[deg] * (p == 2 * deg)
240+
return(w_total)
249241
}

README.md

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,18 @@
1515

1616
The package contains three functions to crunch SHAP values:
1717

18-
- **`permshap()`**: Permutation SHAP algorithm of [1]. Recommended for models with up to 8 features, or if you don't trust Kernel SHAP. Both exact and sampling versions are available.
19-
- **`kernelshap()`**: Kernel SHAP algorithm of [2] and [3]. Recommended for models with more than 8 features. Both exact and (pseudo-exact) sampling versions are available.
18+
- **`permshap()`**: Permutation SHAP algorithm of [1]. Both exact and sampling versions are available.
19+
- **`kernelshap()`**: Kernel SHAP algorithm of [2] and [3]. Both exact and (pseudo-exact) sampling versions are available.
2020
- **`additive_shap()`**: For *additive models* fitted via `lm()`, `glm()`, `mgcv::gam()`, `mgcv::bam()`, `gam::gam()`, `survival::coxph()`, or `survival::survreg()`. Exponentially faster than the model-agnostic options above, and recommended if possible.
2121

22-
To explain your model, select an explanation dataset `X` (up to 1000 rows from the training data, feature columns only) and apply the recommended function. Use {shapviz} to visualize the resulting SHAP values.
22+
To explain your model, select an explanation dataset `X` (up to 1000 rows from the training data, feature columns only). Use {shapviz} to visualize the resulting SHAP values.
2323

2424
**Remarks to `permshap()` and `kernelshap()`**
2525

2626
- Both algorithms need a representative background data `bg_X` to calculate marginal means (up to 500 rows from the training data). In cases with a natural "off" value (like MNIST digits), this can also be a single row with all values set to the off value. If unspecified, 200 rows are randomly sampled from `X`.
27-
- Exact Kernel SHAP is an approximation to exact permutation SHAP. Since exact calculations are usually sufficiently fast for up to eight features, we recommend `permshap()` in this case. With more features, `kernelshap()` switches to a comparably fast, almost exact algorithm with faster convergence than the sampling version of permutation SHAP.
28-
That is why we recommend `kernelshap()` in this case.
29-
- For models with interactions of order up to two, SHAP values of permutation SHAP and Kernel SHAP agree,
30-
and the implemented sampling versions provide the same results as the exact versions.
31-
In the presence of interactions of order three or higher, this is no longer the case.
27+
- Exact Kernel SHAP gives identical results as exact permutation SHAP. Both algorithms are fast up to 8 features.
28+
With more features, `kernelshap()` switches to an almost exact algorithm with faster convergence than the sampling version of permutation SHAP.
29+
- For models with interactions of order up to two, the sampling versions provide the same results as the exact versions.
3230
- For additive models, `permshap()` and `kernelshap()` give the same results as `additive_shap`
3331
as long as the full training data would be used as background data.
3432

@@ -89,13 +87,12 @@ ps
8987
[1,] 1.1913247 0.09005467 -0.13430720 0.000682593
9088
[2,] -0.4931989 -0.11724773 0.09868921 0.028563613
9189

92-
# Kernel SHAP gives very slightly different values because the model contains
93-
# interations of order > 2:
90+
# Indeed, Kernel SHAP gives the same:
9491
ks <- kernelshap(fit, X, bg_X = bg_X)
9592
ks
96-
# log_carat clarity color cut
97-
# [1,] 1.1911791 0.0900462 -0.13531648 0.001845958
98-
# [2,] -0.4927482 -0.1168517 0.09815062 0.028255442
93+
log_carat clarity color cut
94+
[1,] 1.1913247 0.09005467 -0.13430720 0.000682593
95+
[2,] -0.4931989 -0.11724773 0.09868921 0.028563613
9996

10097
# 4) Analyze with {shapviz}
10198
ps <- shapviz(ps)

backlog/compare_with_python2.R

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
library(kernelshap)
2+
3+
n <- 100
4+
5+
X <- data.frame(
6+
x1 = seq(1:n) / 100,
7+
x2 = log(1:n),
8+
x3 = sqrt(1:n),
9+
x4 = sin(1:n),
10+
x5 = (seq(1:n) / 100)^2,
11+
x6 = cos(1:n)
12+
)
13+
head(X)
14+
15+
pf <- function(model, newdata) {
16+
x <- newdata
17+
x[, 1] * x[, 2] * x[, 3] * x[, 4] + x[, 5] + x[, 6]
18+
}
19+
ks <- kernelshap(pf, head(X), bg_X = X, pred_fun = pf)
20+
ks # -1.196216 -1.241848 -0.9567848 3.879420 -0.33825 0.5456252
21+
es <- permshap(pf, head(X), bg_X = X, pred_fun = pf)
22+
es # -1.196216 -1.241848 -0.9567848 3.879420 -0.33825 0.5456252
23+
24+
set.seed(10)
25+
kss <- kernelshap(
26+
pf,
27+
head(X, 1),
28+
bg_X = X,
29+
pred_fun = pf,
30+
hybrid_degree = 0,
31+
exact = F,
32+
m = 9000,
33+
max_iter = 100,
34+
tol = 0.0005
35+
)
36+
kss # -1.198078 -1.246508 -0.9580638 3.877532 -0.3241824 0.541247
37+
38+
set.seed(2)
39+
ksh <- kernelshap(
40+
pf,
41+
head(X, 1),
42+
bg_X = X,
43+
pred_fun = pf,
44+
hybrid_degree = 1,
45+
exact = FALSE,
46+
max_iter = 10000,
47+
tol = 0.0005
48+
)
49+
ksh # -1.191981 -1.240656 -0.9516264 3.86776 -0.3342143 0.5426642
50+
51+
set.seed(1)
52+
ksh2 <- kernelshap(
53+
pf,
54+
head(X, 1),
55+
bg_X = X,
56+
pred_fun = pf,
57+
hybrid_degree = 2,
58+
exact = FALSE,
59+
m = 10000,
60+
max_iter = 10000,
61+
tol = 0.0001
62+
)
63+
ksh2 # 1.195976 -1.241107 -0.9565121 3.878891 -0.3384621 0.5451118

cran-comments.md

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
# kernelshap 0.8.0
1+
# kernelshap 0.9.0
22

3-
Dear CRAN team
3+
We have figured out a bug in the weighting logic of Kernel SHAP.
44

5-
The package (finally!) contains a sampling version of permutation SHAP. In contrast
6-
to other implementations, it iterates until convergence, and standard errors are provided.
5+
This update comes with a fix which has been tested against two other implementations.
6+
7+
I am aware that the last release of {kernelshap} is not too long ago, but I still would love to see
8+
this fixed before the (well-deserved) summer break.
9+
10+
Thanks a lot!
711

812
## Checks
913

@@ -18,9 +22,8 @@ R Under development (unstable) (2025-07-05 r88387 ucrt)
1822

1923
### Revdep OK
2024

21-
survex 1.2.0 ── E: 0 | W: 0 | N: 0
22-
XAItest 1.0.1 ── E: 1 | W: 0 | N: 0
23-
SEMdeep 1.0.0 ── E: 1 | W: 1 | N: 0
25+
survex 1.2.0 ── E: 0 | W: 0 | N: 0
26+
XAItest 1.0.1 ── E: 1 | W: 0 | N: 0
27+
SEMdeep 1.0.0 ── E: 1 | W: 1 | N: 0
2428

25-
OK: 3
26-
BROKEN: 0
29+
OK: 3

0 commit comments

Comments
 (0)