@@ -107,8 +107,7 @@ solver <- function(A, b, constraint) {
107107# to Kernel SHAP weights -> (m x p) matrix.
108108# The argument S can be used to restrict the range of sum(z).
109109sample_Z <- function (p , m , feature_names , S = 1 : (p - 1L )) {
110- # First draw s = sum(z) according to Kernel weights (renormalized to sum 1)
111- probs <- kernel_weights(p , S = S )
110+ probs <- kernel_weights_per_coalition_size(p , S = S )
112111 N <- S [sample.int(length(S ), m , replace = TRUE , prob = probs )]
113112
114113 # Then, conditional on that number, set random positions of z to 1
@@ -144,8 +143,8 @@ input_sampling <- function(p, m, deg, feature_names) {
144143 S <- (deg + 1L ): (p - deg - 1L )
145144 Z <- sample_Z(p = p , m = m / 2 , feature_names = feature_names , S = S )
146145 Z <- rbind(Z , ! Z )
147- w_total <- if (deg == 0L ) 1 else 1 - 2 * sum(kernel_weights( p )[seq_len( deg )] )
148- w <- w_total / m
146+ w <- if (deg == 0L ) 1 else 1 - prop_exact( p , deg = deg )
147+ w <- w / m
149148 list (Z = Z , w = rep.int(w , m ), A = crossprod(Z ) * w )
150149}
151150
@@ -159,33 +158,10 @@ input_sampling <- function(p, m, deg, feature_names) {
159158input_exact <- function (p , feature_names ) {
160159 Z <- exact_Z(p , feature_names = feature_names )
161160 Z <- Z [2L : (nrow(Z ) - 1L ), , drop = FALSE ]
162- # Each Kernel weight(j) is divided by the number of vectors z having sum(z) = j
163- w <- kernel_weights(p ) / choose(p , 1 : (p - 1L ))
164- list (Z = Z , w = w [rowSums(Z )], A = exact_A(p , feature_names = feature_names ))
165- }
166-
167- # ' Exact Matrix A
168- # '
169- # ' Internal function that calculates exact A.
170- # ' Notice the difference to the off-diagnonals in the Supplement of
171- # ' Covert and Lee (2021). Credits to David Watson for figuring out the correct formula,
172- # ' see our discussions in https://github.com/ModelOriented/kernelshap/issues/22
173- # '
174- # ' @noRd
175- # ' @keywords internal
176- # '
177- # ' @param p Number of features.
178- # ' @param feature_names Feature names.
179- # ' @returns A (p x p) matrix.
180- exact_A <- function (p , feature_names ) {
181- S <- 1 : (p - 1L )
182- c_pr <- S * (S - 1 ) / p / (p - 1 )
183- off_diag <- sum(kernel_weights(p ) * c_pr )
184- A <- matrix (
185- data = off_diag , nrow = p , ncol = p , dimnames = list (feature_names , feature_names )
186- )
187- diag(A ) <- 0.5
188- A
161+ kw <- kernel_weights(p ) # Kernel weights for all subsets
162+ w <- kw [rowSums(Z )] # Corresponding weight for each row in Z
163+ w <- w / sum(w )
164+ list (Z = Z , w = w , A = crossprod(Z , w * Z ))
189165}
190166
191167# List all length p vectors z with sum(z) in {k, p - k}
@@ -228,22 +204,37 @@ input_partly_exact <- function(p, deg, feature_names) {
228204 }
229205
230206 kw <- kernel_weights(p )
231- Z <- w <- vector(" list" , deg )
232207
208+ Z <- vector(" list" , deg )
233209 for (k in seq_len(deg )) {
234210 Z [[k ]] <- partly_exact_Z(p , k = k , feature_names = feature_names )
235- n <- nrow(Z [[k ]])
236- w_tot <- kw [k ] * (2 - (p == 2L * k ))
237- w [[k ]] <- rep.int(w_tot / n , n )
238211 }
239- w <- unlist(w , recursive = FALSE , use.names = FALSE )
240212 Z <- do.call(rbind , Z )
241-
213+ w <- kw [rowSums(Z )]
214+ w_target <- prop_exact(p , deg = deg ) # How much of total weight to spend here
215+ w <- w / sum(w ) * w_target
242216 list (Z = Z , w = w , A = crossprod(Z , w * Z ))
243217}
244218
245- # Kernel weights normalized to a non-empty subset S of {1, ..., p-1}
219+ # Kernel weights
246220kernel_weights <- function (p , S = seq_len(p - 1L )) {
247221 probs <- (p - 1L ) / (choose(p , S ) * S * (p - S ))
248- probs / sum(probs )
222+ return (probs / sum(probs ))
223+ }
224+
225+ # Kernel weights per coalition size
226+ kernel_weights_per_coalition_size <- function (p , S = seq_len(p - 1L )) {
227+ probs <- 1 / (S * (p - S ))
228+ return (probs / sum(probs ))
229+ }
230+
231+ # How much Kernel SHAP weights do coalitions of size
232+ # {1, ..., deg, ..., p-deg-1 ..., p-1} have?
233+ prop_exact <- function (p , deg ) {
234+ if (deg == 0 ) {
235+ return (0 )
236+ }
237+ w <- kernel_weights_per_coalition_size(p )
238+ w_total <- 2 * sum(w [seq_len(deg )]) - w [deg ] * (p == 2 * deg )
239+ return (w_total )
249240}
0 commit comments