@@ -24,14 +24,32 @@ surv_shap <- function(explainer,
2424 B = 25 ,
2525 exact = FALSE
2626) {
27+ # make this code work for multiple observations
28+ stopifnot(ifelse(! is.null(y_true ),
29+ ifelse(is.matrix(y_true ),
30+ nrow(new_observation ) == nrow(y_true ),
31+ is.null(dim(y_true )) && length(y_true ) == 2L ),
32+ TRUE ))
33+
2734 test_explainer(explainer , " surv_shap" , has_data = TRUE , has_y = TRUE , has_survival = TRUE )
28- new_observation <- new_observation [, colnames(new_observation ) %in% colnames(explainer $ data )]
35+
36+ # make this code also work for 1-row matrix
37+ col_index <- which(colnames(new_observation ) %in% colnames(explainer $ data ))
38+ if (is.matrix(new_observation ) && nrow(new_observation ) == 1 ) {
39+ new_observation <- as.matrix(t(new_observation [, col_index ]))
40+ } else {
41+ new_observation <- new_observation [, col_index ]
42+ }
43+
2944 if (ncol(explainer $ data ) != ncol(new_observation )) stop(" New observation and data have different number of columns (variables)" )
3045
3146 if (! is.null(y_true )) {
3247 if (is.matrix(y_true )) {
33- y_true_ind <- y_true [1 , 2 ]
34- y_true_time <- y_true [1 , 1 ]
48+ # above, we have already checked that nrows of observations are
49+ # identical to nrows of y_true; thus we do not need to index
50+ # the first row here
51+ y_true_ind <- y_true [, 2 ]
52+ y_true_time <- y_true [, 1 ]
3553 } else {
3654 y_true_ind <- y_true [2 ]
3755 y_true_time <- y_true [1 ]
@@ -40,7 +58,8 @@ surv_shap <- function(explainer,
4058
4159 res <- list ()
4260 res $ eval_times <- explainer $ times
43- res $ variable_values <- new_observation
61+ # to display final object correctly, when is.matrix(new_observation) == TRUE
62+ res $ variable_values <- as.data.frame(new_observation )
4463
4564 res $ result <- switch (calculation_method ,
4665 " exact_kernel" = shap_kernel(explainer , new_observation , ... ),
@@ -148,14 +167,64 @@ aggregate_surv_shap <- function(survshap, method) {
148167use_kernelshap <- function (explainer , new_observation , ... ){
149168
150169 predfun <- function (model , newdata ){
151- explainer $ predict_survival_function(model , newdata , times = explainer $ times )
170+ explainer $ predict_survival_function(
171+ model ,
172+ newdata ,
173+ times = explainer $ times
174+ )
152175 }
153176
154- tmp_res <- kernelshap :: kernelshap(explainer $ model , new_observation , bg_X = explainer $ data ,
155- pred_fun = predfun , verbose = FALSE )
177+ tmp_res_list <- sapply(
178+ X = as.character(seq_len(nrow(new_observation ))),
179+ FUN = function (i ) {
180+ tmp_res <- kernelshap :: kernelshap(
181+ object = explainer $ model ,
182+ X = new_observation [as.integer(i ), ],
183+ bg_X = explainer $ data ,
184+ pred_fun = predfun ,
185+ verbose = FALSE
186+ )
187+ tmp_shap_values <- data.frame (t(sapply(tmp_res $ S , cbind )))
188+ colnames(tmp_shap_values ) <- colnames(tmp_res $ X )
189+ rownames(tmp_shap_values ) <- paste(" t=" , explainer $ times , sep = " " )
190+ data.table :: as.data.table(tmp_shap_values , keep.rownames = TRUE )
191+ },
192+ USE.NAMES = TRUE ,
193+ simplify = FALSE
194+ )
195+
196+ shap_values <- aggregate_shap_multiple_observations(
197+ shap_res_list = tmp_res_list ,
198+ feature_names = colnames(new_observation )
199+ )
200+
201+ return (shap_values )
202+ }
203+
204+
205+ aggregate_shap_multiple_observations <- function (shap_res_list , feature_names ) {
156206
157- shap_values <- data.frame (t(sapply(tmp_res $ S , cbind )))
158- colnames(shap_values ) <- colnames(tmp_res $ X )
159- rownames(shap_values ) <- paste(" t=" , explainer $ times , sep = " " )
207+ if (length(shap_res_list ) > 1 ) {
208+
209+ full_survshap_results <- data.table :: rbindlist(
210+ l = shap_res_list ,
211+ use.names = TRUE ,
212+ idcol = TRUE
213+ )
214+
215+ # compute arithmetic mean for each time-point and feature across
216+ # multiple observations
217+ tmp_res <- full_survshap_results [
218+ , lapply(.SD , mean ), by = " rn" , .SDcols = feature_names
219+ ]
220+ } else {
221+ # no aggregation required
222+ tmp_res <- shap_res_list [[1 ]]
223+ }
224+ shap_values <- tmp_res [, .SD , .SDcols = setdiff(colnames(tmp_res ), " rn" )]
225+ # transform to data.frame to make everything compatible with
226+ # previous code
227+ shap_values <- data.frame (shap_values )
228+ rownames(shap_values ) <- tmp_res $ rn
160229 return (shap_values )
161230}
0 commit comments