55# ' @param output_type a character, either `"survival"` or `"chf"`. Determines which type of prediction should be used for explanations.
66# ' @param ... additional parameters, passed to internal functions
77# ' @param y_true a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting
8- # ' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements) or `"exact_kernel"` for exact Kernel SHAP estimation
8+ # ' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements), `"exact_kernel"` for exact Kernel SHAP estimation,
9+ # ' or `"treeshap"` for use of `treeshap` library (efficient implementation to compute SHAP values for tree-based models).
910# ' @param aggregation_method a character, either `"integral"`, `"integral_absolute"`, `"mean_absolute"`, `"max_absolute"`, or `"sum_of_squares"`
1011# '
1112# ' @return A list, containing the calculated SurvSHAP(t) results in the `result` field
@@ -19,10 +20,15 @@ surv_shap <- function(explainer,
1920 output_type ,
2021 ... ,
2122 y_true = NULL ,
22- calculation_method = " kernelshap" ,
23- aggregation_method = " integral" ) {
23+ calculation_method = c(" kernelshap" , " exact_kernel" , " treeshap" ),
24+ aggregation_method = c(" integral" , " mean_absolute" , " max_absolute" , " sum_of_squares" )
25+ ) {
26+ calculation_method <- match.arg(calculation_method )
27+ aggregation_method <- match.arg(aggregation_method )
28+
29+ # make this code work for multiple observations
2430 stopifnot(
25- " `y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse(
31+ " `y_true` must be either a matrix with one row per observation in `new_observation` or a vector of length == 2" = ifelse(
2632 ! is.null(y_true ),
2733 ifelse(
2834 is.matrix(y_true ),
@@ -33,14 +39,40 @@ surv_shap <- function(explainer,
3339 )
3440 )
3541
42+ if (calculation_method == " kernelshap" ) {
43+ if (! requireNamespace(" kernelshap" , quietly = TRUE )) {
44+ stop(
45+ paste0(
46+ " Package \" kernelshap\" must be installed to use " ,
47+ " 'calculation_method = \" kernelshap\" '."
48+ ),
49+ call. = FALSE
50+ )
51+ }
52+ }
53+ if (calculation_method == " treeshap" ) {
54+ if (! requireNamespace(" treeshap" , quietly = TRUE )) {
55+ stop(
56+ paste0(
57+ " Package \" treeshap\" must be installed to use " ,
58+ " 'calculation_method = \" treeshap\" '."
59+ ),
60+ call. = FALSE
61+ )
62+ }
63+ }
64+
3665 test_explainer(explainer , " surv_shap" , has_data = TRUE , has_y = TRUE , has_survival = TRUE )
3766
3867 # make this code also work for 1-row matrix
3968 col_index <- which(colnames(new_observation ) %in% colnames(explainer $ data ))
4069 if (is.matrix(new_observation ) && nrow(new_observation ) == 1 ) {
41- new_observation <- as.matrix(t(new_observation [, col_index ]))
70+ new_observation <- data.frame ( as.matrix(t(new_observation [, col_index ]) ))
4271 } else {
4372 new_observation <- new_observation [, col_index ]
73+ if (! inherits(new_observation , " data.frame" )) {
74+ new_observation <- data.frame (new_observation )
75+ }
4476 }
4577
4678 if (ncol(explainer $ data ) != ncol(new_observation )) {
@@ -59,14 +91,9 @@ surv_shap <- function(explainer,
5991 }
6092 }
6193
62- # hack to use rf-model death times as explainer death times, as
63- # treeshap::ranger_surv_fun.unify extracts survival times directly
64- # from the ranger object for calculating the predictions
6594 if (calculation_method == " treeshap" ) {
66- if (inherits(explainer $ model , " ranger" )) {
67- explainer $ times <- explainer $ model $ unique.death.times
68- } else {
69- stop(" Calculation method `treeshap` is currently only implemented for `ranger`." )
95+ if (! inherits(explainer $ model , " ranger" )) {
96+ stop(" Calculation method `treeshap` is currently only implemented for `ranger` survival models." )
7097 }
7198 }
7299
@@ -75,9 +102,14 @@ surv_shap <- function(explainer,
75102 # to display final object correctly, when is.matrix(new_observation) == TRUE
76103 res $ variable_values <- as.data.frame(new_observation )
77104 res $ result <- switch (calculation_method ,
78- " exact_kernel" = use_exact_shap(explainer , new_observation , output_type , ... ),
79- " kernelshap" = use_kernelshap(explainer , new_observation , output_type , ... ),
80- stop(" Only `exact_kernel` and `kernelshap` calculation methods are implemented" )
105+ " exact_kernel" = use_exact_shap(explainer , new_observation , output_type , ... ),
106+ " kernelshap" = use_kernelshap(explainer , new_observation , output_type , ... ),
107+ " treeshap" = use_treeshap(explainer , new_observation , ... ),
108+ stop(" Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented" ))
109+ # quality-check here
110+ stopifnot(
111+ " Number of rows of SurvSHAP table are not identical with length(eval_times)" =
112+ nrow(res $ result ) == length(res $ eval_times )
81113 )
82114
83115 if (! is.null(y_true )) res $ y_true <- c(y_true_time = y_true_time , y_true_ind = y_true_ind )
@@ -97,7 +129,7 @@ surv_shap <- function(explainer,
97129 return (res )
98130}
99131
100- use_exact_shap <- function (explainer , new_observation , output_type , observation_aggregation_method , ... ) {
132+ use_exact_shap <- function (explainer , new_observation , output_type , ... ) {
101133 shap_values <- sapply(
102134 X = as.character(seq_len(nrow(new_observation ))),
103135 FUN = function (i ) {
@@ -134,11 +166,8 @@ shap_kernel <- function(explainer, new_observation, output_type, ...) {
134166 timestamps
135167 )
136168
137-
138-
139169 shap_values <- as.data.frame(shap_values , row.names = colnames(explainer $ data ))
140170 colnames(shap_values ) <- paste(" t=" , timestamps , sep = " " )
141-
142171 return (t(shap_values ))
143172}
144173
@@ -215,19 +244,31 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_
215244 times = explainer $ times
216245 )
217246 }
247+ }
248+
249+ stopifnot(
250+ " new_observation must be a data.frame" = inherits(
251+ new_observation , " data.frame" )
252+ )
218253
254+ # get explainer data to be able to make class checks and transformations
255+ explainer_data <- explainer $ data
256+ # ensure that classes of explainer$data and new_observation are equal
257+ if (! inherits(explainer_data , " data.frame" )) {
258+ explainer_data <- data.frame (explainer_data )
219259 }
220260
221261 shap_values <- sapply(
222262 X = as.character(seq_len(nrow(new_observation ))),
223263 FUN = function (i ) {
224264 tmp_res <- kernelshap :: kernelshap(
225265 object = explainer $ model ,
226- X = new_observation [as.integer(i ), ],
227- bg_X = explainer $ data ,
266+ X = new_observation [as.integer(i ), ], # data.frame
267+ bg_X = explainer_data , # data.frame
228268 pred_fun = predfun ,
229269 verbose = FALSE
230270 )
271+ # kernelshap-test: is.matrix(X) == is.matrix(bg_X) should evaluate to `TRUE`
231272 tmp_shap_values <- data.frame (t(sapply(tmp_res $ S , cbind )))
232273 colnames(tmp_shap_values ) <- colnames(tmp_res $ X )
233274 rownames(tmp_shap_values ) <- paste(" t=" , explainer $ times , sep = " " )
@@ -240,6 +281,68 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_
240281 return (shap_values )
241282}
242283
284+ use_treeshap <- function (explainer , new_observation , ... ){
285+
286+ stopifnot(
287+ " new_observation must be a data.frame" = inherits(
288+ new_observation , " data.frame" )
289+ )
290+
291+ # init unify_append_args
292+ unify_append_args <- list ()
293+
294+ if (inherits(explainer $ model , " ranger" )) {
295+ # UNIFY_FUN to prepare code for easy Integration of other ml algorithms
296+ # that are supported by treeshap
297+ UNIFY_FUN <- treeshap :: ranger_surv.unify
298+ unify_append_args <- list (type = " survival" , times = explainer $ times )
299+ } else {
300+ stop(" Support for `treeshap` is currently only implemented for `ranger`." )
301+ }
302+
303+ unify_args <- list (
304+ rf_model = explainer $ model ,
305+ data = explainer $ data
306+ )
307+
308+ if (length(unify_append_args ) > 0 ) {
309+ unify_args <- c(unify_args , unify_append_args )
310+ }
311+
312+ tmp_unified <- do.call(UNIFY_FUN , unify_args )
313+
314+ shap_values <- sapply(
315+ X = as.character(seq_len(nrow(new_observation ))),
316+ FUN = function (i ) {
317+ tmp_res <- do.call(
318+ rbind ,
319+ lapply(
320+ tmp_unified ,
321+ function (m ) {
322+ new_obs_mat <- as.matrix(new_observation [as.integer(i ), ])
323+ # ensure that matrix has expected dimensions; as.integer is
324+ # necessary for valid comparison with "identical"
325+ stopifnot(identical(dim(new_obs_mat ), as.integer(c(1L , ncol(new_observation )))))
326+ treeshap :: treeshap(
327+ unified_model = m ,
328+ x = new_obs_mat
329+ )$ shaps
330+ }
331+ )
332+ )
333+
334+ tmp_shap_values <- data.frame (tmp_res )
335+ colnames(tmp_shap_values ) <- colnames(tmp_res )
336+ rownames(tmp_shap_values ) <- paste(" t=" , explainer $ times , sep = " " )
337+ tmp_shap_values
338+ },
339+ USE.NAMES = TRUE ,
340+ simplify = FALSE
341+ )
342+
343+ return (shap_values )
344+
345+ }
243346
244347# ' @keywords internal
245348aggregate_shap_multiple_observations <- function (shap_res_list , feature_names , aggregation_function ) {
0 commit comments