Skip to content

Commit 0e8c85c

Browse files
committed
fix: added missing output_type arguments to switch of calculation method in surv_shap
which caused error of github action unittests when using 'exact_kernel'
1 parent a9269f0 commit 0e8c85c

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

R/surv_shap.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ surv_shap <- function(explainer,
2828

2929
# make this code work for multiple observations
3030
stopifnot(
31-
"`y_true` must be either a matrix with one per observation in `new_observation` or a vector of length == 2" = ifelse(
31+
"`y_true` must be either a matrix with one row per observation in `new_observation` or a vector of length == 2" = ifelse(
3232
!is.null(y_true),
3333
ifelse(
3434
is.matrix(y_true),
@@ -102,8 +102,8 @@ surv_shap <- function(explainer,
102102
# to display final object correctly, when is.matrix(new_observation) == TRUE
103103
res$variable_values <- as.data.frame(new_observation)
104104
res$result <- switch(calculation_method,
105-
"exact_kernel" = use_exact_shap(explainer, new_observation, ...),
106-
"kernelshap" = use_kernelshap(explainer, new_observation, ...),
105+
"exact_kernel" = use_exact_shap(explainer, new_observation, output_type, ...),
106+
"kernelshap" = use_kernelshap(explainer, new_observation, output_type, ...),
107107
"treeshap" = use_treeshap(explainer, new_observation, ...),
108108
stop("Only `exact_kernel`, `kernelshap` and `treeshap` calculation methods are implemented"))
109109
# quality-check here
@@ -129,7 +129,7 @@ surv_shap <- function(explainer,
129129
return(res)
130130
}
131131

132-
use_exact_shap <- function(explainer, new_observation, output_type, observation_aggregation_method, ...) {
132+
use_exact_shap <- function(explainer, new_observation, output_type, ...) {
133133
shap_values <- sapply(
134134
X = as.character(seq_len(nrow(new_observation))),
135135
FUN = function(i) {

tests/testthat/test-predict_parts.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ test_that("survshap explanations with output_type = 'chf' work", {
103103
plot(parts_cph, rug = "censors")
104104
plot(parts_cph, rug = "none")
105105

106+
# test global exact
107+
parts_cph_glob <- predict_parts(cph_exp, veteran[1:3, !colnames(veteran) %in% c("time", "status")], y_true = as.matrix(veteran[1:3, c("time", "status")]), calculation_method = "exact_kernel", aggregation_method = "max_absolute", output_type = "chf")
108+
plot(parts_cph_glob)
109+
106110
parts_ranger <- predict_parts(rsf_ranger_exp, veteran[2, !colnames(veteran) %in% c("time", "status")], y_true = c(100, 1), aggregation_method = "mean_absolute", output_type = "chf")
107111
plot(parts_ranger)
108112

0 commit comments

Comments
 (0)