Skip to content

Commit 52a9fa1

Browse files
committed
further clean up to the point explain_forest works for ranger
1 parent 10ab14d commit 52a9fa1

9 files changed

+118
-66
lines changed

R/explain_forest.R

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,29 @@
2222
#'
2323
#' @export
2424
explain_forest <- function(forest, interactions = FALSE, data = NULL, vars = NULL, no_of_pred_plots = 3, pred_grid = 100,
25-
measures = if(forest$type == "classification")
26-
c("mean_min_depth", "accuracy_decrease", "gini_decrease", "no_of_nodes", "times_a_root") else
27-
c("mean_min_depth", "mse_increase", "node_purity_increase", "no_of_nodes", "times_a_root")){
28-
if(any(c("accuracy_decrease", "mse_increase") %in% measures) & dim(forest$importance)[2] == 1) {
25+
measures = NULL){
26+
if(is.null(measures)){
27+
if("randomForest" %in% class(forest)){
28+
if(forest$type == "classification"){
29+
measures <- c("mean_min_depth", "accuracy_decrease", "gini_decrease", "no_of_nodes", "times_a_root")
30+
} else{
31+
measures <- c("mean_min_depth", "mse_increase", "node_purity_increase", "no_of_nodes", "times_a_root")
32+
}
33+
} else if("ranger" %in% class(forest)){
34+
measures <- c("mean_min_depth", forest$importance.mode, "no_of_nodes", "times_a_root")
35+
}
36+
}
37+
if("randomForest" %in% class(forest) && dim(forest$importance)[2] == 1){
2938
stop(paste("Your forest does not contain information on local importance so",
30-
paste(intersect(c("accuracy_decrease", "mse_increase"), measures), sep=", "),
39+
ifelse(forest$type == "classification", "accuracy_decrease", "mse_increase"),
3140
"measure cannot be extracted.",
3241
"To add it regrow the forest with the option localImp = TRUE and run this function again."))
3342
}
43+
if("ranger" %in% class(forest) && forest$importance.mode == "none"){
44+
stop(paste("Your forest does not contain importance information so",
45+
"importance cannot be extracted.",
46+
"To add it regrow the forest with the option importance other than 'none' and run this function again."))
47+
}
3448
environment <- new.env()
3549
environment$forest <- forest
3650
environment$data <- data
@@ -42,7 +56,7 @@ explain_forest <- function(forest, interactions = FALSE, data = NULL, vars = NUL
4256
directory <- getwd()
4357
path_to_templates <- file.path(path.package("randomForestExplainer"), "templates")
4458
template_name <- grep('explain_forest_template.rmd', list.files(path_to_templates),
45-
ignore.case = TRUE, value = TRUE)
59+
ignore.case = TRUE, value = TRUE)
4660

4761
rmarkdown::render(file.path(path_to_templates, template_name),
4862
"html_document", output_file = paste0(directory, "/Your_forest_explained.html"),

R/measure_importance.R

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ measure_vimp <- function(forest, only_nonlocal = FALSE){
6060
# Extract randomForest variable importance measures
6161
# ranger
6262
measure_vimp_ranger <- function(forest){
63-
if (forest$importance.mode == "none") {
63+
if (forest$importance.mode == "none"){
6464
stop("No variable importance available, regenerate forest by ranger(..., importance='impurity').")
6565
}
6666
frame <- data.frame(importance=forest$variable.importance,
@@ -95,8 +95,8 @@ measure_times_a_root <- function(min_depth_frame){
9595
measure_p_value <- function(importance_frame){
9696
total_no_of_nodes <- sum(importance_frame$no_of_nodes)
9797
p_value <- unlist(lapply(importance_frame$no_of_nodes,
98-
function(x) stats::binom.test(x, total_no_of_nodes, 1/nrow(importance_frame),
99-
alternative = "greater")$p.value))
98+
function(x) stats::binom.test(x, total_no_of_nodes, 1/nrow(importance_frame),
99+
alternative = "greater")$p.value))
100100
return(p_value)
101101
}
102102

@@ -132,7 +132,7 @@ measure_importance.randomForest <- function(forest, mean_sample = "top_trees", m
132132
if(forest$type == "classification"){
133133
measures <- c("mean_min_depth", "no_of_nodes", "accuracy_decrease",
134134
"gini_decrease", "no_of_trees", "times_a_root", "p_value")
135-
} else if(forest$type =="regression") {
135+
} else if(forest$type =="regression"){
136136
measures <- c("mean_min_depth", "no_of_nodes", "mse_increase", "node_purity_increase",
137137
"no_of_trees", "times_a_root", "p_value")
138138
}
@@ -161,7 +161,7 @@ measure_importance.randomForest <- function(forest, mean_sample = "top_trees", m
161161
}
162162
if(forest$type == "classification"){
163163
vimp <- c("accuracy_decrease", "gini_decrease")
164-
} else if(forest$type =="regression") {
164+
} else if(forest$type =="regression"){
165165
vimp <- c("mse_increase", "node_purity_increase")
166166
}
167167
if(all(vimp %in% measures)){
@@ -193,7 +193,7 @@ measure_importance.randomForest <- function(forest, mean_sample = "top_trees", m
193193
measure_importance.ranger <- function(forest, mean_sample = "top_trees", measures = NULL){
194194
tree <- NULL; splitvarName <- NULL; depth <- NULL
195195
if(is.null(measures)){
196-
measures <- c("mean_min_depth", "no_of_nodes", forest$importance.mode, "no_of_trees", "times_a_root", "p_value")
196+
measures <- c("mean_min_depth", "no_of_nodes", forest$importance.mode, "no_of_trees", "times_a_root", "p_value")
197197
}
198198
if(("p_value" %in% measures) && !("no_of_nodes" %in% measures)){
199199
measures <- c(measures, "no_of_nodes")
@@ -261,7 +261,7 @@ important_variables <- function(importance_frame, k = 15,
261261
if("predicted" %in% measures){
262262
measures <- names(importance_frame)[2:5]
263263
}
264-
} else if ("ranger" %in% class(importance_frame)) {
264+
} else if ("ranger" %in% class(importance_frame)){
265265
importance_frame <- measure_importance(importance_frame)
266266
}
267267
rankings <- data.frame(variable = importance_frame$variable, mean_min_depth =
@@ -376,14 +376,16 @@ plot_multi_way_importance <- function(importance_frame, x_measure = "mean_min_de
376376
#' plot_importance_ggpairs(frame, measures = c("mean_min_depth", "times_a_root"))
377377
#'
378378
#' @export
379-
plot_importance_ggpairs <- function(importance_frame, measures =
380-
names(importance_frame)[c(2, 4, 5, 3, 7)],
379+
plot_importance_ggpairs <- function(importance_frame, measures = NULL,
381380
main = "Relations between measures of importance"){
382-
if("randomForest" %in% class(importance_frame)){
381+
if(any(c("randomForest", "ranger") %in% class(importance_frame))){
383382
importance_frame <- measure_importance(importance_frame)
384-
if("predicted" %in% measures){
385-
names(importance_frame)[c(2, 4, 5, 3, 7)]
386-
}
383+
}
384+
if (is.null(measures)){
385+
default_measures <- c("gini_decrease", "node_purity_increase", # randomForest
386+
"impurity", "impurity_corrected", "permutation", # ranger
387+
"mean_min_depth", "no_of_trees", "no_of_nodes", "p_value")
388+
measures <- intersect(default_measures, colnames(importance_frame))
387389
}
388390
plot <- ggpairs(importance_frame[, measures]) + theme_bw()
389391
if(!is.null(main)){
@@ -397,7 +399,7 @@ plot_importance_ggpairs <- function(importance_frame, measures =
397399
#' Plot against each other rankings of variables according to various measures of importance
398400
#'
399401
#' @param importance_frame A result of using the function measure_importance() to a random forest or a randomForest object
400-
#' @param measures A character vector specifying the measures of importance to be used
402+
#' @param measures A character vector specifying the measures of importance to be used.
401403
#' @param main A string to be used as title of the plot
402404
#'
403405
#' @return A ggplot object
@@ -411,22 +413,29 @@ plot_importance_ggpairs <- function(importance_frame, measures =
411413
#' plot_importance_ggpairs(frame, measures = c("mean_min_depth", "times_a_root"))
412414
#'
413415
#' @export
414-
plot_importance_rankings <- function(importance_frame, measures =
415-
names(importance_frame)[c(2, 4, 5, 3, 7)],
416+
plot_importance_rankings <- function(importance_frame, measures = NULL,
416417
main = "Relations between rankings according to different measures"){
417-
if("randomForest" %in% class(importance_frame)){
418+
if(any(c("randomForest", "ranger") %in% class(importance_frame))){
418419
importance_frame <- measure_importance(importance_frame)
419-
if("predicted" %in% measures){
420-
names(importance_frame)[c(2, 4, 5, 3, 7)]
421-
}
422420
}
423-
rankings <- data.frame(variable = importance_frame$variable, mean_min_depth =
424-
frankv(importance_frame$mean_min_depth, ties.method = "dense"),
425-
apply(importance_frame[, -c(1, 2)], 2,
421+
if (is.null(measures)){
422+
default_measures <- c("gini_decrease", "node_purity_increase", # randomForest
423+
"impurity", "impurity_corrected", "permutation", # ranger
424+
"mean_min_depth", "no_of_trees", "no_of_nodes", "p_value")
425+
measures <- intersect(default_measures, colnames(importance_frame))
426+
}
427+
rankings <- data.frame(variable = importance_frame$variable,
428+
apply(importance_frame[, !colnames(importance_frame) %in% c("variable", "mean_min_depth", "p_value")], 2,
426429
function(x) frankv(x, order = -1, ties.method = "dense")))
430+
if ("mean_min_depth" %in% measures){
431+
rankings$mean_min_depth = frankv(importance_frame$mean_min_depth, ties.method = "dense")
432+
}
433+
if ("p_value" %in% measures){
434+
rankings$p_value = frankv(importance_frame$p_value, ties.method = "dense")
435+
}
427436
plot <- ggpairs(rankings[, measures], lower = list(continuous = function(data, mapping){
428437
ggplot(data = data, mapping = mapping) + geom_point() + geom_smooth(method = "loess")
429-
}))+ theme_bw()
438+
})) + theme_bw()
430439
if(!is.null(main)){
431440
plot <- plot + ggtitle(main)
432441
}

R/min_depth_interactions.R

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ conditional_depth_ranger <- function(frame, vars){
3939
df <- frame[begin:nrow(frame), setdiff(names(frame), setdiff(vars, j))]
4040
df[[j]][1] <- 0
4141
for(k in 2:nrow(df)){
42-
if(length(df[df[, "leftChild"] == as.numeric(df[k, "number"]) |
43-
df[, "rightChild"] == as.numeric(df[k, "number"]), j]) != 0){
42+
if(length(df[(!is.na(df[, "leftChild"]) & df[, "leftChild"] == as.numeric(df[k, "number"])) |
43+
(!is.na(df[, "rightChild"]) & df[, "rightChild"] == as.numeric(df[k, "number"])), j]) != 0){
4444
df[k, j] <-
45-
df[df[, "leftChild"] == as.numeric(df[k, "number"]) |
46-
df[, "rightChild"] == as.numeric(df[k, "number"]), j] + 1
45+
df[(!is.na(df[, "leftChild"]) & df[, "leftChild"] == as.numeric(df[k, "number"])) |
46+
(!is.na(df[, "rightChild"]) & df[, "rightChild"] == as.numeric(df[k, "number"])), j] + 1
4747
}
4848
}
4949
frame[begin:nrow(frame), setdiff(names(frame), setdiff(vars, j))] <- df
@@ -84,8 +84,8 @@ min_depth_interactions_values <- function(forest, vars){
8484
min_depth_interactions_values_ranger <- function(forest, vars){
8585
`.` <- NULL; .SD <- NULL; tree <- NULL; splitvarName <- NULL
8686
interactions_frame <-
87-
lapply(1:forest$num.trees, function(i) ranger::treeInfo(forest, k = i, labelVar = T) %>%
88-
calculate_tree_depth() %>% cbind(., tree = i, number = 1:nrow(.))) %>%
87+
lapply(1:forest$num.trees, function(i) ranger::treeInfo(forest, tree = i) %>%
88+
calculate_tree_depth_ranger() %>% cbind(., tree = i, number = 1:nrow(.))) %>%
8989
data.table::rbindlist() %>% as.data.frame()
9090
interactions_frame[vars] <- as.numeric(NA)
9191
interactions_frame <-
@@ -95,10 +95,10 @@ min_depth_interactions_values_ranger <- function(forest, vars){
9595
mean_tree_depth[mean_tree_depth == -Inf] <- NA
9696
mean_tree_depth <- colMeans(mean_tree_depth[, vars], na.rm = TRUE)
9797
min_depth_interactions_frame <-
98-
interactions_frame %>% dplyr::group_by(tree, `split var`) %>%
98+
interactions_frame %>% dplyr::group_by(tree, splitvarName) %>%
9999
dplyr::summarize_at(vars, funs(min(., na.rm = TRUE))) %>% as.data.frame()
100100
min_depth_interactions_frame[min_depth_interactions_frame == Inf] <- NA
101-
min_depth_interactions_frame <- min_depth_interactions_frame[!is.na(min_depth_interactions_frame$`split var`), ]
101+
min_depth_interactions_frame <- min_depth_interactions_frame[!is.na(min_depth_interactions_frame$splitvarName), ]
102102
colnames(min_depth_interactions_frame)[2] <- "variable"
103103
min_depth_interactions_frame[, -c(1:2)] <- min_depth_interactions_frame[, -c(1:2)] - 1
104104
return(list(min_depth_interactions_frame, mean_tree_depth))
@@ -120,7 +120,8 @@ min_depth_interactions_values_ranger <- function(forest, vars){
120120
#' min_depth_interactions(forest, c("Petal.Width", "Petal.Length"))
121121
#'
122122
#' @export
123-
min_depth_interactions <- function(){
123+
min_depth_interactions <- function(forest, vars = important_variables(measure_importance(forest)),
124+
mean_sample = "top_trees", uncond_mean_sample = mean_sample){
124125
UseMethod("min_depth_interactions")
125126
}
126127

@@ -217,7 +218,7 @@ min_depth_interactions.ranger <- function(forest, vars = important_variables(mea
217218
dplyr::summarize(min(depth))
218219
colnames(min_depth_frame) <- c("tree", "variable", "minimal_depth")
219220
min_depth_frame <- as.data.frame(min_depth_frame[!is.na(min_depth_frame$variable),])
220-
importance_frame <- get_min_depth_means_ranger(min_depth_frame, min_depth_count(min_depth_frame), uncond_mean_sample)
221+
importance_frame <- get_min_depth_means(min_depth_frame, min_depth_count(min_depth_frame), uncond_mean_sample)
221222
colnames(importance_frame)[2] <- "uncond_mean_min_depth"
222223
interactions_frame <- merge(interactions_frame, importance_frame)
223224
}
@@ -273,6 +274,9 @@ plot_min_depth_interactions <- function(interactions_frame, k = 30,
273274
#' @param variable2 A character string with the name a numerical predictor that will on Y-axis
274275
#' @param grid The number of points on the one-dimensional grid on x and y-axis
275276
#' @param main A string to be used as title of the plot
277+
#' @param time A numeric value specifying the time at which to predict survival probability, only
278+
#' applies to survival forests. If not specified, the time closest to predicted median survival
279+
#' time is used
276280
#'
277281
#' @return A ggplot2 object
278282
#'
@@ -285,7 +289,8 @@ plot_min_depth_interactions <- function(interactions_frame, k = 30,
285289
#' @export
286290
plot_predict_interaction <- function(forest, data, variable1, variable2, grid = 100,
287291
main = paste0("Prediction of the forest for different values of ",
288-
paste0(variable1, paste0(" and ", variable2)))){
292+
paste0(variable1, paste0(" and ", variable2))),
293+
time = NULL){
289294
UseMethod("plot_predict_interaction")
290295
}
291296

@@ -296,7 +301,8 @@ plot_predict_interaction <- function(forest, data, variable1, variable2, grid =
296301
#' @export
297302
plot_predict_interaction.randomForest <- function(forest, data, variable1, variable2, grid = 100,
298303
main = paste0("Prediction of the forest for different values of ",
299-
paste0(variable1, paste0(" and ", variable2)))){
304+
paste0(variable1, paste0(" and ", variable2))),
305+
time = NULL){
300306
newdata <- expand.grid(seq(min(data[[variable1]]), max(data[[variable1]]), length.out = grid),
301307
seq(min(data[[variable2]]), max(data[[variable2]]), length.out = grid))
302308
colnames(newdata) <- c(variable1, variable2)
@@ -343,7 +349,8 @@ plot_predict_interaction.randomForest <- function(forest, data, variable1, varia
343349
#' @export
344350
plot_predict_interaction.ranger <- function(forest, data, variable1, variable2, grid = 100,
345351
main = paste0("Prediction of the forest for different values of ",
346-
paste0(variable1, paste0(" and ", variable2)))){
352+
paste0(variable1, paste0(" and ", variable2))),
353+
time = NULL){
347354
newdata <- expand.grid(seq(min(data[[variable1]]), max(data[[variable1]]), length.out = grid),
348355
seq(min(data[[variable2]]), max(data[[variable2]]), length.out = grid))
349356
colnames(newdata) <- c(variable1, variable2)
@@ -358,7 +365,7 @@ plot_predict_interaction.ranger <- function(forest, data, variable1, variable2,
358365
newdata[[i]] <- data[[i]][sample(1:n, nrow(newdata), replace = TRUE)]
359366
}
360367
if(forest$treetype == "Regression"){
361-
newdata$prediction <- predict(forest, newdata, type = "response")
368+
newdata$prediction <- predict(forest, newdata, type = "response")$predictions
362369
plot <- ggplot(newdata, aes_string(x = variable1, y = variable2, fill = "prediction")) +
363370
geom_raster() + theme_bw() +
364371
scale_fill_gradient2(midpoint = min(newdata$prediction) + 0.5 * (max(newdata$prediction) - min(newdata$prediction)),
@@ -377,8 +384,23 @@ plot_predict_interaction.ranger <- function(forest, data, variable1, variable2,
377384
geom_raster() + theme_bw() + facet_wrap(~ variable) +
378385
scale_fill_gradient2(midpoint = min(newdata$prediction) + 0.5 * (max(newdata$prediction) - min(newdata$prediction)),
379386
low = "blue", high = "red")
380-
} else if(forest$treetype == "Classification") {
387+
} else if(forest$treetype == "Classification"){
381388
stop("Ranger forest for classification needs to be generated by ranger(..., probability = TRUE).")
389+
} else if(forest$treetype == "Survival"){
390+
pred <- predict(forest, newdata, type = "response")
391+
if (is.null(time)) {
392+
time <- pred$unique.death.times[which.min(abs(colMeans(pred$survival, na.rm = TRUE) - 0.5))]
393+
message(sprintf("Using unique death time %s which is the closest to predicted median survival time.", time))
394+
} else if (!time %in% pred$unique.death.times) {
395+
new_time <- pred$unique.death.times[which.min(abs(pred$unique.death.times - time))]
396+
message(sprintf("Using closest unique death time %s instead of %s.", new_time, time))
397+
time <- new_time
398+
}
399+
newdata$prediction <- pred$survival[, pred$unique.death.times == time, drop = TRUE]
400+
plot <- ggplot(newdata, aes_string(x = variable1, y = variable2, fill = "prediction")) +
401+
geom_raster() + theme_bw() +
402+
scale_fill_gradient2(midpoint = min(newdata$prediction) + 0.5 * (max(newdata$prediction) - min(newdata$prediction)),
403+
low = "blue", high = "red")
382404
} else {
383405
stop(sprintf("Ranger forest type '%s' is currently not supported.", forest$treetype))
384406
}

0 commit comments

Comments
 (0)