Skip to content

Commit 686bfb1

Browse files
committed
Fix wrong weighting logic
1 parent abed201 commit 686bfb1

File tree

2 files changed

+40
-40
lines changed

2 files changed

+40
-40
lines changed

NEWS.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
# kernelshap 0.8.1
1+
# kernelshap 0.9.0
2+
3+
### Major bug fix
4+
5+
`kernelshap()` used a wrong weighting logic, leading to values slightly off. This has
6+
been fixed with the help of Prof Mario Wuethrich of ETHZ and
7+
[Ian Covert's wonderful Github repo](https://github.com/iancovert/shapley-regression).
8+
Now, exact Kernel SHAP returns identical values than exact permutation SHAP.
9+
10+
Fixed in [#167](https://github.com/ModelOriented/kernelshap/pull/167).
211

312
### API
413

R/utils_kernelshap.R

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ solver <- function(A, b, constraint) {
107107
# to Kernel SHAP weights -> (m x p) matrix.
108108
# The argument S can be used to restrict the range of sum(z).
109109
sample_Z <- function(p, m, feature_names, S = 1:(p - 1L)) {
110-
# First draw s = sum(z) according to Kernel weights (renormalized to sum 1)
111-
probs <- kernel_weights(p, S = S)
110+
probs <- kernel_weights_per_coalition_size(p, S = S)
112111
N <- S[sample.int(length(S), m, replace = TRUE, prob = probs)]
113112

114113
# Then, conditional on that number, set random positions of z to 1
@@ -144,8 +143,8 @@ input_sampling <- function(p, m, deg, feature_names) {
144143
S <- (deg + 1L):(p - deg - 1L)
145144
Z <- sample_Z(p = p, m = m / 2, feature_names = feature_names, S = S)
146145
Z <- rbind(Z, !Z)
147-
w_total <- if (deg == 0L) 1 else 1 - 2 * sum(kernel_weights(p)[seq_len(deg)])
148-
w <- w_total / m
146+
w <- if (deg == 0L) 1 else 1 - prop_exact(p, deg = deg)
147+
w <- w / m
149148
list(Z = Z, w = rep.int(w, m), A = crossprod(Z) * w)
150149
}
151150

@@ -159,33 +158,10 @@ input_sampling <- function(p, m, deg, feature_names) {
159158
input_exact <- function(p, feature_names) {
160159
Z <- exact_Z(p, feature_names = feature_names)
161160
Z <- Z[2L:(nrow(Z) - 1L), , drop = FALSE]
162-
# Each Kernel weight(j) is divided by the number of vectors z having sum(z) = j
163-
w <- kernel_weights(p) / choose(p, 1:(p - 1L))
164-
list(Z = Z, w = w[rowSums(Z)], A = exact_A(p, feature_names = feature_names))
165-
}
166-
167-
#' Exact Matrix A
168-
#'
169-
#' Internal function that calculates exact A.
170-
#' Notice the difference to the off-diagnonals in the Supplement of
171-
#' Covert and Lee (2021). Credits to David Watson for figuring out the correct formula,
172-
#' see our discussions in https://github.com/ModelOriented/kernelshap/issues/22
173-
#'
174-
#' @noRd
175-
#' @keywords internal
176-
#'
177-
#' @param p Number of features.
178-
#' @param feature_names Feature names.
179-
#' @returns A (p x p) matrix.
180-
exact_A <- function(p, feature_names) {
181-
S <- 1:(p - 1L)
182-
c_pr <- S * (S - 1) / p / (p - 1)
183-
off_diag <- sum(kernel_weights(p) * c_pr)
184-
A <- matrix(
185-
data = off_diag, nrow = p, ncol = p, dimnames = list(feature_names, feature_names)
186-
)
187-
diag(A) <- 0.5
188-
A
161+
kw <- kernel_weights(p) # Kernel weights for all subsets
162+
w <- kw[rowSums(Z)] # Corresponding weight for each row in Z
163+
w <- w / sum(w)
164+
list(Z = Z, w = w, A = crossprod(Z, w * Z))
189165
}
190166

191167
# List all length p vectors z with sum(z) in {k, p - k}
@@ -228,22 +204,37 @@ input_partly_exact <- function(p, deg, feature_names) {
228204
}
229205

230206
kw <- kernel_weights(p)
231-
Z <- w <- vector("list", deg)
232207

208+
Z <- vector("list", deg)
233209
for (k in seq_len(deg)) {
234210
Z[[k]] <- partly_exact_Z(p, k = k, feature_names = feature_names)
235-
n <- nrow(Z[[k]])
236-
w_tot <- kw[k] * (2 - (p == 2L * k))
237-
w[[k]] <- rep.int(w_tot / n, n)
238211
}
239-
w <- unlist(w, recursive = FALSE, use.names = FALSE)
240212
Z <- do.call(rbind, Z)
241-
213+
w <- kw[rowSums(Z)]
214+
w_target <- prop_exact(p, deg = deg) # How much of total weight to spend here
215+
w <- w / sum(w) * w_target
242216
list(Z = Z, w = w, A = crossprod(Z, w * Z))
243217
}
244218

245-
# Kernel weights normalized to a non-empty subset S of {1, ..., p-1}
219+
# Kernel weights
246220
kernel_weights <- function(p, S = seq_len(p - 1L)) {
247221
probs <- (p - 1L) / (choose(p, S) * S * (p - S))
248-
probs / sum(probs)
222+
return(probs / sum(probs))
223+
}
224+
225+
# Kernel weights per coalition size
226+
kernel_weights_per_coalition_size <- function(p, S = seq_len(p - 1L)) {
227+
probs <- 1 / (S * (p - S))
228+
return(probs / sum(probs))
229+
}
230+
231+
# How much Kernel SHAP weights do coalitions of size
232+
# {1, ..., deg, ..., p-deg-1 ..., p-1} have?
233+
prop_exact <- function(p, deg) {
234+
if (deg == 0) {
235+
return(0)
236+
}
237+
w <- kernel_weights_per_coalition_size(p)
238+
w_total <- 2 * sum(w[seq_len(deg)]) - w[deg] * (p == 2 * deg)
239+
return(w_total)
249240
}

0 commit comments

Comments
 (0)