Skip to content

Commit 0b2f4f5

Browse files
authored
Merge pull request #85 from kapsner/feat-treeshap
Feature: support SurvSHAP computation with {treeshap}
2 parents 3864b87 + 1d275eb commit 0b2f4f5

File tree

6 files changed

+200
-19
lines changed

6 files changed

+200
-19
lines changed

DESCRIPTION

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Authors@R:
77
person("Mateusz", "Krzyziński", role = c("aut"), comment = c(ORCID = "0000-0001-6143-488X")),
88
person("Sophie", "Langbein", role = c("aut")),
99
person("Hubert", "Baniecki", role = c("aut"), comment = c(ORCID = "0000-0001-6661-5364")),
10+
person("Lorenz A.", "Kapsner", role = c("ctb"), comment = c(ORCID = "0000-0003-1866-860X")),
1011
person("Przemyslaw", "Biecek", role = c("aut"), comment = c(ORCID = "0000-0001-8423-1823"))
1112
)
1213
Description: Survival analysis models are commonly used in medicine and other areas. Many of them
@@ -25,6 +26,7 @@ Imports:
2526
DALEX (>= 2.2.1),
2627
ggplot2 (>= 3.4.0),
2728
kernelshap,
29+
treeshap,
2830
pec,
2931
survival,
3032
patchwork

R/surv_shap.R

