Skip to content

Commit 47b9c22

Browse files
committed
Logical on-off matrices
1 parent edbd7db commit 47b9c22

File tree

5 files changed

+20
-18
lines changed

5 files changed

+20
-18
lines changed

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
parallel mode, thanks [#163](https://github.com/ModelOriented/kernelshap/issues/163)
1919
for reporting. We have added a note in the function documentation that this warning
2020
can be ignored.
21+
22+
### Internal changes
23+
24+
- Matrices holding on-off vectors are now consistently of type logical ([#167](https://github.com/ModelOriented/kernelshap/pull/167)).
2125

2226
# kernelshap 0.8.0
2327

R/utils.R

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,14 @@ wcolMeans <- function(x, w = NULL, ...) {
4949
#'
5050
#' @param p Number of features.
5151
#' @param feature_names Feature names.
52-
#' @returns An integer matrix of all on-off vectors of length `p`.
52+
#' @returns An logical matrix of all on-off vectors of length `p`.
5353
exact_Z <- function(p, feature_names) {
5454
if (p < 2L) {
5555
stop("p must be at least 2 if exact = TRUE.")
5656
}
5757
m <- 2^p
5858
M <- seq_len(m) - 1L
59-
encoded <- as.integer(intToBits(M))
59+
encoded <- as.logical(intToBits(M))
6060
dim(encoded) <- c(32L, m)
6161
Z <- t(encoded[p:1L, , drop = FALSE])
6262
colnames(Z) <- feature_names
@@ -74,15 +74,12 @@ exact_Z <- function(p, feature_names) {
7474
#' @inheritParams kernelshap
7575
#' @param x Row to be explained.
7676
#' @param bg Background data stacked m times.
77-
#' @param Z A (m x p) matrix with on-off values (logical or integer).
77+
#' @param Z A logical (m x p) matrix with on-off values.
7878
#' @param w A vector with case weights (of the same length as the unstacked
7979
#' background data).
8080
#' @returns A (m x K) matrix with vz values.
8181
get_vz <- function(x, bg, Z, object, pred_fun, w, ...) {
8282
m <- nrow(Z)
83-
if (!is.logical(Z)) {
84-
storage.mode(Z) <- "logical"
85-
}
8683
n_bg <- nrow(bg) / m # because bg was replicated m times
8784

8885
# Replicate Z, so that bg and Z are of dimension (m*n_bg x p)

R/utils_kernelshap.R

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ sample_Z <- function(p, m, feature_names, S = 1:(p - 1L)) {
121121
# t(out)
122122

123123
# Vectorized by Mathias Ambuehl
124-
out <- rep(rep.int(0:1, m), as.vector(rbind(p - N, N)))
124+
out <- rep(rep.int(c(FALSE, TRUE), m), as.vector(rbind(p - N, N)))
125125
dim(out) <- c(p, m)
126126
ord <- order(col(out), sample.int(m * p))
127127
out[] <- out[ord]
@@ -143,7 +143,7 @@ input_sampling <- function(p, m, deg, feature_names) {
143143
}
144144
S <- (deg + 1L):(p - deg - 1L)
145145
Z <- sample_Z(p = p, m = m / 2, feature_names = feature_names, S = S)
146-
Z <- rbind(Z, 1 - Z)
146+
Z <- rbind(Z, !Z)
147147
w_total <- if (deg == 0L) 1 else 1 - 2 * sum(kernel_weights(p)[seq_len(deg)])
148148
w <- w_total / m
149149
list(Z = Z, w = rep.int(w, m), A = crossprod(Z) * w)
@@ -198,17 +198,18 @@ partly_exact_Z <- function(p, k, feature_names) {
198198
}
199199
if (k == 1L) {
200200
Z <- diag(p)
201+
storage.mode(Z) <- "logical"
201202
} else {
202203
Z <- t(
203204
utils::combn(seq_len(p), k, FUN = function(z) {
204-
x <- numeric(p)
205-
x[z] <- 1
205+
x <- logical(p)
206+
x[z] <- TRUE
206207
x
207208
})
208209
)
209210
}
210211
if (p != 2L * k) {
211-
Z <- rbind(Z, 1 - Z)
212+
Z <- rbind(Z, !Z)
212213
}
213214
colnames(Z) <- feature_names
214215
Z

R/utils_permshap.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,15 +240,15 @@ init_vzj <- function(p, v0, v1) {
240240
#' @noRd
241241
#' @keywords internal
242242
#'
243-
#' @param mask The output of `exact_Z(p, feature_names)`.
243+
#' @param mask Logical matrix. The output of `exact_Z(p, feature_names)`.
244244
#' @returns List with p elements, each containing an `on` and `off` vector.
245245
positions_for_exact <- function(mask) {
246246
p <- ncol(mask)
247247
codes <- seq_len(nrow(mask)) # Row index = binary code of the row
248248

249249
positions <- vector("list", p)
250250
for (j in seq_len(p)) {
251-
on <- codes[as.logical(mask[, j])]
251+
on <- codes[mask[, j]]
252252
off <- on - 2^(p - j) # trick to turn "bit" off
253253
positions[[j]] <- list(on = on, off = off)
254254
}

backlog/compare_with_python.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ bg_X <- diamonds[seq(1, nrow(diamonds), 450), ]
1414
# Subset of 1018 diamonds to explain
1515
X_small <- diamonds[seq(1, nrow(diamonds), 53), c("carat", ord)]
1616

17-
# Exact KernelSHAP (2s)
17+
# Exact KernelSHAP (1.8s)
1818
system.time(
1919
ks <- kernelshap(fit, X_small, bg_X = bg_X)
2020
)
@@ -25,23 +25,23 @@ ks
2525
# [1,] -2.050074 -0.28048747 0.1281222 0.01587382
2626
# [2,] -2.085838 0.04050415 0.1283010 0.03731644
2727

28-
# Pure sampling version takes a bit longer (7 seconds)
28+
# Pure sampling version takes a bit longer (6.6 seconds)
2929
system.time(
3030
ks2 <- kernelshap(fit, X_small, bg_X = bg_X, exact = FALSE, hybrid_degree = 0)
3131
)
3232
ks2
3333

3434
bench::mark(kernelshap(fit, X_small, bg_X = bg_X, verbose=F))
35-
# 2.17s 1.64GB -> 1.72s 1.44GB
35+
# 2.17s 1.64GB -> 1.79s 1.43GB
3636

3737
bench::mark(kernelshap(fit, X_small, bg_X = bg_X, verbose=F, exact=F, hybrid_degree = 1))
3838
# 4.88s 2.79GB -> 4.38s 2.48GB
3939

4040
bench::mark(permshap(fit, X_small, bg_X = bg_X, verbose=F))
41-
# 1.97s 1.64GB -> 1.8s 1.43GB
41+
# 1.97s 1.64GB -> 1.9s 1.43GB
4242

4343
bench::mark(permshap(fit, X_small, bg_X = bg_X, verbose=F, exact=F))
44-
# 3.04s 1.88GB -> 2.9s 1.64GB
44+
# 3.04s 1.88GB -> 2.8s 1.64GB
4545

4646
# SHAP values of first 2 observations:
4747
# carat clarity color cut

0 commit comments

Comments
 (0)