Skip to content

Commit d6db6a4

Browse files
Yue-Jiangpbiecek
authored andcommitted
add ranger compatibility (#10)
* make `plot_min_depth_distribution` work with ranger forests * first pass adding ranger support for min_depth_interactions, untested * first pass adding ranger support for multi way importance, untested * further clean up to the point explain_forest works for ranger
1 parent b12e54d commit d6db6a4

15 files changed

+430
-79
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@ Imports:
2020
ggrepel (>= 0.6.5),
2121
MASS (>= 7.3.47),
2222
randomForest (>= 4.6.12),
23+
ranger(>= 0.9.0),
2324
reshape2 (>= 1.4.2),
2425
rmarkdown (>= 1.5)
2526
Suggests:
2627
knitr
2728
VignetteBuilder: knitr
28-
RoxygenNote: 6.0.1
29+
RoxygenNote: 6.1.1
2930
URL: https://github.com/MI2DataLab/randomForestExplainer

NAMESPACE

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method(measure_importance,randomForest)
4+
S3method(measure_importance,ranger)
5+
S3method(min_depth_distribution,randomForest)
6+
S3method(min_depth_distribution,ranger)
7+
S3method(min_depth_interactions,randomForest)
8+
S3method(min_depth_interactions,ranger)
9+
S3method(plot_predict_interaction,randomForest)
10+
S3method(plot_predict_interaction,ranger)
311
export(explain_forest)
412
export(important_variables)
513
export(measure_importance)

R/explain_forest.R

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,29 @@
2222
#'
2323
#' @export
2424
explain_forest <- function(forest, interactions = FALSE, data = NULL, vars = NULL, no_of_pred_plots = 3, pred_grid = 100,
25-
measures = if(forest$type == "classification")
26-
c("mean_min_depth", "accuracy_decrease", "gini_decrease", "no_of_nodes", "times_a_root") else
27-
c("mean_min_depth", "mse_increase", "node_purity_increase", "no_of_nodes", "times_a_root")){
28-
if(any(c("accuracy_decrease", "mse_increase") %in% measures) & dim(forest$importance)[2] == 1) {
25+
measures = NULL){
26+
if(is.null(measures)){
27+
if("randomForest" %in% class(forest)){
28+
if(forest$type == "classification"){
29+
measures <- c("mean_min_depth", "accuracy_decrease", "gini_decrease", "no_of_nodes", "times_a_root")
30+
} else{
31+
measures <- c("mean_min_depth", "mse_increase", "node_purity_increase", "no_of_nodes", "times_a_root")
32+
}
33+
} else if("ranger" %in% class(forest)){
34+
measures <- c("mean_min_depth", forest$importance.mode, "no_of_nodes", "times_a_root")
35+
}
36+
}
37+
if("randomForest" %in% class(forest) && dim(forest$importance)[2] == 1){
2938
stop(paste("Your forest does not contain information on local importance so",
30-
paste(intersect(c("accuracy_decrease", "mse_increase"), measures), sep=", "),
39+
ifelse(forest$type == "classification", "accuracy_decrease", "mse_increase"),
3140
"measure cannot be extracted.",
3241
"To add it regrow the forest with the option localImp = TRUE and run this function again."))
3342
}
43+
if("ranger" %in% class(forest) && forest$importance.mode == "none"){
44+
stop(paste("Your forest does not contain importance information so",
45+
"importance cannot be extracted.",
46+
"To add it regrow the forest with the option importance other than 'none' and run this function again."))
47+
}
3448
environment <- new.env()
3549
environment$forest <- forest
3650
environment$data <- data
@@ -42,7 +56,7 @@ explain_forest <- function(forest, interactions = FALSE, data = NULL, vars = NUL
4256
directory <- getwd()
4357
path_to_templates <- file.path(path.package("randomForestExplainer"), "templates")
4458
template_name <- grep('explain_forest_template.rmd', list.files(path_to_templates),
45-
ignore.case = TRUE, value = TRUE)
59+
ignore.case = TRUE, value = TRUE)
4660

4761
rmarkdown::render(file.path(path_to_templates, template_name),
4862
"html_document", output_file = paste0(directory, "/Your_forest_explained.html"),

R/measure_importance.R

Lines changed: 117 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ measure_min_depth <- function(min_depth_frame, mean_sample){
77
}
88

99
# Calculate the number of nodes split on each variable for a data frame with the whole forest
10+
# randomForest
1011
measure_no_of_nodes <- function(forest_table){
1112
`split var` <- NULL
1213
frame <- dplyr::group_by(forest_table, `split var`) %>% dplyr::summarize(n())
@@ -16,7 +17,19 @@ measure_no_of_nodes <- function(forest_table){
1617
return(frame)
1718
}
1819

20+
# Calculate the number of nodes split on each variable for a data frame with the whole forest
21+
# randomForest
22+
measure_no_of_nodes_ranger <- function(forest_table){
23+
splitvarName <- NULL
24+
frame <- dplyr::group_by(forest_table, splitvarName) %>% dplyr::summarize(n())
25+
colnames(frame) <- c("variable", "no_of_nodes")
26+
frame <- as.data.frame(frame[!is.na(frame$variable),])
27+
frame$variable <- as.character(frame$variable)
28+
return(frame)
29+
}
30+
1931
# Extract randomForest variable importance measures
32+
# randomForest
2033
measure_vimp <- function(forest, only_nonlocal = FALSE){
2134
if(forest$type == "classification"){
2235
if(dim(forest$importance)[2] == 1){
@@ -44,6 +57,20 @@ measure_vimp <- function(forest, only_nonlocal = FALSE){
4457
return(frame)
4558
}
4659

60+
# Extract randomForest variable importance measures
61+
# ranger
62+
measure_vimp_ranger <- function(forest){
63+
if (forest$importance.mode == "none"){
64+
stop("No variable importance available, regenerate forest by ranger(..., importance='impurity').")
65+
}
66+
frame <- data.frame(importance=forest$variable.importance,
67+
variable=names(forest$variable.importance),
68+
stringsAsFactors = FALSE)
69+
colnames(frame)[1] <- forest$importance.mode
70+
# possible values are: impurity, 'impurity_corrected', 'permutation'.
71+
return(frame)
72+
}
73+
4774
# Calculate the number of trees using each variable for splitting
4875
measure_no_of_trees <- function(min_depth_frame){
4976
variable <- NULL
@@ -68,8 +95,8 @@ measure_times_a_root <- function(min_depth_frame){
6895
measure_p_value <- function(importance_frame){
6996
total_no_of_nodes <- sum(importance_frame$no_of_nodes)
7097
p_value <- unlist(lapply(importance_frame$no_of_nodes,
71-
function(x) stats::binom.test(x, total_no_of_nodes, 1/nrow(importance_frame),
72-
alternative = "greater")$p.value))
98+
function(x) stats::binom.test(x, total_no_of_nodes, 1/nrow(importance_frame),
99+
alternative = "greater")$p.value))
73100
return(p_value)
74101
}
75102

@@ -87,21 +114,25 @@ measure_p_value <- function(importance_frame){
87114
#'
88115
#' @return A data frame with rows corresponding to variables and columns to various measures of importance of variables
89116
#'
90-
#' @import dplyr
91-
#' @importFrom data.table rbindlist
92-
#'
93117
#' @examples
94118
#' forest <- randomForest::randomForest(Species ~ ., data = iris, localImp = TRUE, ntree = 300)
95119
#' measure_importance(forest)
96120
#'
97121
#' @export
98122
measure_importance <- function(forest, mean_sample = "top_trees", measures = NULL){
123+
UseMethod("measure_importance")
124+
}
125+
126+
#' @import dplyr
127+
#' @importFrom data.table rbindlist
128+
#' @export
129+
measure_importance.randomForest <- function(forest, mean_sample = "top_trees", measures = NULL){
99130
tree <- NULL; `split var` <- NULL; depth <- NULL
100131
if(is.null(measures)){
101132
if(forest$type == "classification"){
102133
measures <- c("mean_min_depth", "no_of_nodes", "accuracy_decrease",
103134
"gini_decrease", "no_of_trees", "times_a_root", "p_value")
104-
} else if(forest$type =="regression") {
135+
} else if(forest$type =="regression"){
105136
measures <- c("mean_min_depth", "no_of_nodes", "mse_increase", "node_purity_increase",
106137
"no_of_trees", "times_a_root", "p_value")
107138
}
@@ -130,7 +161,7 @@ measure_importance <- function(forest, mean_sample = "top_trees", measures = NUL
130161
}
131162
if(forest$type == "classification"){
132163
vimp <- c("accuracy_decrease", "gini_decrease")
133-
} else if(forest$type =="regression") {
164+
} else if(forest$type =="regression"){
134165
vimp <- c("mse_increase", "node_purity_increase")
135166
}
136167
if(all(vimp %in% measures)){
@@ -156,6 +187,54 @@ measure_importance <- function(forest, mean_sample = "top_trees", measures = NUL
156187
return(importance_frame)
157188
}
158189

190+
#' @import dplyr
191+
#' @importFrom data.table rbindlist
192+
#' @export
193+
measure_importance.ranger <- function(forest, mean_sample = "top_trees", measures = NULL){
194+
tree <- NULL; splitvarName <- NULL; depth <- NULL
195+
if(is.null(measures)){
196+
measures <- c("mean_min_depth", "no_of_nodes", forest$importance.mode, "no_of_trees", "times_a_root", "p_value")
197+
}
198+
if(("p_value" %in% measures) && !("no_of_nodes" %in% measures)){
199+
measures <- c(measures, "no_of_nodes")
200+
}
201+
importance_frame <- data.frame(variable = names(forest$variable.importance), stringsAsFactors = FALSE)
202+
# Get objects necessary to calculate importance measures based on the tree structure
203+
if(any(c("mean_min_depth", "no_of_nodes", "no_of_trees", "times_a_root", "p_value") %in% measures)){
204+
forest_table <-
205+
lapply(1:forest$num.trees, function(i) ranger::treeInfo(forest, tree = i) %>%
206+
calculate_tree_depth_ranger() %>% cbind(tree = i)) %>% rbindlist()
207+
min_depth_frame <- dplyr::group_by(forest_table, tree, splitvarName) %>%
208+
dplyr::summarize(min(depth))
209+
colnames(min_depth_frame) <- c("tree", "variable", "minimal_depth")
210+
min_depth_frame <- as.data.frame(min_depth_frame[!is.na(min_depth_frame$variable),])
211+
}
212+
# Add each importance measure to the table (if it was requested)
213+
if("mean_min_depth" %in% measures){
214+
importance_frame <- merge(importance_frame, measure_min_depth(min_depth_frame, mean_sample), all = TRUE)
215+
}
216+
if("no_of_nodes" %in% measures){
217+
importance_frame <- merge(importance_frame, measure_no_of_nodes_ranger(forest_table), all = TRUE)
218+
importance_frame[is.na(importance_frame$no_of_nodes), "no_of_nodes"] <- 0
219+
}
220+
if(forest$importance.mode %in% measures){
221+
importance_frame <- merge(importance_frame, measure_vimp_ranger(forest), all = TRUE)
222+
}
223+
if("no_of_trees" %in% measures){
224+
importance_frame <- merge(importance_frame, measure_no_of_trees(min_depth_frame), all = TRUE)
225+
importance_frame[is.na(importance_frame$no_of_trees), "no_of_trees"] <- 0
226+
}
227+
if("times_a_root" %in% measures){
228+
importance_frame <- merge(importance_frame, measure_times_a_root(min_depth_frame), all = TRUE)
229+
importance_frame[is.na(importance_frame$times_a_root), "times_a_root"] <- 0
230+
}
231+
if("p_value" %in% measures){
232+
importance_frame$p_value <- measure_p_value(importance_frame)
233+
importance_frame$variable <- as.factor(importance_frame$variable)
234+
}
235+
return(importance_frame)
236+
}
237+
159238
#' Extract k most important variables in a random forest
160239
#'
161240
#' Get the names of k variables with highest sum of rankings based on the specified importance measures
@@ -174,13 +253,16 @@ measure_importance <- function(forest, mean_sample = "top_trees", measures = NUL
174253
#' important_variables(measure_importance(forest), k = 2)
175254
#'
176255
#' @export
177-
important_variables <- function(importance_frame, k = 15, measures = names(importance_frame)[2:5],
256+
important_variables <- function(importance_frame, k = 15,
257+
measures = names(importance_frame)[2:min(5, ncol(importance_frame))],
178258
ties_action = "all"){
179259
if("randomForest" %in% class(importance_frame)){
180260
importance_frame <- measure_importance(importance_frame)
181261
if("predicted" %in% measures){
182262
measures <- names(importance_frame)[2:5]
183263
}
264+
} else if ("ranger" %in% class(importance_frame)){
265+
importance_frame <- measure_importance(importance_frame)
184266
}
185267
rankings <- data.frame(variable = importance_frame$variable, mean_min_depth =
186268
frankv(importance_frame$mean_min_depth, ties.method = "dense"),
@@ -232,7 +314,7 @@ plot_multi_way_importance <- function(importance_frame, x_measure = "mean_min_de
232314
min_no_of_trees = 0, no_of_labels = 10,
233315
main = "Multi-way importance plot"){
234316
variable <- NULL
235-
if("randomForest" %in% class(importance_frame)){
317+
if(any(c("randomForest", "ranger") %in% class(importance_frame))){
236318
importance_frame <- measure_importance(importance_frame)
237319
}
238320
data <- importance_frame[importance_frame$no_of_trees > min_no_of_trees, ]
@@ -294,14 +376,16 @@ plot_multi_way_importance <- function(importance_frame, x_measure = "mean_min_de
294376
#' plot_importance_ggpairs(frame, measures = c("mean_min_depth", "times_a_root"))
295377
#'
296378
#' @export
297-
plot_importance_ggpairs <- function(importance_frame, measures =
298-
names(importance_frame)[c(2, 4, 5, 3, 7)],
379+
plot_importance_ggpairs <- function(importance_frame, measures = NULL,
299380
main = "Relations between measures of importance"){
300-
if("randomForest" %in% class(importance_frame)){
381+
if(any(c("randomForest", "ranger") %in% class(importance_frame))){
301382
importance_frame <- measure_importance(importance_frame)
302-
if("predicted" %in% measures){
303-
names(importance_frame)[c(2, 4, 5, 3, 7)]
304-
}
383+
}
384+
if (is.null(measures)){
385+
default_measures <- c("gini_decrease", "node_purity_increase", # randomForest
386+
"impurity", "impurity_corrected", "permutation", # ranger
387+
"mean_min_depth", "no_of_trees", "no_of_nodes", "p_value")
388+
measures <- intersect(default_measures, colnames(importance_frame))
305389
}
306390
plot <- ggpairs(importance_frame[, measures]) + theme_bw()
307391
if(!is.null(main)){
@@ -315,7 +399,7 @@ plot_importance_ggpairs <- function(importance_frame, measures =
315399
#' Plot against each other rankings of variables according to various measures of importance
316400
#'
317401
#' @param importance_frame A result of using the function measure_importance() to a random forest or a randomForest object
318-
#' @param measures A character vector specifying the measures of importance to be used
402+
#' @param measures A character vector specifying the measures of importance to be used.
319403
#' @param main A string to be used as title of the plot
320404
#'
321405
#' @return A ggplot object
@@ -329,22 +413,29 @@ plot_importance_ggpairs <- function(importance_frame, measures =
329413
#' plot_importance_ggpairs(frame, measures = c("mean_min_depth", "times_a_root"))
330414
#'
331415
#' @export
332-
plot_importance_rankings <- function(importance_frame, measures =
333-
names(importance_frame)[c(2, 4, 5, 3, 7)],
416+
plot_importance_rankings <- function(importance_frame, measures = NULL,
334417
main = "Relations between rankings according to different measures"){
335-
if("randomForest" %in% class(importance_frame)){
418+
if(any(c("randomForest", "ranger") %in% class(importance_frame))){
336419
importance_frame <- measure_importance(importance_frame)
337-
if("predicted" %in% measures){
338-
names(importance_frame)[c(2, 4, 5, 3, 7)]
339-
}
340420
}
341-
rankings <- data.frame(variable = importance_frame$variable, mean_min_depth =
342-
frankv(importance_frame$mean_min_depth, ties.method = "dense"),
343-
apply(importance_frame[, -c(1, 2)], 2,
421+
if (is.null(measures)){
422+
default_measures <- c("gini_decrease", "node_purity_increase", # randomForest
423+
"impurity", "impurity_corrected", "permutation", # ranger
424+
"mean_min_depth", "no_of_trees", "no_of_nodes", "p_value")
425+
measures <- intersect(default_measures, colnames(importance_frame))
426+
}
427+
rankings <- data.frame(variable = importance_frame$variable,
428+
apply(importance_frame[, !colnames(importance_frame) %in% c("variable", "mean_min_depth", "p_value")], 2,
344429
function(x) frankv(x, order = -1, ties.method = "dense")))
430+
if ("mean_min_depth" %in% measures){
431+
rankings$mean_min_depth = frankv(importance_frame$mean_min_depth, ties.method = "dense")
432+
}
433+
if ("p_value" %in% measures){
434+
rankings$p_value = frankv(importance_frame$p_value, ties.method = "dense")
435+
}
345436
plot <- ggpairs(rankings[, measures], lower = list(continuous = function(data, mapping){
346437
ggplot(data = data, mapping = mapping) + geom_point() + geom_smooth(method = "loess")
347-
}))+ theme_bw()
438+
})) + theme_bw()
348439
if(!is.null(main)){
349440
plot <- plot + ggtitle(main)
350441
}

R/min_depth_distribution.R

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,43 @@ calculate_tree_depth <- function(frame){
1414
return(frame)
1515
}
1616

17+
# Calculate the depth of each node in a single tree obtained from a forest with ranger::treeInfo
18+
calculate_tree_depth_ranger <- function(frame){
19+
if(!all(c("rightChild", "leftChild") %in% names(frame))){
20+
stop("The data frame has to contain columns called 'rightChild' and 'leftChild'!
21+
It should be a product of the function ranger::treeInfo().")
22+
}
23+
frame$depth <- NA
24+
frame$depth[1] <- 0
25+
for(i in 2:nrow(frame)){
26+
frame[i, "depth"] <-
27+
frame[(!is.na(frame[, "leftChild"]) & frame[, "leftChild"] == frame[i, "nodeID"]) |
28+
(!is.na(frame[, "rightChild"]) & frame[, "rightChild"] == frame[i, "nodeID"]), "depth"] + 1
29+
}
30+
return(frame)
31+
}
32+
1733
#' Calculate minimal depth distribution of a random forest
1834
#'
1935
#' Get minimal depth values for all trees in a random forest
2036
#'
21-
#' @param forest A randomForest object
37+
#' @param forest A randomForest or ranger object
2238
#'
2339
#' @return A data frame with the value of minimal depth for every variable in every tree
2440
#'
25-
#' @import dplyr
26-
#' @importFrom data.table rbindlist
27-
#'
2841
#' @examples
2942
#' min_depth_distribution(randomForest::randomForest(Species ~ ., data = iris))
43+
#' min_depth_distribution(ranger::ranger(Species ~ ., data = iris))
3044
#'
3145
#' @export
3246
min_depth_distribution <- function(forest){
47+
UseMethod("min_depth_distribution")
48+
}
49+
50+
#' @import dplyr
51+
#' @importFrom data.table rbindlist
52+
#' @export
53+
min_depth_distribution.randomForest <- function(forest){
3354
tree <- NULL; `split var` <- NULL; depth <- NULL
3455
forest_table <-
3556
lapply(1:forest$ntree, function(i) randomForest::getTree(forest, k = i, labelVar = T) %>%
@@ -41,6 +62,21 @@ min_depth_distribution <- function(forest){
4162
return(min_depth_frame)
4263
}
4364

65+
#' @import dplyr
66+
#' @importFrom data.table rbindlist
67+
#' @export
68+
min_depth_distribution.ranger <- function(forest){
69+
tree <- NULL; splitvarName <- NULL; depth <- NULL
70+
forest_table <-
71+
lapply(1:forest$num.trees, function(i) ranger::treeInfo(forest, tree = i) %>%
72+
calculate_tree_depth_ranger() %>% cbind(tree = i)) %>% rbindlist()
73+
min_depth_frame <- dplyr::group_by(forest_table, tree, splitvarName) %>%
74+
dplyr::summarize(min(depth))
75+
colnames(min_depth_frame) <- c("tree", "variable", "minimal_depth")
76+
min_depth_frame <- as.data.frame(min_depth_frame[!is.na(min_depth_frame$variable),])
77+
return(min_depth_frame)
78+
}
79+
4480
# Count the trees in which each variable had a given minimal depth
4581
min_depth_count <- function(min_depth_frame){
4682
tree <- NULL; minimal_depth <- NULL; variable <- NULL

0 commit comments

Comments
 (0)