Skip to content

Commit 7752820

Browse files
authored
Merge branch 'feat-treeshap' into main
2 parents 5481675 + 160053d commit 7752820

File tree

6 files changed

+189
-34
lines changed

6 files changed

+189
-34
lines changed

DESCRIPTION

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
Package: survex
22
Title: Explainable Machine Learning in Survival Analysis
3-
Version: 1.1.3
3+
Version: 1.1.3.9001
44
Authors@R:
55
c(
66
person("Mikołaj", "Spytek", email = "[email protected]", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")),
77
person("Mateusz", "Krzyziński", role = c("aut"), comment = c(ORCID = "0000-0001-6143-488X")),
88
person("Sophie", "Langbein", role = c("aut")),
99
person("Hubert", "Baniecki", role = c("aut"), comment = c(ORCID = "0000-0001-6661-5364")),
10-
person("Przemyslaw", "Biecek", role = c("aut"), comment = c(ORCID = "0000-0001-8423-1823"))
10+
person("Przemyslaw", "Biecek", role = c("aut"), comment = c(ORCID = "0000-0001-8423-1823")),
11+
person("Lorenz A.", "Kapsner", role = c("ctb"), comment = c(ORCID = "0000-0003-1866-860X"))
1112
)
1213
Description: Survival analysis models are commonly used in medicine and other areas. Many of them
1314
are too complex to be interpreted by human. Exploration and explanation is needed, but
@@ -50,7 +51,7 @@ Suggests:
5051
withr,
5152
xgboost
5253
Remotes:
53-
github::kapsner/treeshap
54+
github::ModelOriented/treeshap
5455
Config/testthat/edition: 3
5556
VignetteBuilder: knitr
5657
URL: https://modeloriented.github.io/survex/

R/surv_shap.R

Lines changed: 124 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
#' @param output_type a character, either `"survival"` or `"chf"`. Determines which type of prediction should be used for explanations.
66
#' @param ... additional parameters, passed to internal functions
77
#' @param y_true a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting
8-
#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements) or `"exact_kernel"` for exact Kernel SHAP estimation
8+
#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements), `"exact_kernel"` for exact Kernel SHAP estimation,
9+
#' or `"treeshap"` for use of `treeshap` library (efficient implementation to compute SHAP values for tree-based models).
910
#' @param aggregation_method a character, either `"integral"`, `"integral_absolute"`, `"mean_absolute"`, `"max_absolute"`, or `"sum_of_squares"`
1011
#'
1112
#' @return A list, containing the calculated SurvSHAP(t) results in the `result` field
@@ -19,10 +20,15 @@ surv_shap <- function(explainer,
1920
output_type,
2021
...,
2122
y_true = NULL,
22-
calculation_method = "kernelshap",
23-
aggregation_method = "integral") {
23+
calculation_method = c("kernelshap", "exact_kernel", "treeshap"),
24+
aggregation_method = c("integral", "mean_absolute", "max_absolute", "sum_of_squares")
25+
) {
26+
calculation_method <- match.arg(calculation_method)
27+
aggregation_method <- match.arg(aggregation_method)
28+
29+
# make this code work for multiple observations
2430
stopifnot(
25-
"`y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse(
31+
"`y_true` must be either a matrix with one row per observation in `new_observation` or a vector of length == 2" = ifelse(
2632
!is.null(y_true),
2733
ifelse(
2834
is.matrix(y_true),
@@ -33,14 +39,40 @@ surv_shap <- function(explainer,
3339
)
3440
)
3541

42+
if (calculation_method == "kernelshap") {
43+
if (!requireNamespace("kernelshap", quietly = TRUE)) {
44+
stop(
45+
paste0(
46+
"Package \"kernelshap\" must be installed to use ",
47+
"'calculation_method = \"kernelshap\"'."
48+
),
49+
call. = FALSE
50+
)
51+
}
52+
}
53+
if (calculation_method == "treeshap") {
54+
if (!requireNamespace("treeshap", quietly = TRUE)) {
55+
stop(
56+
paste0(
57+
"Package \"treeshap\" must be installed to use ",
58+
"'calculation_method = \"treeshap\"'."
59+
),
60+
call. = FALSE
61+
)
62+
}
63+
}
64+
3665
test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE)
3766

3867
# make this code also work for 1-row matrix
3968
col_index <- which(colnames(new_observation) %in% colnames(explainer$data))
4069
if (is.matrix(new_observation) && nrow(new_observation) == 1) {
41-
new_observation <- as.matrix(t(new_observation[, col_index]))
70+
new_observation <- data.frame(as.matrix(t(new_observation[, col_index])))
4271
} else {
4372
new_observation <- new_observation[, col_index]
73+
if (!inherits(new_observation, "data.frame")) {
74+
new_observation <- data.frame(new_observation)
75+
}
4476
}
4577

4678
if (ncol(explainer$data) != ncol(new_observation)) {
@@ -59,14 +91,9 @@ surv_shap <- function(explainer,
5991
}
6092
}
6193

62-
# hack to use rf-model death times as explainer death times, as
63-
# treeshap::ranger_surv_fun.unify extracts survival times directly
64-
# from the ranger object for calculating the predictions
6594
if (calculation_method == "treeshap") {
66-
if (inherits(explainer$model, "ranger")) {
67-
explainer$times <- explainer$model$unique.death.times
68-
} else {
69-
stop("Calculation method `treeshap` is currently only implemented for `ranger`.")
95+
if (!inherits(explainer$model, "ranger")) {
96+
stop("Calculation method `treeshap` is currently only implemented for `ranger` survival models.")
7097
}
7198
}
7299

@@ -75,9 +102,14 @@ surv_shap <- function(explainer,
75102
# to display final object correctly, when is.matrix(new_observation) == TRUE
76103
res$variable_values <- as.data.frame(new_observation)
77104
res$result <- switch(calculation_method,
78-
"exact_kernel" = use_exact_shap(explainer, new_observation, output_type, ...),
79-
"kernelshap" = use_kernelshap(explainer, new_observation, output_type, ...),
80-
stop("Only `exact_kernel` and `kernelshap` calculation methods are implemented")
105+
"exact_kernel" = use_exact_shap(explainer, new_observation, output_type, ...),
106+
"kernelshap" = use_kernelshap(explainer, new_observation, output_type, ...),
107+
"treeshap" = use_treeshap(explainer, new_observation, ...),
108+
stop("Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented"))
109+
# quality-check here
110+
stopifnot(
111+
"Number of rows of SurvSHAP table are not identical with length(eval_times)" =
112+
nrow(res$result) == length(res$eval_times)
81113
)
82114

83115
if (!is.null(y_true)) res$y_true <- c(y_true_time = y_true_time, y_true_ind = y_true_ind)
@@ -97,7 +129,7 @@ surv_shap <- function(explainer,
97129
return(res)
98130
}
99131

100-
use_exact_shap <- function(explainer, new_observation, output_type, observation_aggregation_method, ...) {
132+
use_exact_shap <- function(explainer, new_observation, output_type, ...) {
101133
shap_values <- sapply(
102134
X = as.character(seq_len(nrow(new_observation))),
103135
FUN = function(i) {
@@ -134,11 +166,8 @@ shap_kernel <- function(explainer, new_observation, output_type, ...) {
134166
timestamps
135167
)
136168

137-
138-
139169
shap_values <- as.data.frame(shap_values, row.names = colnames(explainer$data))
140170
colnames(shap_values) <- paste("t=", timestamps, sep = "")
141-
142171
return(t(shap_values))
143172
}
144173

@@ -215,19 +244,31 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_
215244
times = explainer$times
216245
)
217246
}
247+
}
248+
249+
stopifnot(
250+
"new_observation must be a data.frame" = inherits(
251+
new_observation, "data.frame")
252+
)
218253

254+
# get explainer data to be able to make class checks and transformations
255+
explainer_data <- explainer$data
256+
# ensure that classes of explainer$data and new_observation are equal
257+
if (!inherits(explainer_data, "data.frame")) {
258+
explainer_data <- data.frame(explainer_data)
219259
}
220260

221261
shap_values <- sapply(
222262
X = as.character(seq_len(nrow(new_observation))),
223263
FUN = function(i) {
224264
tmp_res <- kernelshap::kernelshap(
225265
object = explainer$model,
226-
X = new_observation[as.integer(i), ],
227-
bg_X = explainer$data,
266+
X = new_observation[as.integer(i), ], # data.frame
267+
bg_X = explainer_data, # data.frame
228268
pred_fun = predfun,
229269
verbose = FALSE
230270
)
271+
# kernelshap-test: is.matrix(X) == is.matrix(bg_X) should evaluate to `TRUE`
231272
tmp_shap_values <- data.frame(t(sapply(tmp_res$S, cbind)))
232273
colnames(tmp_shap_values) <- colnames(tmp_res$X)
233274
rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "")
@@ -240,6 +281,68 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_
240281
return(shap_values)
241282
}
242283

284+
use_treeshap <- function(explainer, new_observation, ...){
285+
286+
stopifnot(
287+
"new_observation must be a data.frame" = inherits(
288+
new_observation, "data.frame")
289+
)
290+
291+
# init unify_append_args
292+
unify_append_args <- list()
293+
294+
if (inherits(explainer$model, "ranger")) {
295+
# UNIFY_FUN to prepare code for easy Integration of other ml algorithms
296+
# that are supported by treeshap
297+
UNIFY_FUN <- treeshap::ranger_surv.unify
298+
unify_append_args <- list(type = "survival", times = explainer$times)
299+
} else {
300+
stop("Support for `treeshap` is currently only implemented for `ranger`.")
301+
}
302+
303+
unify_args <- list(
304+
rf_model = explainer$model,
305+
data = explainer$data
306+
)
307+
308+
if (length(unify_append_args) > 0) {
309+
unify_args <- c(unify_args, unify_append_args)
310+
}
311+
312+
tmp_unified <- do.call(UNIFY_FUN, unify_args)
313+
314+
shap_values <- sapply(
315+
X = as.character(seq_len(nrow(new_observation))),
316+
FUN = function(i) {
317+
tmp_res <- do.call(
318+
rbind,
319+
lapply(
320+
tmp_unified,
321+
function(m) {
322+
new_obs_mat <- as.matrix(new_observation[as.integer(i), ])
323+
# ensure that matrix has expected dimensions; as.integer is
324+
# necessary for valid comparison with "identical"
325+
stopifnot(identical(dim(new_obs_mat), as.integer(c(1L, ncol(new_observation)))))
326+
treeshap::treeshap(
327+
unified_model = m,
328+
x = new_obs_mat
329+
)$shaps
330+
}
331+
)
332+
)
333+
334+
tmp_shap_values <- data.frame(tmp_res)
335+
colnames(tmp_shap_values) <- colnames(tmp_res)
336+
rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "")
337+
tmp_shap_values
338+
},
339+
USE.NAMES = TRUE,
340+
simplify = FALSE
341+
)
342+
343+
return(shap_values)
344+
345+
}
243346

244347
#' @keywords internal
245348
aggregate_shap_multiple_observations <- function(shap_res_list, feature_names, aggregation_function) {

man/model_survshap.surv_explainer.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/surv_shap.Rd

Lines changed: 4 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-model_survshap.R

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
# create objects here so that they do not have to be created redundantly
23
veteran <- survival::veteran
34
rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
45
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE)
@@ -68,3 +69,25 @@ test_that("global survshap explanations with kernelshap work for coxph, using ex
6869
expect_equal(length(cph_global_survshap$eval_times), length(cph_exp$times))
6970
expect_true(all(names(cph_global_survshap$variable_values) == colnames(cph_exp$data)))
7071
})
72+
73+
# testing if matrix works as input
74+
rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
75+
rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)
76+
77+
test_that("global survshap explanations with treeshap work for ranger", {
78+
79+
new_obs <- model.matrix(~ -1 + ., veteran[1:40, setdiff(colnames(veteran), c("time", "status"))])
80+
ranger_global_survshap_tree <- model_survshap(
81+
rsf_ranger_exp_matrix,
82+
new_observation = new_obs,
83+
y_true = survival::Surv(veteran$time[1:40], veteran$status[1:40]),
84+
aggregation_method = "mean_absolute",
85+
calculation_method = "treeshap"
86+
)
87+
plot(ranger_global_survshap_tree)
88+
89+
expect_s3_class(ranger_global_survshap_tree, c("aggregated_surv_shap", "surv_shap"))
90+
expect_equal(length(ranger_global_survshap_tree$eval_times), length(rsf_ranger_exp_matrix$times))
91+
expect_true(all(names(ranger_global_survshap_tree$variable_values) == colnames(rsf_ranger_exp_matrix$data)))
92+
93+
})

tests/testthat/test-predict_parts.R

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ test_that("survshap explanations work", {
66
rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran)
77

88
cph_exp <- explain(cph, verbose = FALSE)
9-
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE)
9+
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)
1010
rsf_src_exp <- explain(rsf_src, verbose = FALSE)
1111

1212
parts_cph <- predict_parts(cph_exp, veteran[1, !colnames(veteran) %in% c("time", "status")], y_true = matrix(c(100, 1), ncol = 2), aggregation_method = "sum_of_squares")
@@ -19,19 +19,18 @@ test_that("survshap explanations work", {
1919
parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute")
2020
plot(parts_ranger)
2121

22-
# test ranger with treeshap (we need the data as matrix)
22+
# test ranger with kernelshap when using a matrix as input for data and new observation
2323
rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
24-
rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = Surv(veteran$time, veteran$status), verbose = FALSE)
24+
rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)
2525
new_obs <- model.matrix(~ -1 + ., veteran[2, !colnames(veteran) %in% c("time", "status")])
26-
parts_ranger_treeshap <- predict_parts(
26+
parts_ranger_kernelshap <- predict_parts(
2727
rsf_ranger_exp_matrix,
2828
new_observation = new_obs,
2929
y_true = c(100, 1),
3030
aggregation_method = "mean_absolute",
3131
calculation_method = "kernelshap"
3232
)
33-
plot(parts_ranger_treeshap)
34-
33+
plot(parts_ranger_kernelshap)
3534

3635
parts_src <- predict_parts(rsf_src_exp, veteran[3, !colnames(veteran) %in% c("time", "status")])
3736
plot(parts_src)
@@ -60,6 +59,29 @@ test_that("survshap explanations work", {
6059
expect_error(predict_parts(cph_exp, veteran[1, ], calculation_method = "nonexistent"))
6160
expect_error(predict_parts(cph_exp, veteran[1, c(1, 1, 1, 1, 1)], calculation_method = "nonexistent"))
6261

62+
})
63+
64+
test_that("local survshap explanations with treeshap work for ranger", {
65+
66+
veteran <- survival::veteran
67+
68+
rsf_ranger_matrix <- ranger::ranger(survival::Surv(time, status) ~ ., data = model.matrix(~ -1 + ., veteran), respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
69+
rsf_ranger_exp_matrix <- explain(rsf_ranger_matrix, data = model.matrix(~ -1 + ., veteran[, -c(3, 4)]), y = survival::Surv(veteran$time, veteran$status), verbose = FALSE)
70+
71+
72+
new_obs <- model.matrix(~ -1 + ., veteran[2, setdiff(colnames(veteran), c("time", "status"))])
73+
parts_ranger <- model_survshap(
74+
rsf_ranger_exp_matrix,
75+
new_obs,
76+
y_true = c(veteran$time[2], veteran$status[2]),
77+
aggregation_method = "mean_absolute",
78+
calculation_method = "treeshap"
79+
)
80+
plot(parts_ranger)
81+
82+
expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap"))
83+
expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp_matrix$times))
84+
expect_true(all(colnames(parts_ranger$result) == colnames(rsf_ranger_exp_matrix$data)))
6385

6486
})
6587

@@ -81,6 +103,10 @@ test_that("survshap explanations with output_type = 'chf' work", {
81103
plot(parts_cph, rug = "censors")
82104
plot(parts_cph, rug = "none")
83105

106+
# test global exact
107+
parts_cph_glob <- predict_parts(cph_exp, veteran[1:3, !colnames(veteran) %in% c("time", "status")], y_true = as.matrix(veteran[1:3, c("time", "status")]), calculation_method = "exact_kernel", aggregation_method = "max_absolute", output_type = "chf")
108+
plot(parts_cph_glob)
109+
84110
parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute", output_type = "chf")
85111
plot(parts_ranger)
86112

0 commit comments

Comments
 (0)