Skip to content

Commit bb47009

Browse files
committed
chore: merged multirow-support into main
2 parents 44f56db + 00d6a65 commit bb47009

File tree

5 files changed

+123
-31
lines changed

5 files changed

+123
-31
lines changed

DESCRIPTION

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: survex
22
Title: Explainable Machine Learning in Survival Analysis
3-
Version: 1.0.0.9000
3+
Version: 1.0.0.9001
44
Authors@R:
55
c(
66
person("Mikołaj", "Spytek", email = "[email protected]", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")),
@@ -27,7 +27,8 @@ Imports:
2727
treeshap,
2828
pec,
2929
survival,
30-
patchwork
30+
patchwork,
31+
data.table
3132
Suggests:
3233
censored,
3334
covr,

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ export(survival_to_cumulative_hazard)
6868
export(theme_default_survex)
6969
export(theme_vertical_default_survex)
7070
export(transform_to_stepfunction)
71+
import(data.table)
7172
import(ggplot2)
7273
import(patchwork)
7374
import(survival)

R/surv_shap.R

Lines changed: 97 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,16 @@ surv_shap <- function(explainer,
2424
B = 25,
2525
exact = FALSE
2626
) {
27-
# if providing y_true, it must be exactly one single new observation,
28-
# otherwise the indexing of y_true doesn't make any sense
29-
stopifnot(
30-
ifelse(!is.null(y_true), nrow(new_observation) == 1, TRUE),
31-
nrow(new_observation) == 1 # produces nonesense, if more than on new observation
32-
)
27+
# make this code work for multiple observations
28+
stopifnot(ifelse(!is.null(y_true),
29+
ifelse(is.matrix(y_true),
30+
nrow(new_observation) == nrow(y_true),
31+
is.null(dim(y_true)) && length(y_true) == 2L),
32+
TRUE))
3333

3434
test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE)
3535

36-
# make that this also works for 1-row matrix
36+
# make this code also work for 1-row matrix
3737
col_index <- which(colnames(new_observation) %in% colnames(explainer$data))
3838
if (is.matrix(new_observation) && nrow(new_observation) == 1) {
3939
new_observation <- as.matrix(t(new_observation[, col_index]))
@@ -45,8 +45,11 @@ surv_shap <- function(explainer,
4545

4646
if (!is.null(y_true)) {
4747
if (is.matrix(y_true)) {
48-
y_true_ind <- y_true[1, 2]
49-
y_true_time <- y_true[1, 1]
48+
# above, we have already checked that nrows of observations are
49+
# identical to nrows of y_true; thus we do not need to index
50+
# the first row here
51+
y_true_ind <- y_true[, 2]
52+
y_true_time <- y_true[, 1]
5053
} else {
5154
y_true_ind <- y_true[2]
5255
y_true_time <- y_true[1]
@@ -66,6 +69,7 @@ surv_shap <- function(explainer,
6669

6770
res <- list()
6871
res$eval_times <- explainer$times
72+
# to display final object correctly, when is.matrix(new_observation) == TRUE
6973
res$variable_values <- as.data.frame(new_observation)
7074

7175
res$result <- switch(calculation_method,
@@ -175,15 +179,37 @@ aggregate_surv_shap <- function(survshap, method) {
175179
use_kernelshap <- function(explainer, new_observation, ...){
176180

177181
predfun <- function(model, newdata){
178-
explainer$predict_survival_function(model, newdata, times = explainer$times)
182+
explainer$predict_survival_function(
183+
model,
184+
newdata,
185+
times = explainer$times
186+
)
179187
}
180188

181-
tmp_res <- kernelshap::kernelshap(explainer$model, new_observation, bg_X = explainer$data,
182-
pred_fun = predfun, verbose = FALSE)
189+
tmp_res_list <- sapply(
190+
X = as.character(seq_len(nrow(new_observation))),
191+
FUN = function(i) {
192+
tmp_res <- kernelshap::kernelshap(
193+
object = explainer$model,
194+
X = new_observation[as.integer(i), ],
195+
bg_X = explainer$data,
196+
pred_fun = predfun,
197+
verbose = FALSE
198+
)
199+
tmp_shap_values <- data.frame(t(sapply(tmp_res$S, cbind)))
200+
colnames(tmp_shap_values) <- colnames(tmp_res$X)
201+
rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "")
202+
data.table::as.data.table(tmp_shap_values, keep.rownames = TRUE)
203+
},
204+
USE.NAMES = TRUE,
205+
simplify = FALSE
206+
)
207+
208+
shap_values <- aggregate_shap_multiple_observations(
209+
shap_res_list = tmp_res_list,
210+
feature_names = colnames(new_observation)
211+
)
183212

184-
shap_values <- data.frame(t(sapply(tmp_res$S, cbind)))
185-
colnames(shap_values) <- colnames(tmp_res$X)
186-
rownames(shap_values) <- paste("t=", explainer$times, sep = "")
187213
return(shap_values)
188214
}
189215

@@ -200,21 +226,63 @@ use_treeshap <- function(explainer, new_observation, ...){
200226
data = explainer$data
201227
)
202228

203-
tmp_res <- do.call(
204-
rbind,
205-
lapply(
206-
tmp_unified,
207-
function(m) {
208-
treeshap::treeshap(
209-
unified_model = m,
210-
x = new_observation
211-
)$shaps
212-
}
213-
)
229+
tmp_res_list <- sapply(
230+
X = as.character(seq_len(nrow(new_observation))),
231+
FUN = function(i) {
232+
tmp_res <- do.call(
233+
rbind,
234+
lapply(
235+
tmp_unified,
236+
function(m) {
237+
treeshap::treeshap(
238+
unified_model = m,
239+
x = new_observation
240+
)$shaps
241+
}
242+
)
243+
)
244+
245+
tmp_shap_values <- data.frame(tmp_res)
246+
colnames(tmp_shap_values) <- colnames(tmp_res)
247+
rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "")
248+
data.table::as.data.table(tmp_shap_values, keep.rownames = TRUE)
249+
},
250+
USE.NAMES = TRUE,
251+
simplify = FALSE
252+
)
253+
254+
shap_values <- aggregate_shap_multiple_observations(
255+
shap_res_list = tmp_res_list,
256+
feature_names = colnames(new_observation)
214257
)
215258

216-
shap_values <- data.frame(tmp_res)
217-
colnames(shap_values) <- colnames(tmp_res)
218-
rownames(shap_values) <- paste("t=", explainer$times, sep = "")
259+
return(shap_values)
260+
}
261+
262+
263+
aggregate_shap_multiple_observations <- function(shap_res_list, feature_names) {
264+
265+
if (length(shap_res_list) > 1) {
266+
267+
full_survshap_results <- data.table::rbindlist(
268+
l = shap_res_list,
269+
use.names = TRUE,
270+
idcol = TRUE
271+
)
272+
273+
# compute arithmetic mean for each time-point and feature across
274+
# multiple observations
275+
tmp_res <- full_survshap_results[
276+
, lapply(.SD, mean), by = "rn", .SDcols = feature_names
277+
]
278+
} else {
279+
# no aggregation required
280+
tmp_res <- shap_res_list[[1]]
281+
}
282+
shap_values <- tmp_res[, .SD, .SDcols = setdiff(colnames(tmp_res), "rn")]
283+
# transform to data.frame to make everything compatible with
284+
# previous code
285+
shap_values <- data.frame(shap_values)
286+
rownames(shap_values) <- tmp_res$rn
219287
return(shap_values)
220288
}

R/zzz.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#' @import data.table
2+
NULL

tests/testthat/test-predict_parts.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,26 @@ test_that("survshap explanations work", {
6262

6363
})
6464

65+
test_that("global survshap explanations with kernelshap work for ranger", {
66+
veteran <- survival::veteran
67+
68+
rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
69+
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE)
70+
71+
parts_ranger <- predict_parts(
72+
rsf_ranger_exp,
73+
veteran[1:40, !colnames(veteran) %in% c("time", "status")],
74+
y_true = Surv(veteran$time[1:40], veteran$status[1:40]),
75+
aggregation_method = "mean_absolute",
76+
calculation_method = "kernelshap"
77+
)
78+
79+
expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap"))
80+
expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp$times))
81+
expect_true(all(colnames(parts_ranger$result) == colnames(rsf_ranger_exp$data)))
82+
83+
})
84+
6585

6686
test_that("survlime explanations work", {
6787

0 commit comments

Comments
 (0)