@@ -107,7 +107,7 @@ solver <- function(A, b, constraint) {
107107# to Kernel SHAP weights -> (m x p) matrix.
108108# The argument S can be used to restrict the range of sum(z).
109109sample_Z <- function (p , m , feature_names , S = 1 : (p - 1L )) {
110- probs <- kernel_weights_per_coalition_size( p , S = S )
110+ probs <- kernel_weights( p , per_coalition_size = TRUE , S = S )
111111 N <- S [sample.int(length(S ), m , replace = TRUE , prob = probs )]
112112
113113 # Then, conditional on that number, set random positions of z to 1
@@ -158,7 +158,7 @@ input_sampling <- function(p, m, deg, feature_names) {
158158input_exact <- function (p , feature_names ) {
159159 Z <- exact_Z(p , feature_names = feature_names )
160160 Z <- Z [2L : (nrow(Z ) - 1L ), , drop = FALSE ]
161- kw <- kernel_weights(p ) # Kernel weights for all subsets
161+ kw <- kernel_weights(p , per_coalition_size = FALSE ) # Kernel weights for all subsets
162162 w <- kw [rowSums(Z )] # Corresponding weight for each row in Z
163163 w <- w / sum(w )
164164 list (Z = Z , w = w , A = crossprod(Z , w * Z ))
@@ -203,7 +203,7 @@ input_partly_exact <- function(p, deg, feature_names) {
203203 stop(" p must be >=2*deg" )
204204 }
205205
206- kw <- kernel_weights(p )
206+ kw <- kernel_weights(p , per_coalition_size = FALSE )
207207
208208 Z <- vector(" list" , deg )
209209 for (k in seq_len(deg )) {
@@ -216,15 +216,16 @@ input_partly_exact <- function(p, deg, feature_names) {
216216 list (Z = Z , w = w , A = crossprod(Z , w * Z ))
217217}
218218
219- # Kernel weights
220- kernel_weights <- function (p , S = seq_len(p - 1L )) {
221- probs <- (p - 1L ) / (choose(p , S ) * S * (p - S ))
222- return (probs / sum(probs ))
223- }
224-
225- # Kernel weights per coalition size
226- kernel_weights_per_coalition_size <- function (p , S = seq_len(p - 1L )) {
227- probs <- 1 / (S * (p - S ))
219+ # Kernel weight distribution
220+ #
221+ # `per_coalition_size = TRUE` is required, e.g., when one wants to sample random masks
222+ # according to the Kernel SHAP distribution: Pick a coalition size as per
223+ # these weights, then randomly place "on" positions. `FALSE` refer to weights
224+ # if all masks has been calculated and one wants to calculate their weights based
225+ # on the number of "on" positions.
226+ kernel_weights <- function (p , per_coalition_size , S = seq_len(p - 1L )) {
227+ const <- if (per_coalition_size ) 1 else choose(p , S )
228+ probs <- (p - 1 ) / (const * S * (p - S )) # could drop the numerator
228229 return (probs / sum(probs ))
229230}
230231
@@ -234,7 +235,7 @@ prop_exact <- function(p, deg) {
234235 if (deg == 0 ) {
235236 return (0 )
236237 }
237- w <- kernel_weights_per_coalition_size( p )
238+ w <- kernel_weights( p , per_coalition_size = TRUE )
238239 w_total <- 2 * sum(w [seq_len(deg )]) - w [deg ] * (p == 2 * deg )
239240 return (w_total )
240241}
0 commit comments