Skip to content

Commit 00d6a65

Browse files
committed
feat: aggregate survshap across multiple observations
1 parent f17674a commit 00d6a65

File tree

5 files changed

+106
-13
lines changed

5 files changed

+106
-13
lines changed

DESCRIPTION

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: survex
22
Title: Explainable Machine Learning in Survival Analysis
3-
Version: 1.0.0.9000
3+
Version: 1.0.0.9001
44
Authors@R:
55
c(
66
person("Mikołaj", "Spytek", email = "[email protected]", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")),
@@ -18,15 +18,16 @@ Description: Survival analysis models are commonly used in medicine and other ar
1818
License: GPL (>= 3)
1919
Encoding: UTF-8
2020
Roxygen: list(markdown = TRUE)
21-
RoxygenNote: 7.2.1
21+
RoxygenNote: 7.2.3
2222
Depends: R (>= 3.5.0)
2323
Imports:
2424
DALEX (>= 2.2.1),
2525
ggplot2,
2626
kernelshap,
2727
pec,
2828
survival,
29-
patchwork
29+
patchwork,
30+
data.table
3031
Suggests:
3132
censored,
3233
covr,

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ export(survival_to_cumulative_hazard)
6868
export(theme_default_survex)
6969
export(theme_vertical_default_survex)
7070
export(transform_to_stepfunction)
71+
import(data.table)
7172
import(ggplot2)
7273
import(patchwork)
7374
import(survival)

R/surv_shap.R

Lines changed: 79 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,32 @@ surv_shap <- function(explainer,
2424
B = 25,
2525
exact = FALSE
2626
) {
27+
# make this code work for multiple observations
28+
stopifnot(ifelse(!is.null(y_true),
29+
ifelse(is.matrix(y_true),
30+
nrow(new_observation) == nrow(y_true),
31+
is.null(dim(y_true)) && length(y_true) == 2L),
32+
TRUE))
33+
2734
test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE)
28-
new_observation <- new_observation[, colnames(new_observation) %in% colnames(explainer$data)]
35+
36+
# make this code also work for 1-row matrix
37+
col_index <- which(colnames(new_observation) %in% colnames(explainer$data))
38+
if (is.matrix(new_observation) && nrow(new_observation) == 1) {
39+
new_observation <- as.matrix(t(new_observation[, col_index]))
40+
} else {
41+
new_observation <- new_observation[, col_index]
42+
}
43+
2944
if (ncol(explainer$data) != ncol(new_observation)) stop("New observation and data have different number of columns (variables)")
3045

3146
if (!is.null(y_true)) {
3247
if (is.matrix(y_true)) {
33-
y_true_ind <- y_true[1, 2]
34-
y_true_time <- y_true[1, 1]
48+
# above, we have already checked that nrows of observations are
49+
# identical to nrows of y_true; thus we do not need to index
50+
# the first row here
51+
y_true_ind <- y_true[, 2]
52+
y_true_time <- y_true[, 1]
3553
} else {
3654
y_true_ind <- y_true[2]
3755
y_true_time <- y_true[1]
@@ -40,7 +58,8 @@ surv_shap <- function(explainer,
4058

4159
res <- list()
4260
res$eval_times <- explainer$times
43-
res$variable_values <- new_observation
61+
# to display final object correctly, when is.matrix(new_observation) == TRUE
62+
res$variable_values <- as.data.frame(new_observation)
4463

4564
res$result <- switch(calculation_method,
4665
"exact_kernel" = shap_kernel(explainer, new_observation, ...),
@@ -148,14 +167,64 @@ aggregate_surv_shap <- function(survshap, method) {
148167
use_kernelshap <- function(explainer, new_observation, ...){
149168

150169
predfun <- function(model, newdata){
151-
explainer$predict_survival_function(model, newdata, times=explainer$times)
170+
explainer$predict_survival_function(
171+
model,
172+
newdata,
173+
times = explainer$times
174+
)
152175
}
153176

154-
tmp_res <- kernelshap::kernelshap(explainer$model, new_observation, bg_X = explainer$data,
155-
pred_fun = predfun, verbose=FALSE)
177+
tmp_res_list <- sapply(
178+
X = as.character(seq_len(nrow(new_observation))),
179+
FUN = function(i) {
180+
tmp_res <- kernelshap::kernelshap(
181+
object = explainer$model,
182+
X = new_observation[as.integer(i), ],
183+
bg_X = explainer$data,
184+
pred_fun = predfun,
185+
verbose = FALSE
186+
)
187+
tmp_shap_values <- data.frame(t(sapply(tmp_res$S, cbind)))
188+
colnames(tmp_shap_values) <- colnames(tmp_res$X)
189+
rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "")
190+
data.table::as.data.table(tmp_shap_values, keep.rownames = TRUE)
191+
},
192+
USE.NAMES = TRUE,
193+
simplify = FALSE
194+
)
195+
196+
shap_values <- aggregate_shap_multiple_observations(
197+
shap_res_list = tmp_res_list,
198+
feature_names = colnames(new_observation)
199+
)
200+
201+
return(shap_values)
202+
}
203+
204+
205+
aggregate_shap_multiple_observations <- function(shap_res_list, feature_names) {
156206

157-
shap_values <- data.frame(t(sapply(tmp_res$S, cbind)))
158-
colnames(shap_values) <- colnames(tmp_res$X)
159-
rownames(shap_values) <- paste("t=", explainer$times, sep = "")
207+
if (length(shap_res_list) > 1) {
208+
209+
full_survshap_results <- data.table::rbindlist(
210+
l = shap_res_list,
211+
use.names = TRUE,
212+
idcol = TRUE
213+
)
214+
215+
# compute arithmetic mean for each time-point and feature across
216+
# multiple observations
217+
tmp_res <- full_survshap_results[
218+
, lapply(.SD, mean), by = "rn", .SDcols = feature_names
219+
]
220+
} else {
221+
# no aggregation required
222+
tmp_res <- shap_res_list[[1]]
223+
}
224+
shap_values <- tmp_res[, .SD, .SDcols = setdiff(colnames(tmp_res), "rn")]
225+
# transform to data.frame to make everything compatible with
226+
# previous code
227+
shap_values <- data.frame(shap_values)
228+
rownames(shap_values) <- tmp_res$rn
160229
return(shap_values)
161230
}

R/zzz.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#' @import data.table
2+
NULL

tests/testthat/test-predict_parts.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,26 @@ test_that("survshap explanations work", {
4848

4949
})
5050

51+
test_that("global survshap explanations with kernelshap work for ranger", {
52+
veteran <- survival::veteran
53+
54+
rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
55+
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE)
56+
57+
parts_ranger <- predict_parts(
58+
rsf_ranger_exp,
59+
veteran[1:40, !colnames(veteran) %in% c("time", "status")],
60+
y_true = Surv(veteran$time[1:40], veteran$status[1:40]),
61+
aggregation_method = "mean_absolute",
62+
calculation_method = "kernelshap"
63+
)
64+
65+
expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap"))
66+
expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp$times))
67+
expect_true(all(colnames(parts_ranger$result) == colnames(rsf_ranger_exp$data)))
68+
69+
})
70+
5171

5272
test_that("survlime explanations work", {
5373

0 commit comments

Comments
 (0)