@@ -24,16 +24,16 @@ surv_shap <- function(explainer,
2424 B = 25 ,
2525 exact = FALSE
2626) {
27- # if providing y_true, it must be exactly one single new observation,
28- # otherwise the indexing of y_true doesn't make any sense
29- stopifnot(
30- ifelse( ! is.null( y_true ), nrow(new_observation ) == 1 , TRUE ),
31- nrow( new_observation ) == 1 # produces nonesense, if more than on new observation
32- )
27+ # make this code work for multiple observations
28+ stopifnot(ifelse( ! is.null( y_true ),
29+ ifelse(is.matrix( y_true ),
30+ nrow(new_observation ) == nrow( y_true ),
31+ is.null(dim( y_true )) && length( y_true ) == 2L ),
32+ TRUE ) )
3333
3434 test_explainer(explainer , " surv_shap" , has_data = TRUE , has_y = TRUE , has_survival = TRUE )
3535
36- # make that this also works for 1-row matrix
36+ # make this code also work for 1-row matrix
3737 col_index <- which(colnames(new_observation ) %in% colnames(explainer $ data ))
3838 if (is.matrix(new_observation ) && nrow(new_observation ) == 1 ) {
3939 new_observation <- as.matrix(t(new_observation [, col_index ]))
@@ -45,8 +45,11 @@ surv_shap <- function(explainer,
4545
4646 if (! is.null(y_true )) {
4747 if (is.matrix(y_true )) {
48- y_true_ind <- y_true [1 , 2 ]
49- y_true_time <- y_true [1 , 1 ]
48+ # above, we have already checked that nrows of observations are
49+ # identical to nrows of y_true; thus we do not need to index
50+ # the first row here
51+ y_true_ind <- y_true [, 2 ]
52+ y_true_time <- y_true [, 1 ]
5053 } else {
5154 y_true_ind <- y_true [2 ]
5255 y_true_time <- y_true [1 ]
@@ -66,6 +69,7 @@ surv_shap <- function(explainer,
6669
6770 res <- list ()
6871 res $ eval_times <- explainer $ times
72+ # to display final object correctly, when is.matrix(new_observation) == TRUE
6973 res $ variable_values <- as.data.frame(new_observation )
7074
7175 res $ result <- switch (calculation_method ,
@@ -175,15 +179,37 @@ aggregate_surv_shap <- function(survshap, method) {
175179use_kernelshap <- function (explainer , new_observation , ... ){
176180
177181 predfun <- function (model , newdata ){
178- explainer $ predict_survival_function(model , newdata , times = explainer $ times )
182+ explainer $ predict_survival_function(
183+ model ,
184+ newdata ,
185+ times = explainer $ times
186+ )
179187 }
180188
181- tmp_res <- kernelshap :: kernelshap(explainer $ model , new_observation , bg_X = explainer $ data ,
182- pred_fun = predfun , verbose = FALSE )
189+ tmp_res_list <- sapply(
190+ X = as.character(seq_len(nrow(new_observation ))),
191+ FUN = function (i ) {
192+ tmp_res <- kernelshap :: kernelshap(
193+ object = explainer $ model ,
194+ X = new_observation [as.integer(i ), ],
195+ bg_X = explainer $ data ,
196+ pred_fun = predfun ,
197+ verbose = FALSE
198+ )
199+ tmp_shap_values <- data.frame (t(sapply(tmp_res $ S , cbind )))
200+ colnames(tmp_shap_values ) <- colnames(tmp_res $ X )
201+ rownames(tmp_shap_values ) <- paste(" t=" , explainer $ times , sep = " " )
202+ data.table :: as.data.table(tmp_shap_values , keep.rownames = TRUE )
203+ },
204+ USE.NAMES = TRUE ,
205+ simplify = FALSE
206+ )
207+
208+ shap_values <- aggregate_shap_multiple_observations(
209+ shap_res_list = tmp_res_list ,
210+ feature_names = colnames(new_observation )
211+ )
183212
184- shap_values <- data.frame (t(sapply(tmp_res $ S , cbind )))
185- colnames(shap_values ) <- colnames(tmp_res $ X )
186- rownames(shap_values ) <- paste(" t=" , explainer $ times , sep = " " )
187213 return (shap_values )
188214}
189215
@@ -200,21 +226,63 @@ use_treeshap <- function(explainer, new_observation, ...){
200226 data = explainer $ data
201227 )
202228
203- tmp_res <- do.call(
204- rbind ,
205- lapply(
206- tmp_unified ,
207- function (m ) {
208- treeshap :: treeshap(
209- unified_model = m ,
210- x = new_observation
211- )$ shaps
212- }
213- )
229+ tmp_res_list <- sapply(
230+ X = as.character(seq_len(nrow(new_observation ))),
231+ FUN = function (i ) {
232+ tmp_res <- do.call(
233+ rbind ,
234+ lapply(
235+ tmp_unified ,
236+ function (m ) {
237+ treeshap :: treeshap(
238+ unified_model = m ,
239+ x = new_observation
240+ )$ shaps
241+ }
242+ )
243+ )
244+
245+ tmp_shap_values <- data.frame (tmp_res )
246+ colnames(tmp_shap_values ) <- colnames(tmp_res )
247+ rownames(tmp_shap_values ) <- paste(" t=" , explainer $ times , sep = " " )
248+ data.table :: as.data.table(tmp_shap_values , keep.rownames = TRUE )
249+ },
250+ USE.NAMES = TRUE ,
251+ simplify = FALSE
252+ )
253+
254+ shap_values <- aggregate_shap_multiple_observations(
255+ shap_res_list = tmp_res_list ,
256+ feature_names = colnames(new_observation )
214257 )
215258
216- shap_values <- data.frame (tmp_res )
217- colnames(shap_values ) <- colnames(tmp_res )
218- rownames(shap_values ) <- paste(" t=" , explainer $ times , sep = " " )
259+ return (shap_values )
260+ }
261+
262+
263+ aggregate_shap_multiple_observations <- function (shap_res_list , feature_names ) {
264+
265+ if (length(shap_res_list ) > 1 ) {
266+
267+ full_survshap_results <- data.table :: rbindlist(
268+ l = shap_res_list ,
269+ use.names = TRUE ,
270+ idcol = TRUE
271+ )
272+
273+ # compute arithmetic mean for each time-point and feature across
274+ # multiple observations
275+ tmp_res <- full_survshap_results [
276+ , lapply(.SD , mean ), by = " rn" , .SDcols = feature_names
277+ ]
278+ } else {
279+ # no aggregation required
280+ tmp_res <- shap_res_list [[1 ]]
281+ }
282+ shap_values <- tmp_res [, .SD , .SDcols = setdiff(colnames(tmp_res ), " rn" )]
283+ # transform to data.frame to make everything compatible with
284+ # previous code
285+ shap_values <- data.frame (shap_values )
286+ rownames(shap_values ) <- tmp_res $ rn
219287 return (shap_values )
220288}
0 commit comments