@@ -213,3 +213,124 @@ test_that("kernelshap works for large p (hybrid case)", {
213213 expect_equal(s $ baseline , mean(y ))
214214 expect_equal(rowSums(s $ S ) + s $ baseline , unname(predict(fit , X [1L , ])))
215215})
216+
217+ test_that(" kernelshap and permshap work for models with high-order interactions" , {
218+ # Expected: Python output
219+ # import numpy as np
220+ # import shap 0.47.2
221+ #
222+ # X = np.array(
223+ # [
224+ # np.arange(1, 101) / 100,
225+ # np.log(np.arange(1, 101)),
226+ # np.sqrt(np.arange(1, 101)),
227+ # np.sin(np.arange(1, 101)),
228+ # (np.arange(1, 101) / 100) ** 2,
229+ # np.cos(np.arange(1, 101)),
230+ # ]
231+ # ).T
232+ #
233+ #
234+ # def predict(X):
235+ # return X[:, 0] * X[:, 1] * X[:, 2] * X[:, 3] + X[:, 4] + X[:, 5]
236+ #
237+ #
238+ # ks = shap.explainers.Kernel(predict, X, nsamples=10000)
239+ # es = shap.explainers.Exact(predict, X)
240+ #
241+ # print("Exact Kernel SHAP:\n", ks(X[0:2]).values)
242+ # print("Exact (Permutation) SHAP:\n", es(X[0:2]).values)
243+ #
244+ # # Exact Kernel SHAP:
245+ # # [[-1.19621609 -1.24184808 -0.9567848 3.87942037 -0.33825 0.54562519]
246+ # # [-1.64922699 -1.20770105 -1.18388581 4.54321217 -0.33795 -0.41082395]]
247+ # # Exact (Permutation) SHAP:
248+ # # [[-1.19621609 -1.24184808 -0.9567848 3.87942037 -0.33825 0.54562519]
249+ # # [-1.64922699 -1.20770105 -1.18388581 4.54321217 -0.33795 -0.41082395]]
250+
251+ expected <- rbind(
252+ c(- 1.19621609 , - 1.24184808 , - 0.9567848 , 3.87942037 , - 0.33825 , 0.54562519 ),
253+ c(- 1.64922699 , - 1.20770105 , - 1.18388581 , 4.54321217 , - 0.33795 , - 0.41082395 )
254+ )
255+
256+ n <- 100
257+
258+ X <- data.frame (
259+ x1 = seq(1 : n ) / 100 ,
260+ x2 = log(1 : n ),
261+ x3 = sqrt(1 : n ),
262+ x4 = sin(1 : n ),
263+ x5 = (seq(1 : n ) / 100 )^ 2 ,
264+ x6 = cos(1 : n )
265+ )
266+
267+ pf <- function (model , newdata ) {
268+ x <- newdata
269+ x [, 1 ] * x [, 2 ] * x [, 3 ] * x [, 4 ] + x [, 5 ] + x [, 6 ]
270+ }
271+ ks <- kernelshap(pf , head(X , 2 ), bg_X = X , pred_fun = pf , verbose = FALSE )
272+ expect_equal(unname(ks $ S ), expected )
273+
274+ ps <- permshap(pf , head(X , 2 ), bg_X = X , pred_fun = pf , verbose = FALSE )
275+ expect_equal(unname(ps $ S ), expected )
276+
277+ # Sampling versions of KernelSHAP is quite close
278+ set.seed(1 )
279+ ksh2 <- kernelshap(
280+ pf ,
281+ head(X , 1 ),
282+ bg_X = X ,
283+ pred_fun = pf ,
284+ hybrid_degree = 2 ,
285+ exact = FALSE ,
286+ m = 1000 ,
287+ max_iter = 100 , ,
288+ tol = 0.001 ,
289+ verbose = FALSE
290+ )
291+ expect_equal(
292+ c(ksh2 $ S ),
293+ c(- 1.194878 , - 1.24747 , - 0.9596389 , 3.883523 , - 0.3349787 , 0.5453894 ),
294+ tolerance = 1e-4
295+ )
296+
297+ set.seed(1 )
298+ ksh1 <- kernelshap(
299+ pf ,
300+ head(X , 1 ),
301+ bg_X = X ,
302+ pred_fun = pf ,
303+ hybrid_degree = 1 ,
304+ exact = FALSE ,
305+ m = 1000 ,
306+ max_iter = 1000 ,
307+ tol = 0.001 ,
308+ verbose = FALSE
309+ )
310+ expect_equal(
311+ c(ksh1 $ S ),
312+ c(- 1.199874 , - 1.23913 , - 0.9575389 , 3.884381 , - 0.3451588 , 0.5492675 ),
313+ tolerance = 1e-4
314+ )
315+
316+ set.seed(1 )
317+ ksh0 <- suppressWarnings(
318+ kernelshap(
319+ pf ,
320+ head(X , 1 ),
321+ bg_X = X ,
322+ pred_fun = pf ,
323+ hybrid_degree = 0 ,
324+ exact = FALSE ,
325+ m = 10000 ,
326+ max_iter = 10000 ,
327+ tol = 0.002 ,
328+ verbose = FALSE
329+ )
330+ )
331+ expect_equal(
332+ c(ksh0 $ S ),
333+ c(- 1.191228 , - 1.235814 , - 0.9362117 , 3.849839 , - 0.3371862 , 0.5425477 ),
334+ tolerance = 1e-4
335+ )
336+ })
0 commit comments