Lines changed: 128 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#' @param output_type a character, either `"survival"` or `"chf"`. Determines which type of prediction should be used for explanations.
66
#' @param ... additional parameters, passed to internal functions
77
#' @param y_true a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting
8-
#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements) or `"exact_kernel"` for exact Kernel SHAP estimation
8+
#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements), `"exact_kernel"` for exact Kernel SHAP estimation, or `"treeshap"` for use of `treeshap` library (efficient implementation to compute SHAP values for tree-based models).
99
#' @param aggregation_method a character, either `"integral"`, `"integral_absolute"`, `"mean_absolute"`, `"max_absolute"`, or `"sum_of_squares"`
1010
#'
1111
#' @return A list, containing the calculated SurvSHAP(t) results in the `result` field
@@ -19,10 +19,15 @@ surv_shap <- function(explainer,
1919
output_type,
2020
...,
2121
y_true = NULL,
22-
calculation_method = "kernelshap",
23-
aggregation_method = "integral") {
22+
calculation_method = c("kernelshap", "exact_kernel", "treeshap"),
23+
aggregation_method = c("integral", "mean_absolute", "max_absolute", "sum_of_squares")
24+
) {
25+
calculation_method <- match.arg(calculation_method)
26+
aggregation_method <- match.arg(aggregation_method)
27+
28+
# make this code work for multiple observations
2429
stopifnot(
25-
"`y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse(
30+
"`y_true` must be either a matrix with one row per observation in `new_observation` or a vector of length == 2" = ifelse(
2631
!is.null(y_true),
2732
ifelse(
2833
is.matrix(y_true),
@@ -33,14 +38,40 @@ surv_shap <- function(explainer,
3338
)
3439
)
3540

41+
if (calculation_method == "kernelshap") {
42+
if (!requireNamespace("kernelshap", quietly = TRUE)) {
43+
stop(
44+
paste0(
45+
"Package \"kernelshap\" must be installed to use ",
46+
"'calculation_method = \"kernelshap\"'."
47+
),
48+
call. = FALSE
49+
)
50+
}
51+
}
52+
if (calculation_method == "treeshap") {
53+
if (!requireNamespace("treeshap", quietly = TRUE)) {
54+
stop(
55+
paste0(
56+
"Package \"treeshap\" must be installed to use ",
57+
"'calculation_method = \"treeshap\"'."
58+
),
59+
call. = FALSE
60+
)
61+
}
62+
}
63+
3664
test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE)
3765

3866
# make this code also work for 1-row matrix
3967
col_index <- which(colnames(new_observation) %in% colnames(explainer$data))
4068
if (is.matrix(new_observation) && nrow(new_observation) == 1) {
41-
new_observation <- as.matrix(t(new_observation[, col_index]))
69+
new_observation <- data.frame(as.matrix(t(new_observation[, col_index])))
4270
} else {
4371
new_observation <- new_observation[, col_index]
72+
if (!inherits(new_observation, "data.frame")) {
73+
new_observation <- data.frame(new_observation)
74+
}
4475
}
4576

4677
if (ncol(explainer$data) != ncol(new_observation)) {
@@ -59,14 +90,25 @@ surv_shap <- function(explainer,
5990
}
6091
}
6192

93+
if (calculation_method == "treeshap") {
94+
if (!inherits(explainer$model, "ranger")) {
95+
stop("Calculation method `treeshap` is currently only implemented for `ranger` survival models.")
96+
}
97+
}
98+
6299
res <- list()
63100
res$eval_times <- explainer$times
64101
# to display final object correctly, when is.matrix(new_observation) == TRUE
65102
res$variable_values <- as.data.frame(new_observation)
66103
res$result <- switch(calculation_method,
67-
"exact_kernel" = use_exact_shap(explainer, new_observation, output_type, ...),
68-
"kernelshap" = use_kernelshap(explainer, new_observation, output_type, ...),
69-
stop("Only `exact_kernel` and `kernelshap` calculation methods are implemented")
104+
"exact_kernel" = use_exact_shap(explainer, new_observation, output_type, ...),
105+
"kernelshap" = use_kernelshap(explainer, new_observation, output_type, ...),
106+
"treeshap" = use_treeshap(explainer, new_observation, ...),
107+
stop("Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented"))
108+
# quality-check here
109+
stopifnot(
110+
"Number of rows of SurvSHAP table are not identical with length(eval_times)" =
111+
nrow(res$result) == length(res$eval_times)
70112
)
71113

72114
if (!is.null(y_true)) res$y_true <- c(y_true_time = y_true_time, y_true_ind = y_true_ind)
@@ -86,7 +128,7 @@ surv_shap <- function(explainer,
86128
return(res)
87129
}
88130

89-
use_exact_shap <- function(explainer, new_observation, output_type, observation_aggregation_method, ...) {
131+
use_exact_shap <- function(explainer, new_observation, output_type, ...) {
90132
shap_values <- sapply(
91133
X = as.character(seq_len(nrow(new_observation))),
92134
FUN = function(i) {
@@ -123,11 +165,8 @@ shap_kernel <- function(explainer, new_observation, output_type, ...) {
123165
timestamps
124166
)
125167

126-
127-
128168
shap_values <- as.data.frame(shap_values, row.names = colnames(explainer$data))
129169
colnames(shap_values) <- paste("t=", timestamps, sep = "")
130-
131170
return(t(shap_values))
132171
}
133172

@@ -204,19 +243,31 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_
204243
times = explainer$times
205244
)
206245
}
246+
}
247+
248+
stopifnot(
249+
"new_observation must be a data.frame" = inherits(
250+
new_observation, "data.frame")
251+
)
207252

253+
# get explainer data to be able to make class checks and transformations
254+
explainer_data <- explainer$data
255+
# ensure that classes of explainer$data and new_observation are equal
256+
if (!inherits(explainer_data, "data.frame")) {
257+
explainer_data <- data.frame(explainer_data)
208258
}
209259

210260
shap_values <- sapply(
211261
X = as.character(seq_len(nrow(new_observation))),
212262
FUN = function(i) {
213263
tmp_res <- kernelshap::kernelshap(
214264
object = explainer$model,
215-
X = new_observation[as.integer(i), ],
216-
bg_X = explainer$data,
265+
X = new_observation[as.integer(i), ], # data.frame
266+
bg_X = explainer_data, # data.frame
217267
pred_fun = predfun,
218268
verbose = FALSE
219269
)
270+
# kernelshap-test: is.matrix(X) == is.matrix(bg_X) should evaluate to `TRUE`
220271
tmp_shap_values <- data.frame(t(sapply(tmp_res$S, cbind)))
221272
colnames(tmp_shap_values) <- colnames(tmp_res$X)
222273
rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "")
@@ -229,6 +280,69 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_
229280
return(shap_values)
230281
}
231282

283+
use_treeshap <- function(explainer, new_observation, ...){
284+
285+
stopifnot(
286+
"new_observation must be a data.frame" = inherits(
287+
new_observation, "data.frame")
288+
)
289+
290+
# init unify_append_args
291+
unify_append_args <- list()
292+
293+
if (inherits(explainer$model, "ranger")) {
294+
# UNIFY_FUN to prepare code for easy Integration of other ml algorithms
295+
# that are supported by treeshap
296+
UNIFY_FUN <- treeshap::ranger_surv.unify
297+
unify_append_args <- list(type = "survival", times = explainer$times)
298+
} else {
299+
stop("Support for `treeshap` is currently only implemented for `ranger`.")
300+
}
301+
302+
unify_args <- list(
303+
rf_model = explainer$model,
304+
data = explainer$data
305+
)
306+
307+
if (length(unify_append_args) > 0) {
308+
unify_args <- c(unify_args, unify_append_args)
309+
}
310+
311+
tmp_unified <- do.call(UNIFY_FUN, unify_args)
312+
313+
shap_values <- sapply(
314+
X = as.character(seq_len(nrow(new_observation))),
315+
FUN = function(i) {
316+
tmp_res <- do.call(
317+
rbind,
318+
lapply(
319+
tmp_unified,
320+
function(m) {
321+
new_obs_mat <- new_observation[as.integer(i), ]
322+
# ensure that matrix has expected dimensions; as.integer is
323+
# necessary for valid comparison with "identical"
324+
stopifnot(identical(dim(new_obs_mat), as.integer(c(1L, ncol(new_observation)))))
325+
treeshap::treeshap(
326+
unified_model = m,
327+
x = new_obs_mat
328+
)$shaps
329+
}
330+
)
331+
)
332+
333+
tmp_shap_values <- data.frame(tmp_res)
334+
colnames(tmp_shap_values) <- colnames(tmp_res)
335+
rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "")
336+
tmp_shap_values
337+
},
338+
USE.NAMES = TRUE,
339+
simplify = FALSE
340+
)
341+
342+
return(shap_values)
343+
344+
}
345+
232346
#' @keywords internal
233347
aggregate_shap_multiple_observations <- function(shap_res_list, feature_names, aggregation_function) {
234348
if (length(shap_res_list) > 1) {

man/model_survshap.surv_explainer.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/surv_shap.Rd

Lines changed: 4 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-model_survshap.R

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
# create objects here so that they do not have to be created redundantly
23
veteran <- survival::veteran
34
rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
45
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE)
@@ -68,3 +69,25 @@ test_that("global survshap explanations with kernelshap work for coxph, using ex
6869
expect_equal(length(cph_global_survshap$eval_times), length(cph_exp$times))
6970
expect_true(all(names(cph_global_survshap$variable_values) == colnames(cph_exp$data)))
7071
})
72+
73+
# testing if matrix works as input
74+
rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
75+
rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)
76+
77+
test_that("global survshap explanations with treeshap work for ranger", {
78+
79+
new_obs <- model.matrix(~ -1 + ., veteran[1:40, setdiff(colnames(veteran), c("time", "status"))])
80+
ranger_global_survshap_tree <- model_survshap(
81+
rsf_ranger_exp_matrix,
82+
new_observation = new_obs,
83+
y_true = survival::Surv(veteran$time[1:40], veteran$status[1:40]),
84+
aggregation_method = "mean_absolute",
85+
calculation_method = "treeshap"
86+
)
87+
plot(ranger_global_survshap_tree)
88+
89+
expect_s3_class(ranger_global_survshap_tree, c("aggregated_surv_shap", "surv_shap"))
90+
expect_equal(length(ranger_global_survshap_tree$eval_times), length(rsf_ranger_exp_matrix$times))
91+
expect_true(all(names(ranger_global_survshap_tree$variable_values) == colnames(rsf_ranger_exp_matrix$data)))
92+
93+
})

