Skip to content

Commit 0f6bf31

Browse files
authored
Merge pull request #86 from ModelOriented/fixes
hot fixes for predict parts functions
2 parents 0b2f4f5 + cca9530 commit 0f6bf31

File tree

8 files changed

+35
-23
lines changed

8 files changed

+35
-23
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: survex
22
Title: Explainable Machine Learning in Survival Analysis
3-
Version: 1.1.3.9000
3+
Version: 1.1.3.9002
44
Authors@R:
55
c(
66
person("Mikołaj", "Spytek", email = "[email protected]", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")),

R/plot_surv_shap.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ plot_shap_global_beeswarm <- function(x,
293293
max_vars = 7,
294294
colors = NULL) {
295295
df <- as.data.frame(do.call(rbind, x$aggregate))
296-
cols <- names(sort(colMeans(abs(df))))[1:min(max_vars, length(df))]
296+
cols <- names(sort(colMeans(abs(df)), decreasing = TRUE))[1:min(max_vars, length(df))]
297297
df <- df[, cols]
298298
df <- stack(df)
299299
colnames(df) <- c("shap_value", "variable")

R/predict_parts.R

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
#' @param explainer an explainer object - model preprocessed by the `explain()` function
66
#' @param new_observation a new observation for which prediction need to be explained
77
#' @param ... other parameters which are passed to `iBreakDown::break_down` if `output_type=="risk"`, or if `output_type=="survival"` to `surv_shap()` or `surv_lime()` functions depending on the selected type
8-
#' @param N the maximum number of observations used for calculation of attributions. If `NULL` (default) all observations will be used.
8+
#' @param N the number of observations used for calculation of attributions. If `NULL` (default) all explainer data will be used for SurvSHAP(t) and 100 neigbours for SurvLIME.
99
#' @param type if `output_type == "survival"` must be either `"survshap"` or `"survlime"`, otherwise refer to the `DALEX::predict_parts`
10-
#' @param output_type either `"survival"` or `"risk"` the type of survival model output that should be considered for explanations. If `"survival"` the explanations are based on the survival function. Otherwise the scalar risk predictions are used by the `DALEX::predict_parts` function.
10+
#' @param output_type either `"survival"`, `"chf"` or `"risk"` the type of survival model output that should be considered for explanations. If `"survival"` the explanations are based on the survival function. If `"chf"` the explanations are based on the cumulative hazard function. Otherwise the scalar risk predictions are used by the `DALEX::predict_parts` function.
1111
#' @param explanation_label a label that can overwrite explainer label (useful for multiple explanations for the same explainer/model)
1212
#'
1313
#' @return An object of class `"predict_parts_survival"` and additional classes depending on the type of explanations. It is a list with the element `result` containing the results of the calculation.
@@ -27,7 +27,6 @@
2727
#' * `categorical_variables` - character vector, names of variables that should be treated as categories (factors are included by default)
2828
#' * `k` - a small positive number > 1, added to chf before taking log, so that weigths aren't negative
2929
#' * for `survshap`
30-
#' * `timestamps` - a numeric vector, time points at which the survival function will be evaluated
3130
#' * `y_true` - a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting
3231
#' * `calculation_method` - a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements) or `"exact_kernel"` for exact Kernel SHAP estimation
3332
#' * `aggregation_method` - a character, either `"mean_absolute"` or `"integral"`, `"max_absolute"`, `"sum_of_squares"`
@@ -75,8 +74,8 @@ predict_parts.surv_explainer <- function(explainer, new_observation, ..., N = NU
7574
))
7675
} else {
7776
res <- switch(type,
78-
"survshap" = surv_shap(explainer, new_observation, output_type, ...),
79-
"survlime" = surv_lime(explainer, new_observation, ...),
77+
"survshap" = surv_shap(explainer, new_observation, output_type, ..., N = N),
78+
"survlime" = surv_lime(explainer, new_observation, ..., N = N),
8079
stop("Only `survshap` and `survlime` methods are implemented for now")
8180
)
8281
}

R/surv_lime.R

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ surv_lime <- function(explainer, new_observation,
3333
test_explainer(explainer, "surv_lime", has_data = TRUE, has_y = TRUE, has_chf = TRUE)
3434
new_observation <- new_observation[, colnames(new_observation) %in% colnames(explainer$data)]
3535
if (ncol(explainer$data) != ncol(new_observation)) stop("New observation and data have different number of columns (variables)")
36+
if (is.null(N)) N <- 100
3637

3738
predicted_sf <- explainer$predict_survival_function(explainer$model, new_observation, explainer$times)
3839

@@ -57,12 +58,11 @@ surv_lime <- function(explainer, new_observation,
5758

5859
distances <- apply(scaled_data, 1, dist, scaled_data[1, ])
5960

60-
if (is.null(kernel_width)) kernel_width <- sqrt(ncol(scaled_data) * 0.75)
61+
if (is.null(kernel_width)) kernel_width <- sqrt(ncol(scaled_data)) * 0.75
6162

6263
weights <- sqrt(exp(-(distances^2) / (kernel_width^2)))
6364
na_est <- survival::basehaz(survival::coxph(explainer$y ~ 1))
6465

65-
6666
model_chfs <- explainer$predict_cumulative_hazard_function(explainer$model, neighbourhood$inverse, na_est$time) + k
6767
log_chfs <- log(model_chfs)
6868
weights_v <- model_chfs / log_chfs
@@ -175,10 +175,13 @@ generate_neighbourhood <- function(data_org,
175175
data <- data[, colnames(data_row)]
176176

177177
if (length(categorical_variables) > 0) {
178+
inverse_as_factor <- inverse
179+
inverse_as_factor[additional_categorical_variables] <-
180+
lapply(inverse_as_factor[additional_categorical_variables], as.factor)
178181
expr <- paste0("~", paste(categorical_variables, collapse = "+"))
179-
categorical_matrix <- model.matrix(as.formula(expr), data = inverse)[, -1]
182+
categorical_matrix <- model.matrix(as.formula(expr), data = inverse_as_factor)[, -1]
180183
inverse_ohe <- cbind(inverse, categorical_matrix)
181-
inverse_ohe[, factor_variables] <- NULL
184+
inverse_ohe[, categorical_variables] <- NULL
182185
} else {
183186
inverse_ohe <- inverse
184187
}

R/surv_shap.R

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#' @param new_observation new observations for which predictions need to be explained
55
#' @param output_type a character, either `"survival"` or `"chf"`. Determines which type of prediction should be used for explanations.
66
#' @param ... additional parameters, passed to internal functions
7+
#' @param N a positive integer, number of observations used as the background data
78
#' @param y_true a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting
89
#' @param calculation_method a character, either `"kernelshap"` for use of `kernelshap` library (providing faster Kernel SHAP with refinements), `"exact_kernel"` for exact Kernel SHAP estimation, or `"treeshap"` for use of `treeshap` library (efficient implementation to compute SHAP values for tree-based models).
910
#' @param aggregation_method a character, either `"integral"`, `"integral_absolute"`, `"mean_absolute"`, `"max_absolute"`, or `"sum_of_squares"`
@@ -18,6 +19,7 @@ surv_shap <- function(explainer,
1819
new_observation,
1920
output_type,
2021
...,
22+
N = NULL,
2123
y_true = NULL,
2224
calculation_method = c("kernelshap", "exact_kernel", "treeshap"),
2325
aggregation_method = c("integral", "mean_absolute", "max_absolute", "sum_of_squares")
@@ -62,7 +64,6 @@ surv_shap <- function(explainer,
6264
}
6365

6466
test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE)
65-
6667
# make this code also work for 1-row matrix
6768
col_index <- which(colnames(new_observation) %in% colnames(explainer$data))
6869
if (is.matrix(new_observation) && nrow(new_observation) == 1) {
@@ -128,11 +129,12 @@ surv_shap <- function(explainer,
128129
return(res)
129130
}
130131

131-
use_exact_shap <- function(explainer, new_observation, output_type, ...) {
132+
133+
use_exact_shap <- function(explainer, new_observation, output_type, N, ...) {
132134
shap_values <- sapply(
133135
X = as.character(seq_len(nrow(new_observation))),
134136
FUN = function(i) {
135-
as.data.frame(shap_kernel(explainer, new_observation[as.integer(i), ], output_type, ...))
137+
as.data.frame(shap_kernel(explainer, new_observation[as.integer(i), ], output_type, N, ...))
136138
},
137139
USE.NAMES = TRUE,
138140
simplify = FALSE
@@ -142,24 +144,24 @@ use_exact_shap <- function(explainer, new_observation, output_type, ...) {
142144
}
143145

144146

145-
shap_kernel <- function(explainer, new_observation, output_type, ...) {
147+
shap_kernel <- function(explainer, new_observation, output_type, N, ...) {
146148
timestamps <- explainer$times
147149
p <- ncol(explainer$data)
148-
150+
if (is.null(N)) N <- nrow(explainer$data)
151+
background_data <- explainer$data[sample(1:nrow(explainer$data), N),]
149152

150153
target_sf <- predict(explainer, new_observation, times = timestamps, output_type = output_type)
151-
sfs <- predict(explainer, explainer$data, times = timestamps, output_type = output_type)
154+
sfs <- predict(explainer, background_data, times = timestamps, output_type = output_type)
152155
baseline_sf <- apply(sfs, 2, mean)
153156

154-
155157
permutations <- expand.grid(rep(list(0:1), p))
156158
kernel_weights <- generate_shap_kernel_weights(permutations, p)
157159

158160
shap_values <- calculate_shap_values(
159161
explainer,
160162
explainer$model,
161163
baseline_sf,
162-
as.data.frame(explainer$data),
164+
as.data.frame(background_data),
163165
permutations, kernel_weights,
164166
as.data.frame(new_observation),
165167
timestamps
@@ -227,7 +229,7 @@ aggregate_surv_shap <- function(survshap, times, method, ...) {
227229
}
228230

229231

230-
use_kernelshap <- function(explainer, new_observation, output_type, observation_aggregation_method, ...) {
232+
use_kernelshap <- function(explainer, new_observation, output_type, N, ...) {
231233
predfun <- function(model, newdata) {
232234

233235
if (output_type == "survival"){
@@ -257,6 +259,9 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_
257259
explainer_data <- data.frame(explainer_data)
258260
}
259261

262+
if (is.null(N)) N <- nrow(explainer$data)
263+
background_data <- explainer$data[sample(1:nrow(explainer$data), N),]
264+
260265
shap_values <- sapply(
261266
X = as.character(seq_len(nrow(new_observation))),
262267
FUN = function(i) {

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ If you use `survex`, please cite [our preprint](https://arxiv.org/abs/2308.16113
102102
- W. Chen, B. Zhou, C. Y. Jeon, F. Xie, Y-C. Lin, R. K. Butler, Y. Zhou, T. Q. Luong, E. Lustigova, J. R. Pisegna, B. U. Wu. [Machine learning versus regression for prediction of sporadic pancreatic cancer](https://doi.org/10.1016/j.pan.2023.04.009). *Pancreatology*, 2023.
103103
- M. Nachit, Y. Horsmans, R. M. Summers, I. A. Leclercq, P. J. Pickhardt. [AI-based CT Body Composition Identifies Myosteatosis as Key Mortality Predictor in Asymptomatic Adults](https://doi.org/10.1148/radiol.222008). *Radiology*, 2023.
104104
- R. Passera, S. Zompi, J. Gill, A. Busca. [Explainable Machine Learning (XAI) for Survival in Bone Marrow Transplantation Trials: A Technical Report](https://doi.org/10.3390/biomedinformatics3030048). *BioMedInformatics*, 2023.
105+
- P. Donizy, M. Spytek, M. Krzyziński, K. Kotowski, A. Markiewicz, B. Romanowska-Dixon, P. Biecek, M. P. Hoang. [Ki67 is a better marker than PRAME in risk stratification of BAP1-positive and BAP1-loss uveal melanomas](http://dx.doi.org/10.1136/bjo-2023-323816). *British Journal of Ophthalmology*, 2023.
106+
- X. Qi, Y. Ge, A. Yang, Y. Liu, Q. Wang & G. Wu. [Potential value of mitochondrial regulatory pathways in the clinical application of clear cell renal cell carcinoma: a machine learning-based study](https://doi.org/10.1007/s00432-023-05393-8). *Journal of Cancer Research and Clinical Oncology*, 2023.
107+
105108
- Share it with us!
106109

107110
## Related work

man/predict_parts.surv_explainer.Rd

Lines changed: 2 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/surv_shap.Rd

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)