Skip to content

Commit 27d92bf

Browse files
committed
Add test to cover model with high interaction
1 parent 895dfa7 commit 27d92bf

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

tests/testthat/test-basic.R

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,124 @@ test_that("kernelshap works for large p (hybrid case)", {
213213
expect_equal(s$baseline, mean(y))
214214
expect_equal(rowSums(s$S) + s$baseline, unname(predict(fit, X[1L, ])))
215215
})
216+
217+
test_that("kernelshap and permshap work for models with high-order interactions", {
218+
# Expected: Python output
219+
# import numpy as np
220+
# import shap 0.47.2
221+
#
222+
# X = np.array(
223+
# [
224+
# np.arange(1, 101) / 100,
225+
# np.log(np.arange(1, 101)),
226+
# np.sqrt(np.arange(1, 101)),
227+
# np.sin(np.arange(1, 101)),
228+
# (np.arange(1, 101) / 100) ** 2,
229+
# np.cos(np.arange(1, 101)),
230+
# ]
231+
# ).T
232+
#
233+
#
234+
# def predict(X):
235+
# return X[:, 0] * X[:, 1] * X[:, 2] * X[:, 3] + X[:, 4] + X[:, 5]
236+
#
237+
#
238+
# ks = shap.explainers.Kernel(predict, X, nsamples=10000)
239+
# es = shap.explainers.Exact(predict, X)
240+
#
241+
# print("Exact Kernel SHAP:\n", ks(X[0:2]).values)
242+
# print("Exact (Permutation) SHAP:\n", es(X[0:2]).values)
243+
#
244+
# # Exact Kernel SHAP:
245+
# # [[-1.19621609 -1.24184808 -0.9567848 3.87942037 -0.33825 0.54562519]
246+
# # [-1.64922699 -1.20770105 -1.18388581 4.54321217 -0.33795 -0.41082395]]
247+
# # Exact (Permutation) SHAP:
248+
# # [[-1.19621609 -1.24184808 -0.9567848 3.87942037 -0.33825 0.54562519]
249+
# # [-1.64922699 -1.20770105 -1.18388581 4.54321217 -0.33795 -0.41082395]]
250+
251+
expected <- rbind(
252+
c(-1.19621609, -1.24184808, -0.9567848, 3.87942037, -0.33825, 0.54562519),
253+
c(-1.64922699, -1.20770105, -1.18388581, 4.54321217, -0.33795, -0.41082395)
254+
)
255+
256+
n <- 100
257+
258+
X <- data.frame(
259+
x1 = seq(1:n) / 100,
260+
x2 = log(1:n),
261+
x3 = sqrt(1:n),
262+
x4 = sin(1:n),
263+
x5 = (seq(1:n) / 100)^2,
264+
x6 = cos(1:n)
265+
)
266+
267+
pf <- function(model, newdata) {
268+
x <- newdata
269+
x[, 1] * x[, 2] * x[, 3] * x[, 4] + x[, 5] + x[, 6]
270+
}
271+
ks <- kernelshap(pf, head(X, 2), bg_X = X, pred_fun = pf, verbose = FALSE)
272+
expect_equal(unname(ks$S), expected)
273+
274+
ps <- permshap(pf, head(X, 2), bg_X = X, pred_fun = pf, verbose = FALSE)
275+
expect_equal(unname(ps$S), expected)
276+
277+
# Sampling versions of KernelSHAP is quite close
278+
set.seed(1)
279+
ksh2 <- kernelshap(
280+
pf,
281+
head(X, 1),
282+
bg_X = X,
283+
pred_fun = pf,
284+
hybrid_degree = 2,
285+
exact = FALSE,
286+
m = 1000,
287+
max_iter = 100, ,
288+
tol = 0.001,
289+
verbose = FALSE
290+
)
291+
expect_equal(
292+
c(ksh2$S),
293+
c(-1.194878, -1.24747, -0.9596389, 3.883523, -0.3349787, 0.5453894),
294+
tolerance = 1e-4
295+
)
296+
297+
set.seed(1)
298+
ksh1 <- kernelshap(
299+
pf,
300+
head(X, 1),
301+
bg_X = X,
302+
pred_fun = pf,
303+
hybrid_degree = 1,
304+
exact = FALSE,
305+
m = 1000,
306+
max_iter = 1000,
307+
tol = 0.001,
308+
verbose = FALSE
309+
)
310+
expect_equal(
311+
c(ksh1$S),
312+
c(-1.199874, -1.23913, -0.9575389, 3.884381, -0.3451588, 0.5492675),
313+
tolerance = 1e-4
314+
)
315+
316+
set.seed(1)
317+
ksh0 <- suppressWarnings(
318+
kernelshap(
319+
pf,
320+
head(X, 1),
321+
bg_X = X,
322+
pred_fun = pf,
323+
hybrid_degree = 0,
324+
exact = FALSE,
325+
m = 10000,
326+
max_iter = 10000,
327+
tol = 0.002,
328+
verbose = FALSE
329+
)
330+
)
331+
expect_equal(
332+
c(ksh0$S),
333+
c(-1.191228, -1.235814, -0.9362117, 3.849839, -0.3371862, 0.5425477),
334+
tolerance = 1e-4
335+
)
336+
})

0 commit comments

Comments
 (0)