Skip to content

Commit cca9530

Browse files
authored
Merge branch 'treeshap' into fixes
2 parents 03f8319 + 0b2f4f5 commit cca9530

File tree

6 files changed

+202
-16
lines changed

6 files changed

+202
-16
lines changed

DESCRIPTION

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Authors@R:
77
person("Mateusz", "Krzyziński", role = c("aut"), comment = c(ORCID = "0000-0001-6143-488X")),
88
person("Sophie", "Langbein", role = c("aut")),
99
person("Hubert", "Baniecki", role = c("aut"), comment = c(ORCID = "0000-0001-6661-5364")),
10+
person("Lorenz A.", "Kapsner", role = c("ctb"), comment = c(ORCID = "0000-0003-1866-860X")),
1011
person("Przemyslaw", "Biecek", role = c("aut"), comment = c(ORCID = "0000-0001-8423-1823"))
1112
)
1213
Description: Survival analysis models are commonly used in medicine and other areas. Many of them
@@ -25,6 +26,7 @@ Imports:
2526
DALEX (>= 2.2.1),
2627
ggplot2 (>= 3.4.0),
2728
kernelshap,
29+
treeshap,
2830
pec,
2931
survival,
3032
patchwork

R/surv_shap.R

Lines changed: 130 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#' @param ... additional parameters, passed to internal functions
77
#' @param N a positive integer, number of observations used as the background data
88
#' @param y_true a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting
9-
#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements) or `"exact_kernel"` for exact Kernel SHAP estimation
9+
#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements), `"exact_kernel"` for exact Kernel SHAP estimation, or `"treeshap"` for use of `treeshap` library (efficient implementation to compute SHAP values for tree-based models).
1010
#' @param aggregation_method a character, either `"integral"`, `"integral_absolute"`, `"mean_absolute"`, `"max_absolute"`, or `"sum_of_squares"`
1111
#'
1212
#' @return A list, containing the calculated SurvSHAP(t) results in the `result` field
@@ -21,10 +21,15 @@ surv_shap <- function(explainer,
2121
...,
2222
N = NULL,
2323
y_true = NULL,
24-
calculation_method = "kernelshap",
25-
aggregation_method = "integral") {
24+
calculation_method = c("kernelshap", "exact_kernel", "treeshap"),
25+
aggregation_method = c("integral", "mean_absolute", "max_absolute", "sum_of_squares")
26+
) {
27+
calculation_method <- match.arg(calculation_method)
28+
aggregation_method <- match.arg(aggregation_method)
29+
30+
# make this code work for multiple observations
2631
stopifnot(
27-
"`y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse(
32+
"`y_true` must be either a matrix with one row per observation in `new_observation` or a vector of length == 2" = ifelse(
2833
!is.null(y_true),
2934
ifelse(
3035
is.matrix(y_true),
@@ -34,13 +39,40 @@ surv_shap <- function(explainer,
3439
TRUE
3540
)
3641
)
42+
43+
if (calculation_method == "kernelshap") {
44+
if (!requireNamespace("kernelshap", quietly = TRUE)) {
45+
stop(
46+
paste0(
47+
"Package \"kernelshap\" must be installed to use ",
48+
"'calculation_method = \"kernelshap\"'."
49+
),
50+
call. = FALSE
51+
)
52+
}
53+
}
54+
if (calculation_method == "treeshap") {
55+
if (!requireNamespace("treeshap", quietly = TRUE)) {
56+
stop(
57+
paste0(
58+
"Package \"treeshap\" must be installed to use ",
59+
"'calculation_method = \"treeshap\"'."
60+
),
61+
call. = FALSE
62+
)
63+
}
64+
}
65+
3766
test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE)
3867
# make this code also work for 1-row matrix
3968
col_index <- which(colnames(new_observation) %in% colnames(explainer$data))
4069
if (is.matrix(new_observation) && nrow(new_observation) == 1) {
41-
new_observation <- as.matrix(t(new_observation[, col_index]))
70+
new_observation <- data.frame(as.matrix(t(new_observation[, col_index])))
4271
} else {
4372
new_observation <- new_observation[, col_index]
73+
if (!inherits(new_observation, "data.frame")) {
74+
new_observation <- data.frame(new_observation)
75+
}
4476
}
4577

4678
if (ncol(explainer$data) != ncol(new_observation)) {
@@ -59,14 +91,25 @@ surv_shap <- function(explainer,
5991
}
6092
}
6193

94+
if (calculation_method == "treeshap") {
95+
if (!inherits(explainer$model, "ranger")) {
96+
stop("Calculation method `treeshap` is currently only implemented for `ranger` survival models.")
97+
}
98+
}
99+
62100
res <- list()
63101
res$eval_times <- explainer$times
64102
# to display final object correctly, when is.matrix(new_observation) == TRUE
65103
res$variable_values <- as.data.frame(new_observation)
66104
res$result <- switch(calculation_method,
67-
"exact_kernel" = use_exact_shap(explainer, new_observation, output_type, N, ...),
68-
"kernelshap" = use_kernelshap(explainer, new_observation, output_type, N, ...),
69-
stop("Only `exact_kernel` and `kernelshap` calculation methods are implemented")
105+
"exact_kernel" = use_exact_shap(explainer, new_observation, output_type, ...),
106+
"kernelshap" = use_kernelshap(explainer, new_observation, output_type, ...),
107+
"treeshap" = use_treeshap(explainer, new_observation, ...),
108+
stop("Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented"))
109+
# quality-check here
110+
stopifnot(
111+
"Number of rows of SurvSHAP table are not identical with length(eval_times)" =
112+
nrow(res$result) == length(res$eval_times)
70113
)
71114

72115
if (!is.null(y_true)) res$y_true <- c(y_true_time = y_true_time, y_true_ind = y_true_ind)
@@ -86,6 +129,7 @@ surv_shap <- function(explainer,
86129
return(res)
87130
}
88131

132+
89133
use_exact_shap <- function(explainer, new_observation, output_type, N, ...) {
90134
shap_values <- sapply(
91135
X = as.character(seq_len(nrow(new_observation))),
@@ -125,7 +169,6 @@ shap_kernel <- function(explainer, new_observation, output_type, N, ...) {
125169

126170
shap_values <- as.data.frame(shap_values, row.names = colnames(explainer$data))
127171
colnames(shap_values) <- paste("t=", timestamps, sep = "")
128-
129172
return(t(shap_values))
130173
}
131174

@@ -204,6 +247,18 @@ use_kernelshap <- function(explainer, new_observation, output_type, N, ...) {
204247
}
205248
}
206249

250+
stopifnot(
251+
"new_observation must be a data.frame" = inherits(
252+
new_observation, "data.frame")
253+
)
254+
255+
# get explainer data to be able to make class checks and transformations
256+
explainer_data <- explainer$data
257+
# ensure that classes of explainer$data and new_observation are equal
258+
if (!inherits(explainer_data, "data.frame")) {
259+
explainer_data <- data.frame(explainer_data)
260+
}
261+
207262
if (is.null(N)) N <- nrow(explainer$data)
208263
background_data <- explainer$data[sample(1:nrow(explainer$data), N),]
209264

@@ -212,11 +267,12 @@ use_kernelshap <- function(explainer, new_observation, output_type, N, ...) {
212267
FUN = function(i) {
213268
tmp_res <- kernelshap::kernelshap(
214269
object = explainer$model,
215-
X = new_observation[as.integer(i), ],
216-
bg_X = background_data,
270+
X = new_observation[as.integer(i), ], # data.frame
271+
bg_X = explainer_data, # data.frame
217272
pred_fun = predfun,
218273
verbose = FALSE
219274
)
275+
# kernelshap-test: is.matrix(X) == is.matrix(bg_X) should evaluate to `TRUE`
220276
tmp_shap_values <- data.frame(t(sapply(tmp_res$S, cbind)))
221277
colnames(tmp_shap_values) <- colnames(tmp_res$X)
222278
rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "")
@@ -229,6 +285,69 @@ use_kernelshap <- function(explainer, new_observation, output_type, N, ...) {
229285
return(shap_values)
230286
}
231287

288+
use_treeshap <- function(explainer, new_observation, ...){
289+
290+
stopifnot(
291+
"new_observation must be a data.frame" = inherits(
292+
new_observation, "data.frame")
293+
)
294+
295+
# init unify_append_args
296+
unify_append_args <- list()
297+
298+
if (inherits(explainer$model, "ranger")) {
299+
# UNIFY_FUN to prepare code for easy Integration of other ml algorithms
300+
# that are supported by treeshap
301+
UNIFY_FUN <- treeshap::ranger_surv.unify
302+
unify_append_args <- list(type = "survival", times = explainer$times)
303+
} else {
304+
stop("Support for `treeshap` is currently only implemented for `ranger`.")
305+
}
306+
307+
unify_args <- list(
308+
rf_model = explainer$model,
309+
data = explainer$data
310+
)
311+
312+
if (length(unify_append_args) > 0) {
313+
unify_args <- c(unify_args, unify_append_args)
314+
}
315+
316+
tmp_unified <- do.call(UNIFY_FUN, unify_args)
317+
318+
shap_values <- sapply(
319+
X = as.character(seq_len(nrow(new_observation))),
320+
FUN = function(i) {
321+
tmp_res <- do.call(
322+
rbind,
323+
lapply(
324+
tmp_unified,
325+
function(m) {
326+
new_obs_mat <- new_observation[as.integer(i), ]
327+
# ensure that matrix has expected dimensions; as.integer is
328+
# necessary for valid comparison with "identical"
329+
stopifnot(identical(dim(new_obs_mat), as.integer(c(1L, ncol(new_observation)))))
330+
treeshap::treeshap(
331+
unified_model = m,
332+
x = new_obs_mat
333+
)$shaps
334+
}
335+
)
336+
)
337+
338+
tmp_shap_values <- data.frame(tmp_res)
339+
colnames(tmp_shap_values) <- colnames(tmp_res)
340+
rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "")
341+
tmp_shap_values
342+
},
343+
USE.NAMES = TRUE,
344+
simplify = FALSE
345+
)
346+
347+
return(shap_values)
348+
349+
}
350+
232351
#' @keywords internal
233352
aggregate_shap_multiple_observations <- function(shap_res_list, feature_names, aggregation_function) {
234353
if (length(shap_res_list) > 1) {

man/model_survshap.surv_explainer.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/surv_shap.Rd

Lines changed: 4 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-model_survshap.R

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
# create objects here so that they do not have to be created redundantly
23
veteran <- survival::veteran
34
rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
45
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE)
@@ -68,3 +69,25 @@ test_that("global survshap explanations with kernelshap work for coxph, using ex
6869
expect_equal(length(cph_global_survshap$eval_times), length(cph_exp$times))
6970
expect_true(all(names(cph_global_survshap$variable_values) == colnames(cph_exp$data)))
7071
})
72+
73+
# testing if matrix works as input
74+
rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
75+
rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)
76+
77+
test_that("global survshap explanations with treeshap work for ranger", {
78+
79+
new_obs <- model.matrix(~ -1 + ., veteran[1:40, setdiff(colnames(veteran), c("time", "status"))])
80+
ranger_global_survshap_tree <- model_survshap(
81+
rsf_ranger_exp_matrix,
82+
new_observation = new_obs,
83+
y_true = survival::Surv(veteran$time[1:40], veteran$status[1:40]),
84+
aggregation_method = "mean_absolute",
85+
calculation_method = "treeshap"
86+
)
87+
plot(ranger_global_survshap_tree)
88+
89+
expect_s3_class(ranger_global_survshap_tree, c("aggregated_surv_shap", "surv_shap"))
90+
expect_equal(length(ranger_global_survshap_tree$eval_times), length(rsf_ranger_exp_matrix$times))
91+
expect_true(all(names(ranger_global_survshap_tree$variable_values) == colnames(rsf_ranger_exp_matrix$data)))
92+
93+
})

tests/testthat/test-predict_parts.R

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ test_that("survshap explanations work", {
66
rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran)
77

88
cph_exp <- explain(cph, verbose = FALSE)
9-
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE)
9+
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)
1010
rsf_src_exp <- explain(rsf_src, verbose = FALSE)
1111

1212
parts_cph <- predict_parts(cph_exp, veteran[1, !colnames(veteran) %in% c("time", "status")], y_true = matrix(c(100, 1), ncol = 2), aggregation_method = "sum_of_squares")
@@ -19,6 +19,19 @@ test_that("survshap explanations work", {
1919
parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute")
2020
plot(parts_ranger)
2121

22+
# test ranger with kernelshap when using a matrix as input for data and new observation
23+
rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
24+
rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)
25+
new_obs <- model.matrix(~ -1 + ., veteran[2, !colnames(veteran) %in% c("time", "status")])
26+
parts_ranger_kernelshap <- predict_parts(
27+
rsf_ranger_exp_matrix,
28+
new_observation = new_obs,
29+
y_true = c(100, 1),
30+
aggregation_method = "mean_absolute",
31+
calculation_method = "kernelshap"
32+
)
33+
plot(parts_ranger_kernelshap)
34+
2235
parts_src <- predict_parts(rsf_src_exp, veteran[3, !colnames(veteran) %in% c("time", "status")])
2336
plot(parts_src)
2437

@@ -46,6 +59,29 @@ test_that("survshap explanations work", {
4659
expect_error(predict_parts(cph_exp, veteran[1, ], calculation_method = "nonexistent"))
4760
expect_error(predict_parts(cph_exp, veteran[1, c(1, 1, 1, 1, 1)], calculation_method = "nonexistent"))
4861

62+
})
63+
64+
test_that("local survshap explanations with treeshap work for ranger", {
65+
66+
veteran <- survival::veteran
67+
68+
rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
69+
rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)
70+
71+
72+
new_obs <- data.frame(model.matrix(~ -1 + ., veteran[2, setdiff(colnames(veteran), c("time", "status"))]))
73+
parts_ranger <- model_survshap(
74+
rsf_ranger_exp_matrix,
75+
new_obs,
76+
y_true = c(veteran$time[2], veteran$status[2]),
77+
aggregation_method = "mean_absolute",
78+
calculation_method = "treeshap"
79+
)
80+
plot(parts_ranger)
81+
82+
expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap"))
83+
expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp_matrix$times))
84+
expect_true(all(colnames(parts_ranger$result) == colnames(rsf_ranger_exp_matrix$data)))
4985

5086
})
5187

@@ -67,6 +103,10 @@ test_that("survshap explanations with output_type = 'chf' work", {
67103
plot(parts_cph, rug = "censors")
68104
plot(parts_cph, rug = "none")
69105

106+
# test global exact
107+
parts_cph_glob <- predict_parts(cph_exp, veteran[1:3, !colnames(veteran) %in% c("time", "status")], y_true = as.matrix(veteran[1:3, c("time", "status")]), calculation_method = "exact_kernel", aggregation_method = "max_absolute", output_type = "chf")
108+
plot(parts_cph_glob)
109+
70110
parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute", output_type = "chf")
71111
plot(parts_ranger)
72112

0 commit comments

Comments
 (0)