Skip to content

Commit d128feb

Browse files
committed
Avoid replication of x
1 parent 0249061 commit d128feb

File tree

10 files changed

+72
-64
lines changed

10 files changed

+72
-64
lines changed

NEWS.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
# kernelshap 0.8.1
22

3-
### Performance improvements
3+
### API
44

5-
- `permshap(, exact = TRUE)` is slightly faster by pre-calculating more
5+
- The argument `feature_names` can now also be used with matrix input ([#166](https://github.com/ModelOriented/kernelshap/pull/166)).
6+
7+
### Speed and memory improvements
8+
9+
- `permshap()` and `kernelshap()` require about 10% less memory ([#166](https://github.com/ModelOriented/kernelshap/pull/166)).
10+
- `permshap()` and `kernelshap()` are faster for data.frame input,
11+
and slightly slower for matrix input ([#166](https://github.com/ModelOriented/kernelshap/pull/166)).
12+
- Additionally, `permshap(, exact = TRUE)` is faster by pre-calculating more
613
elements used across rows [#165](https://github.com/ModelOriented/kernelshap/pull/165)
714

815
### Documentation

R/kernelshap.R

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@
8282
#' Additional (named) arguments are passed via `...`.
8383
#' The default, [stats::predict()], will work in most cases.
8484
#' @param feature_names Optional vector of column names in `X` used to calculate
85-
#' SHAP values. By default, this equals `colnames(X)`. Not supported if `X`
86-
#' is a matrix.
85+
#' SHAP values. By default, this equals `colnames(X)`.
8786
#' @param bg_w Optional vector of case weights for each row of `bg_X`.
8887
#' If `bg_X = NULL`, must be of same length as `X`. Set to `NULL` for no weights.
8988
#' @param bg_n If `bg_X = NULL`: Size of background data to be sampled from `X`.
@@ -240,12 +239,6 @@ kernelshap.default <- function(
240239
message(txt)
241240
}
242241

243-
# Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant
244-
# In what follows, predictions will never be applied directly to bg_X anymore
245-
if (!identical(colnames(bg_X), feature_names)) {
246-
bg_X <- bg_X[, feature_names, drop = FALSE]
247-
}
248-
249242
# Pre-calculations that are identical for each row to be explained
250243
if (exact || hybrid_degree >= 1L) {
251244
if (exact) {

R/permshap.R

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,6 @@ permshap.default <- function(
137137
v0 <- wcolMeans(bg_preds, w = bg_w) # Average pred of bg data: 1 x K
138138
v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K
139139

140-
# Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant
141-
# Predictions will never be applied directly to bg_X anymore
142-
if (!identical(colnames(bg_X), feature_names)) {
143-
bg_X <- bg_X[, feature_names, drop = FALSE]
144-
}
145-
146140
# Pre-calculations that are identical for each row to be explained
147141
if (exact) {
148142
Z <- exact_Z(p, feature_names = feature_names)

R/utils.R

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,41 +63,42 @@ exact_Z <- function(p, feature_names) {
6363
return(Z)
6464
}
6565

66-
#' Masker
66+
#' Masked Predict
6767
#'
6868
#' Internal function.
6969
#' For each on-off vector (rows in `Z`), the (weighted) average prediction is returned.
70-
#' In Python, this function is called "masker", and Z is the "mask".
7170
#'
7271
#' @noRd
7372
#' @keywords internal
7473
#'
7574
#' @inheritParams kernelshap
76-
#' @param X Row to be explained stacked m*n_bg times.
75+
#' @param x Row to be explained.
7776
#' @param bg Background data stacked m times.
78-
#' @param Z A (m x p) matrix with on-off values.
77+
#' @param Z A (m x p) matrix with on-off values (logical or integer).
7978
#' @param w A vector with case weights (of the same length as the unstacked
8079
#' background data).
8180
#' @returns A (m x K) matrix with vz values.
82-
get_vz <- function(X, bg, Z, object, pred_fun, w, ...) {
81+
get_vz <- function(x, bg, Z, object, pred_fun, w, ...) {
8382
m <- nrow(Z)
84-
not_Z <- !Z
83+
if (!is.logical(Z)) {
84+
storage.mode(Z) <- "logical"
85+
}
8586
n_bg <- nrow(bg) / m # because bg was replicated m times
8687

87-
# Replicate not_Z, so that X, bg, not_Z are all of dimension (m*n_bg x p)
88+
# Replicate Z, so that bg and Z are of dimension (m*n_bg x p)
8889
g <- rep_each(m, each = n_bg)
89-
not_Z <- not_Z[g, , drop = FALSE]
90+
Z_rep <- Z[g, , drop = FALSE]
9091

91-
if (is.matrix(X)) {
92-
# Remember that columns of X and bg are perfectly aligned in this case
93-
X[not_Z] <- bg[not_Z]
94-
} else {
95-
for (v in colnames(Z)) {
96-
s <- not_Z[, v]
97-
X[[v]][s] <- bg[[v]][s]
92+
for (v in colnames(Z)) {
93+
s <- Z_rep[, v]
94+
if (is.matrix(x)) {
95+
bg[s, v] <- x[, v]
96+
} else {
97+
bg[[v]][s] <- x[[v]]
9898
}
9999
}
100-
preds <- align_pred(pred_fun(object, X, ...))
100+
101+
preds <- align_pred(pred_fun(object, bg, ...))
101102

102103
# Aggregate (distinguishing fast 1-dim case)
103104
if (ncol(preds) == 1L) {
@@ -331,8 +332,6 @@ basic_checks <- function(X, feature_names, pred_fun) {
331332
dim(X) >= 1L,
332333
length(feature_names) >= 1L,
333334
all(feature_names %in% colnames(X)),
334-
"If X is a matrix, feature_names must equal colnames(X)" =
335-
!is.matrix(X) || identical(colnames(X), feature_names),
336335
is.function(pred_fun)
337336
)
338337
TRUE

R/utils_kernelshap.R

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,21 @@
11
# Kernel SHAP algorithm for a single row x
22
# If exact, a single call to predict() is necessary.
33
# If sampling is involved, we need at least two additional calls to predict().
4-
kernelshap_one <- function(x, v1, object, pred_fun, feature_names, bg_w, exact, deg,
5-
m, tol, max_iter, v0, precalc, ...) {
4+
kernelshap_one <- function(
5+
x,
6+
v1,
7+
object,
8+
pred_fun,
9+
feature_names,
10+
bg_w,
11+
exact,
12+
deg,
13+
m,
14+
tol,
15+
max_iter,
16+
v0,
17+
precalc,
18+
...) {
619
p <- length(feature_names)
720
K <- ncol(v1)
821
K_names <- colnames(v1)
@@ -16,14 +29,8 @@ kernelshap_one <- function(x, v1, object, pred_fun, feature_names, bg_w, exact,
1629
v0_m_exact <- v0[rep.int(1L, m_exact), , drop = FALSE] # (m_ex x K)
1730

1831
# Most expensive part
19-
vz <- get_vz( # (m_ex x K)
20-
X = rep_rows(x, rep.int(1L, nrow(bg_X_exact))), # (m_ex*n_bg x p)
21-
bg = bg_X_exact, # (m_ex*n_bg x p)
22-
Z = Z, # (m_ex x p)
23-
object = object,
24-
pred_fun = pred_fun,
25-
w = bg_w,
26-
...
32+
vz <- get_vz(
33+
x = x, bg = bg_X_exact, Z = Z, object = object, pred_fun = pred_fun, w = bg_w, ...
2734
)
2835
# Note: w is correctly replicated along columns of (vz - v0_m_exact)
2936
b_exact <- crossprod(Z, precalc[["w"]] * (vz - v0_m_exact)) # (p x K)
@@ -37,7 +44,6 @@ kernelshap_one <- function(x, v1, object, pred_fun, feature_names, bg_w, exact,
3744

3845
# Iterative sampling part, always using A_exact and b_exact to fill up the weights
3946
bg_X_m <- precalc[["bg_X_m"]] # (m*n_bg x p)
40-
X <- rep_rows(x, rep.int(1L, nrow(bg_X_m))) # (m*n_bg x p)
4147
v0_m <- v0[rep.int(1L, m), , drop = FALSE] # (m x K)
4248
est_m <- array(
4349
data = 0, dim = c(max_iter, p, K), dimnames = list(NULL, feature_names, K_names)
@@ -60,7 +66,7 @@ kernelshap_one <- function(x, v1, object, pred_fun, feature_names, bg_w, exact,
6066

6167
# Expensive # (m x K)
6268
vz <- get_vz(
63-
X = X, bg = bg_X_m, Z = Z, object = object, pred_fun = pred_fun, w = bg_w, ...
69+
x = x, bg = bg_X_m, Z = Z, object = object, pred_fun = pred_fun, w = bg_w, ...
6470
)
6571

6672
# The sum of weights of A_exact and input[["A"]] is 1, same for b
@@ -152,7 +158,7 @@ input_sampling <- function(p, m, deg, feature_names) {
152158
# - A: Exact matrix A = Z'wZ
153159
input_exact <- function(p, feature_names) {
154160
Z <- exact_Z(p, feature_names = feature_names)
155-
Z <- Z[2L:nrow(Z) - 1L, , drop = FALSE]
161+
Z <- Z[2L:(nrow(Z) - 1L), , drop = FALSE]
156162
# Each Kernel weight(j) is divided by the number of vectors z having sum(z) = j
157163
w <- kernel_weights(p) / choose(p, 1:(p - 1L))
158164
list(Z = Z, w = w[rowSums(Z)], A = exact_A(p, feature_names = feature_names))

R/utils_permshap.R

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ permshap_one <- function(
2828
max_iter,
2929
...) {
3030
bg <- precalc$bg_X_rep
31-
X <- rep_rows(x, rep.int(1L, times = nrow(bg)))
3231

3332
p <- length(feature_names)
3433
K <- ncol(v1)
@@ -38,7 +37,7 @@ permshap_one <- function(
3837
if (exact) {
3938
Z <- precalc$Z # ((m_ex+2) x K)
4039
vz <- get_vz( # (m_ex x K)
41-
X = X,
40+
x = x,
4241
bg = bg,
4342
Z = Z[2L:(nrow(Z) - 1L), , drop = FALSE], # (m_ex x p)
4443
object = object,
@@ -52,7 +51,7 @@ permshap_one <- function(
5251
pos <- precalc$positions[[j]]
5352
beta_n[j, ] <- wcolMeans(
5453
vz[pos$on, , drop = FALSE] - vz[pos$off, , drop = FALSE],
55-
weights = precalc["shapley_w"][pos$on]
54+
w = precalc$shapley_w[pos$on]
5655
)
5756
}
5857
return(list(beta = beta_n))
@@ -68,7 +67,7 @@ permshap_one <- function(
6867

6968
# Pre-calculate part of Z with rowsum 1 or p - 1
7069
vz_balanced <- get_vz( # (2p x K)
71-
X = rep_rows(x, rep.int(1L, times = nrow(precalc$bg_X_balanced))),
70+
x = x,
7271
bg = precalc$bg_X_balanced,
7372
Z = precalc$Z_balanced,
7473
object = object,
@@ -90,13 +89,19 @@ permshap_one <- function(
9089
if (!low_memory) { # predictions for all chains at once
9190
Z <- do.call(rbind, Z)
9291
vz <- get_vz(
93-
X = X, bg = bg, Z = Z, object = object, pred_fun = pred_fun, w = bg_w, ...
92+
x = x, bg = bg, Z = Z, object = object, pred_fun = pred_fun, w = bg_w, ...
9493
)
9594
} else { # predictions for each chain separately
9695
vz <- vector("list", length = p)
9796
for (j in seq_len(p)) {
9897
vz[[j]] <- get_vz(
99-
X = X, bg = bg, Z = Z[[j]], object = object, pred_fun = pred_fun, w = bg_w, ...
98+
x = x,
99+
bg = bg,
100+
Z = Z[[j]],
101+
object = object,
102+
pred_fun = pred_fun,
103+
w = bg_w,
104+
...
100105
)
101106
}
102107
vz <- do.call(rbind, vz)

backlog/compare_with_python.R

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ system.time(
3131
)
3232
ks2
3333

34+
bench::mark(kernelshap(fit, X_small, bg_X = bg_X, verbose=F))
35+
# 2.17s 1.64GB -> 1.72s 1.44GB
36+
37+
bench::mark(kernelshap(fit, X_small, bg_X = bg_X, verbose=F, exact=F, hybrid_degree = 1))
38+
# 4.88s 2.79GB -> 4.38s 2.48GB
39+
40+
bench::mark(permshap(fit, X_small, bg_X = bg_X, verbose=F))
41+
# 1.97s 1.64GB -> 1.8s 1.43GB
42+
43+
bench::mark(permshap(fit, X_small, bg_X = bg_X, verbose=F, exact=F))
44+
# 3.04s 1.88GB -> 2.9s 1.64GB
45+
3446
# SHAP values of first 2 observations:
3547
# carat clarity color cut
3648
# [1,] -2.050074 -0.28048747 0.1281222 0.01587382
@@ -54,7 +66,7 @@ fit <- lm(
5466
X_small <- diamonds[seq(1, nrow(diamonds), 53), setdiff(names(diamonds), "price")]
5567

5668
# Exact KernelSHAP on X_small, using X_small as background data
57-
# (39s for exact, 15s for hybrid deg 2, 8s for hybrid deg 1, 16s for sampling)
69+
# (38s for exact, 11s for hybrid deg 2, 7s for hybrid deg 1, 11s for sampling)
5870
system.time(
5971
ks <- kernelshap(fit, X_small, bg_X = bg_X)
6072
)

man/kernelshap.Rd

Lines changed: 1 addition & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/permshap.Rd

Lines changed: 1 addition & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-basic.R

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,6 @@ test_that("Matrix input is fine", {
175175
expect_no_error( # additional cols in bg are ok
176176
algo(fit, X[J, x], pred_fun = pred_fun, bg_X = cbind(d = 1, X), verbose = FALSE)
177177
)
178-
expect_error( # feature_names are less flexible
179-
algo(fit, X[J, ],
180-
pred_fun = pred_fun, bg_X = X,
181-
verbose = FALSE, feature_names = "Sepal.Width"
182-
)
183-
)
184178
}
185179
})
186180

0 commit comments

Comments
 (0)