66# ' @param ... additional parameters, passed to internal functions
77# ' @param N a positive integer, number of observations used as the background data
88# ' @param y_true a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting
9- # ' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements) or `"exact_kernel"` for exact Kernel SHAP estimation
9+ # ' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements), `"exact_kernel"` for exact Kernel SHAP estimation, or `"treeshap"` for use of `treeshap` library (efficient implementation to compute SHAP values for tree-based models).
1010# ' @param aggregation_method a character, either `"integral"`, `"integral_absolute"`, `"mean_absolute"`, `"max_absolute"`, or `"sum_of_squares"`
1111# '
1212# ' @return A list, containing the calculated SurvSHAP(t) results in the `result` field
@@ -21,10 +21,15 @@ surv_shap <- function(explainer,
2121 ... ,
2222 N = NULL ,
2323 y_true = NULL ,
24- calculation_method = " kernelshap" ,
25- aggregation_method = " integral" ) {
24+ calculation_method = c(" kernelshap" , " exact_kernel" , " treeshap" ),
25+ aggregation_method = c(" integral" , " mean_absolute" , " max_absolute" , " sum_of_squares" )
26+ ) {
27+ calculation_method <- match.arg(calculation_method )
28+ aggregation_method <- match.arg(aggregation_method )
29+
30+ # make this code work for multiple observations
2631 stopifnot(
27- " `y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse(
32+ " `y_true` must be either a matrix with one row per observation in `new_observation` or a vector of length == 2" = ifelse(
2833 ! is.null(y_true ),
2934 ifelse(
3035 is.matrix(y_true ),
@@ -34,13 +39,40 @@ surv_shap <- function(explainer,
3439 TRUE
3540 )
3641 )
42+
43+ if (calculation_method == " kernelshap" ) {
44+ if (! requireNamespace(" kernelshap" , quietly = TRUE )) {
45+ stop(
46+ paste0(
47+ " Package \" kernelshap\" must be installed to use " ,
48+ " 'calculation_method = \" kernelshap\" '."
49+ ),
50+ call. = FALSE
51+ )
52+ }
53+ }
54+ if (calculation_method == " treeshap" ) {
55+ if (! requireNamespace(" treeshap" , quietly = TRUE )) {
56+ stop(
57+ paste0(
58+ " Package \" treeshap\" must be installed to use " ,
59+ " 'calculation_method = \" treeshap\" '."
60+ ),
61+ call. = FALSE
62+ )
63+ }
64+ }
65+
3766 test_explainer(explainer , " surv_shap" , has_data = TRUE , has_y = TRUE , has_survival = TRUE )
3867 # make this code also work for 1-row matrix
3968 col_index <- which(colnames(new_observation ) %in% colnames(explainer $ data ))
4069 if (is.matrix(new_observation ) && nrow(new_observation ) == 1 ) {
41- new_observation <- as.matrix(t(new_observation [, col_index ]))
70+ new_observation <- data.frame ( as.matrix(t(new_observation [, col_index ]) ))
4271 } else {
4372 new_observation <- new_observation [, col_index ]
73+ if (! inherits(new_observation , " data.frame" )) {
74+ new_observation <- data.frame (new_observation )
75+ }
4476 }
4577
4678 if (ncol(explainer $ data ) != ncol(new_observation )) {
@@ -59,14 +91,25 @@ surv_shap <- function(explainer,
5991 }
6092 }
6193
94+ if (calculation_method == " treeshap" ) {
95+ if (! inherits(explainer $ model , " ranger" )) {
96+ stop(" Calculation method `treeshap` is currently only implemented for `ranger` survival models." )
97+ }
98+ }
99+
62100 res <- list ()
63101 res $ eval_times <- explainer $ times
64102 # to display final object correctly, when is.matrix(new_observation) == TRUE
65103 res $ variable_values <- as.data.frame(new_observation )
66104 res $ result <- switch (calculation_method ,
67- " exact_kernel" = use_exact_shap(explainer , new_observation , output_type , N , ... ),
68- " kernelshap" = use_kernelshap(explainer , new_observation , output_type , N , ... ),
69- stop(" Only `exact_kernel` and `kernelshap` calculation methods are implemented" )
105+ " exact_kernel" = use_exact_shap(explainer , new_observation , output_type , ... ),
106+ " kernelshap" = use_kernelshap(explainer , new_observation , output_type , ... ),
107+ " treeshap" = use_treeshap(explainer , new_observation , ... ),
108+ stop(" Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented" ))
109+ # quality-check here
110+ stopifnot(
111+ " Number of rows of SurvSHAP table are not identical with length(eval_times)" =
112+ nrow(res $ result ) == length(res $ eval_times )
70113 )
71114
72115 if (! is.null(y_true )) res $ y_true <- c(y_true_time = y_true_time , y_true_ind = y_true_ind )
@@ -86,6 +129,7 @@ surv_shap <- function(explainer,
86129 return (res )
87130}
88131
132+
89133use_exact_shap <- function (explainer , new_observation , output_type , N , ... ) {
90134 shap_values <- sapply(
91135 X = as.character(seq_len(nrow(new_observation ))),
@@ -125,7 +169,6 @@ shap_kernel <- function(explainer, new_observation, output_type, N, ...) {
125169
126170 shap_values <- as.data.frame(shap_values , row.names = colnames(explainer $ data ))
127171 colnames(shap_values ) <- paste(" t=" , timestamps , sep = " " )
128-
129172 return (t(shap_values ))
130173}
131174
@@ -204,6 +247,18 @@ use_kernelshap <- function(explainer, new_observation, output_type, N, ...) {
204247 }
205248 }
206249
250+ stopifnot(
251+ " new_observation must be a data.frame" = inherits(
252+ new_observation , " data.frame" )
253+ )
254+
255+ # get explainer data to be able to make class checks and transformations
256+ explainer_data <- explainer $ data
257+ # ensure that classes of explainer$data and new_observation are equal
258+ if (! inherits(explainer_data , " data.frame" )) {
259+ explainer_data <- data.frame (explainer_data )
260+ }
261+
207262 if (is.null(N )) N <- nrow(explainer $ data )
208263 background_data <- explainer $ data [sample(1 : nrow(explainer $ data ), N ),]
209264
@@ -212,11 +267,12 @@ use_kernelshap <- function(explainer, new_observation, output_type, N, ...) {
212267 FUN = function (i ) {
213268 tmp_res <- kernelshap :: kernelshap(
214269 object = explainer $ model ,
215- X = new_observation [as.integer(i ), ],
216- bg_X = background_data ,
270+ X = new_observation [as.integer(i ), ], # data.frame
271+ bg_X = explainer_data , # data.frame
217272 pred_fun = predfun ,
218273 verbose = FALSE
219274 )
275+ # kernelshap-test: is.matrix(X) == is.matrix(bg_X) should evaluate to `TRUE`
220276 tmp_shap_values <- data.frame (t(sapply(tmp_res $ S , cbind )))
221277 colnames(tmp_shap_values ) <- colnames(tmp_res $ X )
222278 rownames(tmp_shap_values ) <- paste(" t=" , explainer $ times , sep = " " )
@@ -229,6 +285,69 @@ use_kernelshap <- function(explainer, new_observation, output_type, N, ...) {
229285 return (shap_values )
230286}
231287
288+ use_treeshap <- function (explainer , new_observation , ... ){
289+
290+ stopifnot(
291+ " new_observation must be a data.frame" = inherits(
292+ new_observation , " data.frame" )
293+ )
294+
295+ # init unify_append_args
296+ unify_append_args <- list ()
297+
298+ if (inherits(explainer $ model , " ranger" )) {
299+ # UNIFY_FUN to prepare code for easy Integration of other ml algorithms
300+ # that are supported by treeshap
301+ UNIFY_FUN <- treeshap :: ranger_surv.unify
302+ unify_append_args <- list (type = " survival" , times = explainer $ times )
303+ } else {
304+ stop(" Support for `treeshap` is currently only implemented for `ranger`." )
305+ }
306+
307+ unify_args <- list (
308+ rf_model = explainer $ model ,
309+ data = explainer $ data
310+ )
311+
312+ if (length(unify_append_args ) > 0 ) {
313+ unify_args <- c(unify_args , unify_append_args )
314+ }
315+
316+ tmp_unified <- do.call(UNIFY_FUN , unify_args )
317+
318+ shap_values <- sapply(
319+ X = as.character(seq_len(nrow(new_observation ))),
320+ FUN = function (i ) {
321+ tmp_res <- do.call(
322+ rbind ,
323+ lapply(
324+ tmp_unified ,
325+ function (m ) {
326+ new_obs_mat <- new_observation [as.integer(i ), ]
327+ # ensure that matrix has expected dimensions; as.integer is
328+ # necessary for valid comparison with "identical"
329+ stopifnot(identical(dim(new_obs_mat ), as.integer(c(1L , ncol(new_observation )))))
330+ treeshap :: treeshap(
331+ unified_model = m ,
332+ x = new_obs_mat
333+ )$ shaps
334+ }
335+ )
336+ )
337+
338+ tmp_shap_values <- data.frame (tmp_res )
339+ colnames(tmp_shap_values ) <- colnames(tmp_res )
340+ rownames(tmp_shap_values ) <- paste(" t=" , explainer $ times , sep = " " )
341+ tmp_shap_values
342+ },
343+ USE.NAMES = TRUE ,
344+ simplify = FALSE
345+ )
346+
347+ return (shap_values )
348+
349+ }
350+
232351# ' @keywords internal
233352aggregate_shap_multiple_observations <- function (shap_res_list , feature_names , aggregation_function ) {
234353 if (length(shap_res_list ) > 1 ) {
0 commit comments