tests/testthat/test-predict_parts.R

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ test_that("survshap explanations work", {
66
rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran)
77

88
cph_exp <- explain(cph, verbose = FALSE)
9-
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE)
9+
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)
1010
rsf_src_exp <- explain(rsf_src, verbose = FALSE)
1111

1212
parts_cph <- predict_parts(cph_exp, veteran[1, !colnames(veteran) %in% c("time", "status")], y_true = matrix(c(100, 1), ncol = 2), aggregation_method = "sum_of_squares")
@@ -19,6 +19,19 @@ test_that("survshap explanations work", {
1919
parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute")
2020
plot(parts_ranger)
2121

22+
# test ranger with kernelshap when using a matrix as input for data and new observation
23+
rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
24+
rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)
25+
new_obs <- model.matrix(~ -1 + ., veteran[2, !colnames(veteran) %in% c("time", "status")])
26+
parts_ranger_kernelshap <- predict_parts(
27+
rsf_ranger_exp_matrix,
28+
new_observation = new_obs,
29+
y_true = c(100, 1),
30+
aggregation_method = "mean_absolute",
31+
calculation_method = "kernelshap"
32+
)
33+
plot(parts_ranger_kernelshap)
34+
2235
parts_src <- predict_parts(rsf_src_exp, veteran[3, !colnames(veteran) %in% c("time", "status")])
2336
plot(parts_src)
2437

@@ -46,6 +59,29 @@ test_that("survshap explanations work", {
4659
expect_error(predict_parts(cph_exp, veteran[1, ], calculation_method = "nonexistent"))
4760
expect_error(predict_parts(cph_exp, veteran[1, c(1, 1, 1, 1, 1)], calculation_method = "nonexistent"))
4861

62+
})
63+
64+
test_that("local survshap explanations with treeshap work for ranger", {
65+
66+
veteran <- survival::veteran
67+
68+
rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
69+
rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)
70+
71+
72+
new_obs <- data.frame(model.matrix(~ -1 + ., veteran[2, setdiff(colnames(veteran), c("time", "status"))]))
73+
parts_ranger <- model_survshap(
74+
rsf_ranger_exp_matrix,
75+
new_obs,
76+
y_true = c(veteran$time[2], veteran$status[2]),
77+
aggregation_method = "mean_absolute",
78+
calculation_method = "treeshap"
79+
)
80+
plot(parts_ranger)
81+
82+
expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap"))
83+
expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp_matrix$times))
84+
expect_true(all(colnames(parts_ranger$result) == colnames(rsf_ranger_exp_matrix$data)))
4985

5086
})
5187

@@ -67,6 +103,10 @@ test_that("survshap explanations with output_type = 'chf' work", {
67103
plot(parts_cph, rug = "censors")
68104
plot(parts_cph, rug = "none")
69105

106+
# test global exact
107+
parts_cph_glob <- predict_parts(cph_exp, veteran[1:3, !colnames(veteran) %in% c("time", "status")], y_true = as.matrix(veteran[1:3, c("time", "status")]), calculation_method = "exact_kernel", aggregation_method = "max_absolute", output_type = "chf")
108+
plot(parts_cph_glob)
109+
70110
parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute", output_type = "chf")
71111
plot(parts_ranger)
72112

0 commit comments

Comments
 (0)