diff --git a/R/predict_surrogate.R b/R/predict_surrogate.R index b7a0b8f..393710f 100644 --- a/R/predict_surrogate.R +++ b/R/predict_surrogate.R @@ -73,9 +73,13 @@ predict_surrogate_lime <- function(explainer, new_observation, n_features = 4, n # https://github.com/ModelOriented/DALEXtra/issues/73 new_observation <- new_observation[, intersect(colnames(explainer$data), colnames(new_observation))] + should_bin_continuous <- ifelse(!is.null(explainer$bin_continuous), explainer$bin_continuous, TRUE) - lime_model <- lime::lime(x = explainer$data[, colnames(new_observation)], - model = explainer) + lime_model <- lime::lime( + x = explainer$data[, colnames(new_observation)], + model = explainer, + bin_continuous = should_bin_continuous + ) lime_expl <- lime::explain(x = new_observation, explainer = lime_model,