Skip to content

Commit 78da770

Browse files
committed
New kernelshap solver
1 parent db1b342 commit 78da770

File tree

5 files changed

+33
-8
lines changed

5 files changed

+33
-8
lines changed

DESCRIPTION

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ RoxygenNote: 7.3.2
2828
Imports:
2929
doFuture,
3030
foreach,
31-
MASS,
3231
stats,
3332
utils
3433
Suggests:

NEWS.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,23 @@ The list is passed to `[foreach::foreach(.options.future = ...)]`.
3131
### Internal changes
3232

3333
- Matrices holding on-off vectors are now consistently of type logical ([#167](https://github.com/ModelOriented/kernelshap/pull/167)).
34+
- `kernelshap()` solver: Replacing the Moore-Penrose pseudo-inverse by two direct solves, a trick of [Ian Covert](https://github.com/iancovert/shapley-regression/blob/master/shapreg/shapley.py),
35+
and ported to R in ([#171](https://github.com/ModelOriented/kernelshap/pull/171)).
3436

3537
### Changes in parallelism
3638

3739
We have switched from `%dopar%` to `doFuture` ([#170](https://github.com/ModelOriented/kernelshap/pull/170)) with the following impact:
3840

3941
- No need for calling `registerDoFuture()` anymore.
4042
- Random seeding is properly handled, and respects `seed`, thanks [#163](https://github.com/ModelOriented/kernelshap/issues/163) for reporting.
41-
- {doFuture} is listed under "imports", not as "suggested".
4243
- If missing packages or globals have to be specified, this now has to be done through `parallel_args = list(packages = ..., globals = ...)`
4344
instead of `parallel_args = list(.packages = ..., .globals = ...)`. The list is passed to `[foreach::foreach(.options.future = ...)]`.
4445

46+
### Dependencies
47+
48+
- {MASS}: Dropped from imports
49+
- {doFuture}: suggests -> imports
50+
4551
# kernelshap 0.8.0
4652

4753
### Major improvement

R/utils_kernelshap.R

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,15 @@ kernelshap_one <- function(
9595

9696
# Regression coefficients given sum(beta) = constraint
9797
# A: (p x p), b: (p x k), constraint: (1 x K)
98+
# Full credits: https://github.com/iancovert/shapley-regression/blob/master/shapreg/shapley.py
9899
solver <- function(A, b, constraint) {
99-
p <- ncol(A)
100-
Ainv <- MASS::ginv(A)
101-
dimnames(Ainv) <- dimnames(A)
102-
s <- (matrix(colSums(Ainv %*% b), nrow = 1L) - constraint) / sum(Ainv) # (1 x K)
103-
Ainv %*% (b - s[rep.int(1L, p), , drop = FALSE]) # (p x K)
100+
Ainv1 <- solve(A, matrix(1, nrow = nrow(A)))
101+
Ainvb <- solve(A, b)
102+
num <- rbind(colSums(Ainvb)) - constraint
103+
return(Ainvb - Ainv1 %*% num / sum(Ainv1))
104104
}
105105

106+
106107
# Draw m binary vectors z of length p with sum(z) distributed according
107108
# to Kernel SHAP weights -> (m x p) matrix.
108109
# The argument S can be used to restrict the range of sum(z).

packaging.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ use_description(
3939

4040
use_package("doFuture", "Imports")
4141
use_package("foreach", "Imports")
42-
use_package("MASS", "Imports")
4342
use_package("stats", "Imports")
4443
use_package("utils", "Imports")
4544

tests/testthat/test-kernelshap-utils.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,23 @@ test_that("input_partly_exact(p, deg) fails for bad p or deg", {
128128
expect_error(input_partly_exact(2L, deg = 0L, feature_names = LETTERS[1:p]))
129129
expect_error(input_partly_exact(5L, deg = 3L, feature_names = LETTERS[1:p]))
130130
})
131+
132+
test_that("new solver gives same results as original one", {
133+
solver_old <- function(A, b, constraint) {
134+
p <- ncol(A)
135+
Ainv <- solve(A) # was actually: Ainv <- MASS::ginv(A)
136+
dimnames(Ainv) <- dimnames(A)
137+
s <- (matrix(colSums(Ainv %*% b), nrow = 1L) - constraint) / sum(Ainv) # (1 x K)
138+
Ainv %*% (b - s[rep.int(1L, p), , drop = FALSE]) # (p x K)
139+
}
140+
141+
A <- matrix(seq(0.1, 0.20, length.out = 25), ncol = 5)
142+
diag(A) <- 0.5
143+
b <- cbind(1:5)
144+
constraint <- rbind(8)
145+
expect_equal(solver_old(A, b, constraint), solver(A, b, constraint))
146+
147+
b <- cbind(1:5, seq(2, 10, by = 2))
148+
constraint <- rbind(1:2)
149+
expect_equal(solver_old(A, b, constraint), solver(A, b, constraint))
150+
})

0 commit comments

Comments
 (0)