11# Kernel SHAP algorithm for a single row x
22# If exact, a single call to predict() is necessary.
33# If sampling is involved, we need at least two additional calls to predict().
4- kernelshap_one <- function (x , v1 , object , pred_fun , feature_names , bg_w , exact , deg ,
5- m , tol , max_iter , v0 , precalc , ... ) {
4+ kernelshap_one <- function (
5+ x ,
6+ v1 ,
7+ object ,
8+ pred_fun ,
9+ feature_names ,
10+ bg_w ,
11+ exact ,
12+ deg ,
13+ m ,
14+ tol ,
15+ max_iter ,
16+ v0 ,
17+ precalc ,
18+ ... ) {
619 p <- length(feature_names )
720 K <- ncol(v1 )
821 K_names <- colnames(v1 )
@@ -16,14 +29,8 @@ kernelshap_one <- function(x, v1, object, pred_fun, feature_names, bg_w, exact,
1629 v0_m_exact <- v0 [rep.int(1L , m_exact ), , drop = FALSE ] # (m_ex x K)
1730
1831 # Most expensive part
19- vz <- get_vz( # (m_ex x K)
20- X = rep_rows(x , rep.int(1L , nrow(bg_X_exact ))), # (m_ex*n_bg x p)
21- bg = bg_X_exact , # (m_ex*n_bg x p)
22- Z = Z , # (m_ex x p)
23- object = object ,
24- pred_fun = pred_fun ,
25- w = bg_w ,
26- ...
32+ vz <- get_vz(
33+ x = x , bg = bg_X_exact , Z = Z , object = object , pred_fun = pred_fun , w = bg_w , ...
2734 )
2835 # Note: w is correctly replicated along columns of (vz - v0_m_exact)
2936 b_exact <- crossprod(Z , precalc [[" w" ]] * (vz - v0_m_exact )) # (p x K)
@@ -37,7 +44,6 @@ kernelshap_one <- function(x, v1, object, pred_fun, feature_names, bg_w, exact,
3744
3845 # Iterative sampling part, always using A_exact and b_exact to fill up the weights
3946 bg_X_m <- precalc [[" bg_X_m" ]] # (m*n_bg x p)
40- X <- rep_rows(x , rep.int(1L , nrow(bg_X_m ))) # (m*n_bg x p)
4147 v0_m <- v0 [rep.int(1L , m ), , drop = FALSE ] # (m x K)
4248 est_m <- array (
4349 data = 0 , dim = c(max_iter , p , K ), dimnames = list (NULL , feature_names , K_names )
@@ -60,7 +66,7 @@ kernelshap_one <- function(x, v1, object, pred_fun, feature_names, bg_w, exact,
6066
6167 # Expensive # (m x K)
6268 vz <- get_vz(
63- X = X , bg = bg_X_m , Z = Z , object = object , pred_fun = pred_fun , w = bg_w , ...
69+ x = x , bg = bg_X_m , Z = Z , object = object , pred_fun = pred_fun , w = bg_w , ...
6470 )
6571
6672 # The sum of weights of A_exact and input[["A"]] is 1, same for b
@@ -151,7 +157,8 @@ input_sampling <- function(p, m, deg, feature_names) {
151157# the SHAP kernel distribution
152158# - A: Exact matrix A = Z'wZ
153159input_exact <- function (p , feature_names ) {
154- Z <- exact_Z(p , feature_names = feature_names , keep_extremes = FALSE )
160+ Z <- exact_Z(p , feature_names = feature_names )
161+ Z <- Z [2L : (nrow(Z ) - 1L ), , drop = FALSE ]
155162 # Each Kernel weight(j) is divided by the number of vectors z having sum(z) = j
156163 w <- kernel_weights(p ) / choose(p , 1 : (p - 1L ))
157164 list (Z = Z , w = w [rowSums(Z )], A = exact_A(p , feature_names = feature_names ))
0 commit comments