@@ -108,7 +108,7 @@ solver <- function(A, b, constraint) {
108108# to Kernel SHAP weights -> (m x p) matrix.
109109# The argument S can be used to restrict the range of sum(z).
110110sample_Z <- function (p , m , feature_names , S = 1 : (p - 1L )) {
111- probs <- kernel_weights( p , per_coalition_size = TRUE , S = S )
111+ probs <- kernel_weights_per_coalition_size( p , S = S )
112112 N <- S [sample.int(length(S ), m , replace = TRUE , prob = probs )]
113113
114114 # Then, conditional on that number, set random positions of z to 1
@@ -159,7 +159,7 @@ input_sampling <- function(p, m, deg, feature_names) {
159159input_exact <- function (p , feature_names ) {
160160 Z <- exact_Z(p , feature_names = feature_names )
161161 Z <- Z [2L : (nrow(Z ) - 1L ), , drop = FALSE ]
162- kw <- kernel_weights(p , per_coalition_size = FALSE ) # Kernel weights for all subsets
162+ kw <- kernel_weights(p ) # Kernel weights for all subsets
163163 w <- kw [rowSums(Z )] # Corresponding weight for each row in Z
164164 w <- w / sum(w )
165165 list (Z = Z , w = w , A = crossprod(Z , w * Z ))
@@ -204,7 +204,7 @@ input_partly_exact <- function(p, deg, feature_names) {
204204 stop(" p must be >=2*deg" )
205205 }
206206
207- kw <- kernel_weights(p , per_coalition_size = FALSE )
207+ kw <- kernel_weights(p )
208208
209209 Z <- vector(" list" , deg )
210210 for (k in seq_len(deg )) {
@@ -217,16 +217,17 @@ input_partly_exact <- function(p, deg, feature_names) {
217217 list (Z = Z , w = w , A = crossprod(Z , w * Z ))
218218}
219219
220- # Kernel weight distribution
221- #
222- # `per_coalition_size = TRUE` is required, e.g., when one wants to sample random masks
223- # according to the Kernel SHAP distribution: Pick a coalition size as per
224- # these weights, then randomly place "on" positions. `FALSE` refer to weights
225- # if all masks has been calculated and one wants to calculate their weights based
226- # on the number of "on" positions.
227- kernel_weights <- function (p , per_coalition_size , S = seq_len(p - 1L )) {
228- const <- if (per_coalition_size ) 1 else choose(p , S )
229- probs <- (p - 1 ) / (const * S * (p - S )) # could drop the numerator
220+ # Kernel weight distribution. Gives the weight of each coalition vector of sum k
221+ kernel_weights <- function (p ) {
222+ S <- seq_len(p - 1L )
223+ probs <- 1 / (choose(p , S ) * S * (p - S ))
224+ return (probs / sum(probs ))
225+ }
226+
227+ # Kernel weights per coalition size. Sums the kernel_weights over the number of
228+ # coalitions with same sum.
229+ kernel_weights_per_coalition_size <- function (p , S = seq_len(p - 1L )) {
230+ probs <- 1 / (S * (p - S ))
230231 return (probs / sum(probs ))
231232}
232233
@@ -236,7 +237,7 @@ prop_exact <- function(p, deg) {
236237 if (deg == 0 ) {
237238 return (0 )
238239 }
239- w <- kernel_weights( p , per_coalition_size = TRUE )
240+ w <- kernel_weights_per_coalition_size( p )
240241 w_total <- 2 * sum(w [seq_len(deg )]) - w [deg ] * (p == 2 * deg )
241242 return (w_total )
242243}
0 commit comments