Skip to content

Commit e43f614

Browse files
committed
enable to pass additional parameters to treeshap function
1 parent eb6f82d commit e43f614

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

R/model_survshap.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ model_survshap.surv_explainer <- function(explainer,
102102
N = N,
103103
y_true = y_true,
104104
calculation_method = calculation_method,
105-
aggregation_method = aggregation_method
105+
aggregation_method = aggregation_method,
106+
...
106107
)
107108

108109
attr(shap_values, "label") <- explainer$label

R/surv_shap.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,14 +252,13 @@ use_kernelshap <- function(explainer, new_observation, output_type, N, ...) {
252252
new_observation, "data.frame")
253253
)
254254

255+
if (is.null(N)) N <- nrow(explainer$data)
255256
background_data <- explainer$data[sample(1:nrow(explainer$data), N),]
256257
# ensure that classes of explainer$data and new_observation are equal
257258
if (!inherits(background_data, "data.frame")) {
258259
background_data <- data.frame(background_data)
259260
}
260261

261-
if (is.null(N)) N <- nrow(explainer$data)
262-
263262
shap_values <- sapply(
264263
X = as.character(seq_len(nrow(new_observation))),
265264
FUN = function(i) {
@@ -327,7 +326,8 @@ use_treeshap <- function(explainer, new_observation, ...){
327326
stopifnot(identical(dim(new_obs_mat), as.integer(c(1L, ncol(new_observation)))))
328327
treeshap::treeshap(
329328
unified_model = m,
330-
x = new_obs_mat
329+
x = new_obs_mat,
330+
...
331331
)$shaps
332332
}
333333
)

0 commit comments

Comments
 (0)