Skip to content

Commit 23e57da

Browse files
committed
fix: another try to fix kernelshap data, now X and bg_X as data.frame
1 parent a496a67 commit 23e57da

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

R/surv_shap.R

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,16 +251,24 @@ use_kernelshap <- function(explainer, new_observation, output_type, observation_
251251
new_observation, "data.frame")
252252
)
253253

254+
# get explainer data to be able to make class checks and transformations
255+
explainer_data <- explainer$data
256+
# ensure that classes of explainer$data and new_observation are equal
257+
if (!inherits(explainer_data, "data.frame")) {
258+
explainer_data <- data.frame(explainer_data)
259+
}
260+
254261
shap_values <- sapply(
255262
X = as.character(seq_len(nrow(new_observation))),
256263
FUN = function(i) {
257264
tmp_res <- kernelshap::kernelshap(
258265
object = explainer$model,
259-
X = as.matrix(new_observation[as.integer(i), ]),
260-
bg_X = as.matrix(explainer$data),
266+
X = new_observation[as.integer(i), ], # data.frame
267+
bg_X = explainer_data, # data.frame
261268
pred_fun = predfun,
262269
verbose = FALSE
263270
)
271+
# kernelshap-test: is.matrix(X) == is.matrix(bg_X) should evaluate to `TRUE`
264272
tmp_shap_values <- data.frame(t(sapply(tmp_res$S, cbind)))
265273
colnames(tmp_shap_values) <- colnames(tmp_res$X)
266274
rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "")

0 commit comments

Comments
 (0)