Skip to content

Commit db1b342

Browse files
authored
Merge pull request #170 from ModelOriented/seed-parallel
Add seed and change parallelism
2 parents b9ff089 + 60ca1a3 commit db1b342

File tree

12 files changed

+148
-79
lines changed

12 files changed

+148
-79
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ Encoding: UTF-8
2626
Roxygen: list(markdown = TRUE)
2727
RoxygenNote: 7.3.2
2828
Imports:
29+
doFuture,
2930
foreach,
3031
MASS,
3132
stats,
3233
utils
3334
Suggests:
34-
doFuture,
3535
testthat (>= 3.0.0)
3636
Config/testthat/edition: 3
3737
URL: https://github.com/ModelOriented/kernelshap

NAMESPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ export(additive_shap)
1010
export(is.kernelshap)
1111
export(kernelshap)
1212
export(permshap)
13-
importFrom(foreach,"%dopar%")
13+
importFrom(doFuture,"%dofuture%")

NEWS.md

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,33 @@ unit tests against Python's "shap".
1515
### API
1616

1717
- The argument `feature_names` can now also be used with matrix input ([#166](https://github.com/ModelOriented/kernelshap/pull/166)).
18+
- `kernelshap()` and `permshap()` have received a `seed = NULL` argument ([#170](https://github.com/ModelOriented/kernelshap/pull/170)).
19+
- Parallel mode: If missing packages or globals have to be specified, this now has to be done through `parallel_args = list(packages = ..., globals = ...)`
20+
instead of `parallel_args = list(.packages = ..., .globals = ...)`, see section on parallelism below.
21+
The list is passed to `[foreach::foreach(.options.future = ...)]`.
1822

1923
### Speed and memory improvements
2024

2125
- `permshap()` and `kernelshap()` require about 10% less memory ([#166](https://github.com/ModelOriented/kernelshap/pull/166)).
2226
- `permshap()` and `kernelshap()` are faster for data.frame input,
2327
and slightly slower for matrix input ([#166](https://github.com/ModelOriented/kernelshap/pull/166)).
2428
- Additionally, `permshap(, exact = TRUE)` is faster by pre-calculating more
25-
elements used across rows [#165](https://github.com/ModelOriented/kernelshap/pull/165)
29+
elements used across rows ([#165](https://github.com/ModelOriented/kernelshap/pull/165)).
2630

27-
### Documentation
28-
29-
- `kernelshap()` and `permshap()` currently yield a warning on random seed handling in
30-
parallel mode, thanks [#163](https://github.com/ModelOriented/kernelshap/issues/163)
31-
for reporting. We have added a note in the function documentation that this warning
32-
can be ignored.
33-
3431
### Internal changes
3532

3633
- Matrices holding on-off vectors are now consistently of type logical ([#167](https://github.com/ModelOriented/kernelshap/pull/167)).
3734

35+
### Changes in parallelism
36+
37+
We have switched from `%dopar%` to `doFuture` ([#170](https://github.com/ModelOriented/kernelshap/pull/170)) with the following impact:
38+
39+
- No need for calling `registerDoFuture()` anymore.
40+
- Random seeding is properly handled, and respects `seed`, thanks [#163](https://github.com/ModelOriented/kernelshap/issues/163) for reporting.
41+
- {doFuture} is listed under "imports", not as "suggested".
42+
- If missing packages or globals have to be specified, this now has to be done through `parallel_args = list(packages = ..., globals = ...)`
43+
instead of `parallel_args = list(.packages = ..., .globals = ...)`. The list is passed to `[foreach::foreach(.options.future = ...)]`.
44+
3845
# kernelshap 0.8.0
3946

4047
### Major improvement

R/kernelshap.R

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
#' should not be higher than 10 for exact calculations.
5555
#' For similar reasons, degree 2 hybrids should not use \eqn{p} larger than 40.
5656
#'
57-
#' @importFrom foreach %dopar%
57+
#' @importFrom doFuture %dofuture%
5858
#'
5959
#' @param object Fitted model object.
6060
#' @param X \eqn{(n \times p)} matrix or `data.frame` with rows to be explained.
@@ -105,16 +105,18 @@
105105
#' For `permshap()`, the default is 0.01, while for `kernelshap()` it is set to 0.005.
106106
#' @param max_iter If the stopping criterion (see `tol`) is not reached after
107107
#' `max_iter` iterations, the algorithm stops. Ignored if `exact = TRUE`.
108-
#' @param parallel If `TRUE`, use parallel [foreach::foreach()] to loop over rows
109-
#' to be explained. Must register backend beforehand, e.g., via 'doFuture' package,
110-
#' see README for an example. Parallelization automatically disables the progress bar.
111-
#' @param parallel_args Named list of arguments passed to [foreach::foreach()].
112-
#' Ideally, this is `NULL` (default). Only relevant if `parallel = TRUE`.
108+
#' @param parallel If `TRUE`, use [foreach::foreach()] and `%dofuture%` to loop over rows
109+
#' to be explained. Must register backend beforehand, e.g., `plan(multisession)`,
110+
#' see README for an example. Currently disables the progress bar.
111+
#' @param parallel_args Named list of arguments passed to
112+
#' `foreach::foreach(.options.future = ...)`, ideally `NULL` (default).
113+
#' Only relevant if `parallel = TRUE`.
113114
#' Example on Windows: if `object` is a GAM fitted with package 'mgcv',
114-
#' then one might need to set `parallel_args = list(.packages = "mgcv")`.
115-
#' The warning "unexpectedly generated random numbers" can be ignored because
116-
#' sharing seeds across rows of `X` it is not a problem.
115+
#' then one might need to set `parallel_args = list(packages = "mgcv")`.
116+
#' Similarly, if the model has been fitted with `ranger()`, then it might be necessary
117+
#' to pass `parallel_args = list(packages = "ranger")`.
117118
#' @param verbose Set to `FALSE` to suppress messages and the progress bar.
119+
#' @param seed Optional integer random seed. Note that it changes the global seed.
118120
#' @param survival Should cumulative hazards ("chf", default) or survival
119121
#' probabilities ("prob") per time be predicted? Only in `ranger()` survival models.
120122
#' @param ... Additional arguments passed to `pred_fun(object, X, ...)`.
@@ -195,6 +197,7 @@ kernelshap.default <- function(
195197
parallel = FALSE,
196198
parallel_args = NULL,
197199
verbose = TRUE,
200+
seed = NULL,
198201
...) {
199202
p <- length(feature_names)
200203
basic_checks(X = X, feature_names = feature_names, pred_fun = pred_fun)
@@ -209,6 +212,10 @@ kernelshap.default <- function(
209212
bg_n <- nrow(bg_X)
210213
n <- nrow(X)
211214

215+
if (!is.null(seed)) {
216+
set.seed(seed)
217+
}
218+
212219
# Calculate v1 and v0
213220
bg_preds <- align_pred(pred_fun(object, bg_X, ...))
214221
v0 <- wcolMeans(bg_preds, bg_w) # Average pred of bg data: 1 x K
@@ -254,8 +261,9 @@ kernelshap.default <- function(
254261

255262
# Apply Kernel SHAP to each row of X
256263
if (isTRUE(parallel)) {
257-
parallel_args <- c(list(i = seq_len(n)), parallel_args)
258-
res <- do.call(foreach::foreach, parallel_args) %dopar% kernelshap_one(
264+
future_args <- c(list(seed = TRUE), parallel_args)
265+
parallel_args <- c(list(i = seq_len(n)), list(.options.future = future_args))
266+
res <- do.call(foreach::foreach, parallel_args) %dofuture% kernelshap_one(
259267
x = X[i, , drop = FALSE],
260268
v1 = v1[i, , drop = FALSE],
261269
object = object,
@@ -350,6 +358,7 @@ kernelshap.ranger <- function(
350358
parallel = FALSE,
351359
parallel_args = NULL,
352360
verbose = TRUE,
361+
seed = NULL,
353362
survival = c("chf", "prob"),
354363
...) {
355364
if (is.null(pred_fun)) {
@@ -372,6 +381,7 @@ kernelshap.ranger <- function(
372381
parallel = parallel,
373382
parallel_args = parallel_args,
374383
verbose = verbose,
384+
seed = seed,
375385
...
376386
)
377387
}

R/permshap.R

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ permshap.default <- function(
108108
parallel = FALSE,
109109
parallel_args = NULL,
110110
verbose = TRUE,
111+
seed = NULL,
111112
...) {
112113
p <- length(feature_names)
113114
if (p <= 1L) {
@@ -132,6 +133,10 @@ permshap.default <- function(
132133
bg_n <- nrow(bg_X)
133134
n <- nrow(X)
134135

136+
if (!is.null(seed)) {
137+
set.seed(seed)
138+
}
139+
135140
# Baseline and predictions on explanation data
136141
bg_preds <- align_pred(pred_fun(object, bg_X, ...))
137142
v0 <- wcolMeans(bg_preds, w = bg_w) # Average pred of bg data: 1 x K
@@ -166,8 +171,9 @@ permshap.default <- function(
166171

167172
# Apply permutation SHAP to each row of X
168173
if (isTRUE(parallel)) {
169-
parallel_args <- c(list(i = seq_len(n)), parallel_args)
170-
res <- do.call(foreach::foreach, parallel_args) %dopar% permshap_one(
174+
future_args <- c(list(seed = TRUE), parallel_args)
175+
parallel_args <- c(list(i = seq_len(n)), list(.options.future = future_args))
176+
res <- do.call(foreach::foreach, parallel_args) %dofuture% permshap_one(
171177
x = X[i, , drop = FALSE],
172178
v1 = v1[i, , drop = FALSE],
173179
object = object,
@@ -257,6 +263,7 @@ permshap.ranger <- function(
257263
parallel = FALSE,
258264
parallel_args = NULL,
259265
verbose = TRUE,
266+
seed = NULL,
260267
survival = c("chf", "prob"),
261268
...) {
262269
if (is.null(pred_fun)) {
@@ -278,6 +285,7 @@ permshap.ranger <- function(
278285
parallel = parallel,
279286
parallel_args = parallel_args,
280287
verbose = verbose,
288+
seed = seed,
281289
...
282290
)
283291
}

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,17 @@ The {kernelshap} package can deal with almost any situation. We will show some o
110110

111111
### Parallel computing
112112

113-
Parallel computing for `permshap()` and `kernelshap()` is supported via {foreach}. Note that this does not work for all models.
113+
Parallel computing for `permshap()` and `kernelshap()` is supported via {doFuture} and {foreach}.
114+
Note that this does not work for all models.
114115

115-
On Windows, sometimes not all packages or global objects are passed to the parallel sessions. Often, this can be fixed via `parallel_args`, see this example:
116+
On Windows, sometimes not all packages or global objects are passed to the parallel sessions.
117+
Often, this can be fixed via `parallel_args` using the arguments "packages" and "globals" passed
118+
to `foreach(.options.future = ...)`, see this example:
116119

117120
```r
118121
library(doFuture)
119122
library(mgcv)
120123

121-
registerDoFuture()
122124
plan(multisession, workers = 4) # Windows
123125
# plan(multicore, workers = 4) # Linux, macOS, Solaris
124126

@@ -127,7 +129,7 @@ fit <- gam(log_price ~ s(log_carat) + clarity * color + cut, data = diamonds)
127129

128130
system.time( # 4 seconds in parallel
129131
ps <- permshap(
130-
fit, X, bg_X = bg_X, parallel = TRUE, parallel_args = list(.packages = "mgcv")
132+
fit, X, bg_X = bg_X, parallel = TRUE, parallel_args = list(packages = "mgcv")
131133
)
132134
)
133135
ps

cran-comments.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ Thanks a lot!
1717

1818
### `check_win_devel()`
1919

20-
Status: 1 NOTE
21-
R Under development (unstable) (2025-07-05 r88387 ucrt)
20+
Status: OK
2221

2322
### Revdep OK
2423

man/kernelshap.Rd

Lines changed: 13 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/permshap.Rd

Lines changed: 13 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packaging.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,12 @@ use_description(
3737
roxygen = TRUE
3838
)
3939

40+
use_package("doFuture", "Imports")
4041
use_package("foreach", "Imports")
4142
use_package("MASS", "Imports")
4243
use_package("stats", "Imports")
4344
use_package("utils", "Imports")
4445

45-
use_package("doFuture", "Suggests")
46-
4746
use_gpl_license(2)
4847

4948
# Your files that do not belong to the package itself (others are added by "use_* function")

0 commit comments

Comments
 (0)