55# ' @param output_type a character, either `"survival"` or `"chf"`. Determines which type of prediction should be used for explanations.
66# ' @param ... additional parameters, passed to internal functions
77# ' @param y_true a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting
8- # ' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements) or `"exact_kernel"` for exact Kernel SHAP estimation
8+ # ' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements), `"exact_kernel"` for exact Kernel SHAP estimation, or `"treeshap"` for use of `treeshap` library (efficient implementation to compute SHAP values for tree-based models).
99# ' @param aggregation_method a character, either `"integral"`, `"integral_absolute"`, `"mean_absolute"`, `"max_absolute"`, or `"sum_of_squares"`
1010# '
1111# ' @return A list, containing the calculated SurvSHAP(t) results in the `result` field
@@ -19,10 +19,15 @@ surv_shap <- function(explainer,
1919 output_type ,
2020 ... ,
2121 y_true = NULL ,
22- calculation_method = " kernelshap" ,
23- aggregation_method = " integral" ) {
22+ calculation_method = c(" kernelshap" , " exact_kernel" , " treeshap" ),
23+ aggregation_method = c(" integral" , " mean_absolute" , " max_absolute" , " sum_of_squares" )
24+ ) {
25+ calculation_method <- match.arg(calculation_method )
26+ aggregation_method <- match.arg(aggregation_method )
27+
28+ # make this code work for multiple observations
2429 stopifnot(
25- " `y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse(
30+ " `y_true` must be either a matrix with one row per observation in `new_observation` or a vector of length == 2" = ifelse(
2631 ! is.null(y_true ),
2732 ifelse(
2833 is.matrix(y_true ),
@@ -33,14 +38,40 @@ surv_shap <- function(explainer,
3338 )
3439 )
3540
41+ if (calculation_method == " kernelshap" ) {
42+ if (! requireNamespace(" kernelshap" , quietly = TRUE )) {
43+ stop(
44+ paste0(
45+ " Package \" kernelshap\" must be installed to use " ,
46+ " 'calculation_method = \" kernelshap\" '."
47+ ),
48+ call. = FALSE
49+ )
50+ }
51+ }
52+ if (calculation_method == " treeshap" ) {
53+ if (! requireNamespace(" treeshap" , quietly = TRUE )) {
54+ stop(
55+ paste0(
56+ " Package \" treeshap\" must be installed to use " ,
57+ " 'calculation_method = \" treeshap\" '."
58+ ),
59+ call. = FALSE
60+ )
61+ }
62+ }
63+
3664 test_explainer(explainer , " surv_shap" , has_data = TRUE , has_y = TRUE , has_survival = TRUE )
3765
3866 # make this code also work for 1-row matrix
3967 col_index <- which(colnames(new_observation ) %in% colnames(explainer $ data ))
4068 if (is.matrix(new_observation ) && nrow(new_observation ) == 1 ) {
41- new_observation <- as.matrix(t(new_observation [, col_index ]))
69+ new_observation <- data.frame ( as.matrix(t(new_observation [, col_index ]) ))
4270 } else {
4371 new_observation <- new_observation [, col_index ]
72+ if (! inherits(new_observation , " data.frame" )) {
73+ new_observation <- data.frame (new_observation )
74+ }
4475 }
4576
4677 if (ncol(explainer $ data ) != ncol(new_observation )) {
@@ -59,14 +90,25 @@ surv_shap <- function(explainer,
5990 }
6091 }
6192
93+ if (calculation_method == " treeshap" ) {
94+ if (! inherits(explainer $ model , " ranger" )) {
95+ stop(" Calculation method `treeshap` is currently only implemented for `ranger` survival models." )
96+ }
97+ }
98+
6299 res <- list ()
63100 res $ eval_times <- explainer $ times
64101 # to display final object correctly, when is.matrix(new_observation) == TRUE
65102 res $ variable_values <- as.data.frame(new_observation )
66103 res $ result <- switch (calculation_method ,
67- " exact_kernel" = use_exact_shap(explainer , new_observation , output_type , ... ),
68- " kernelshap" = use_kernelshap(explainer , new_observation , output_type , ... ),
69- stop(" Only `exact_kernel` and `kernelshap` calculation methods are implemented" )
104+ " exact_kernel" = use_exact_shap(explainer , new_observation , output_type , ... ),
105+ " kernelshap" = use_kernelshap(explainer , new_observation , output_type , ... ),
106+ " treeshap" = use_treeshap(explainer , new_observation , ... ),
107+ stop(" Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented" ))
108+ # quality-check here
109+ stopifnot(
110+ " Number of rows of SurvSHAP table are not identical with length(eval_times)" =
111+ nrow(res $ result ) == length(res $ eval_times )
70112 )
71113
72114 if (! is.null(y_true )) res $ y_true <- c(y_true_time = y_true_time , y_true_ind = y_true_ind )
@@ -86,7 +128,7 @@ surv_shap <- function(explainer,
86128 return (res )
87129}
88130
89- use_exact_shap <- function (explainer , new_observation , output_type , observation_aggregation_method , ... ) {
131+ use_exact_shap <- function (explainer , new_observation , output_type , ... ) {
90132 shap_values <- sapply(
91133 X = as.character(seq_len(nrow(new_observation ))),
92134 FUN = function (i ) {
@@ -123,11 +165,8 @@ shap_kernel <- function(explainer, new_observation, output_type, ...) {
123165 timestamps
124166 )
125167
126-
127-
128168 shap_values <- as.data.frame(shap_values , row.names = colnames(explainer $ data ))
129169 colnames(shap_values ) <- paste(" t=" , timestamps , sep = " " )
130-
131170 return (t(shap_values ))
132171}
133172
@@ -204,19 +243,31 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_
204243 times = explainer $ times
205244 )
206245 }
246+ }
247+
248+ stopifnot(
249+ " new_observation must be a data.frame" = inherits(
250+ new_observation , " data.frame" )
251+ )
207252
253+ # get explainer data to be able to make class checks and transformations
254+ explainer_data <- explainer $ data
255+ # ensure that classes of explainer$data and new_observation are equal
256+ if (! inherits(explainer_data , " data.frame" )) {
257+ explainer_data <- data.frame (explainer_data )
208258 }
209259
210260 shap_values <- sapply(
211261 X = as.character(seq_len(nrow(new_observation ))),
212262 FUN = function (i ) {
213263 tmp_res <- kernelshap :: kernelshap(
214264 object = explainer $ model ,
215- X = new_observation [as.integer(i ), ],
216- bg_X = explainer $ data ,
265+ X = new_observation [as.integer(i ), ], # data.frame
266+ bg_X = explainer_data , # data.frame
217267 pred_fun = predfun ,
218268 verbose = FALSE
219269 )
270+ # kernelshap-test: is.matrix(X) == is.matrix(bg_X) should evaluate to `TRUE`
220271 tmp_shap_values <- data.frame (t(sapply(tmp_res $ S , cbind )))
221272 colnames(tmp_shap_values ) <- colnames(tmp_res $ X )
222273 rownames(tmp_shap_values ) <- paste(" t=" , explainer $ times , sep = " " )
@@ -229,6 +280,69 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_
229280 return (shap_values )
230281}
231282
283+ use_treeshap <- function (explainer , new_observation , ... ){
284+
285+ stopifnot(
286+ " new_observation must be a data.frame" = inherits(
287+ new_observation , " data.frame" )
288+ )
289+
290+ # init unify_append_args
291+ unify_append_args <- list ()
292+
293+ if (inherits(explainer $ model , " ranger" )) {
294+ # UNIFY_FUN to prepare code for easy Integration of other ml algorithms
295+ # that are supported by treeshap
296+ UNIFY_FUN <- treeshap :: ranger_surv.unify
297+ unify_append_args <- list (type = " survival" , times = explainer $ times )
298+ } else {
299+ stop(" Support for `treeshap` is currently only implemented for `ranger`." )
300+ }
301+
302+ unify_args <- list (
303+ rf_model = explainer $ model ,
304+ data = explainer $ data
305+ )
306+
307+ if (length(unify_append_args ) > 0 ) {
308+ unify_args <- c(unify_args , unify_append_args )
309+ }
310+
311+ tmp_unified <- do.call(UNIFY_FUN , unify_args )
312+
313+ shap_values <- sapply(
314+ X = as.character(seq_len(nrow(new_observation ))),
315+ FUN = function (i ) {
316+ tmp_res <- do.call(
317+ rbind ,
318+ lapply(
319+ tmp_unified ,
320+ function (m ) {
321+ new_obs_mat <- new_observation [as.integer(i ), ]
322+ # ensure that matrix has expected dimensions; as.integer is
323+ # necessary for valid comparison with "identical"
324+ stopifnot(identical(dim(new_obs_mat ), as.integer(c(1L , ncol(new_observation )))))
325+ treeshap :: treeshap(
326+ unified_model = m ,
327+ x = new_obs_mat
328+ )$ shaps
329+ }
330+ )
331+ )
332+
333+ tmp_shap_values <- data.frame (tmp_res )
334+ colnames(tmp_shap_values ) <- colnames(tmp_res )
335+ rownames(tmp_shap_values ) <- paste(" t=" , explainer $ times , sep = " " )
336+ tmp_shap_values
337+ },
338+ USE.NAMES = TRUE ,
339+ simplify = FALSE
340+ )
341+
342+ return (shap_values )
343+
344+ }
345+
232346# ' @keywords internal
233347aggregate_shap_multiple_observations <- function (shap_res_list , feature_names , aggregation_function ) {
234348 if (length(shap_res_list ) > 1 ) {
0 commit comments