Skip to content

Commit 10ab14d

Browse files
committed
first pass adding ranger support for multi way importance, untested
1 parent 3fbce5a commit 10ab14d

9 files changed

+122
-31
lines changed

NAMESPACE

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method(measure_importance,randomForest)
4+
S3method(measure_importance,ranger)
35
S3method(min_depth_distribution,randomForest)
46
S3method(min_depth_distribution,ranger)
7+
S3method(min_depth_interactions,randomForest)
8+
S3method(min_depth_interactions,ranger)
9+
S3method(plot_predict_interaction,randomForest)
10+
S3method(plot_predict_interaction,ranger)
511
export(explain_forest)
612
export(important_variables)
713
export(measure_importance)

R/measure_importance.R

Lines changed: 87 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ measure_min_depth <- function(min_depth_frame, mean_sample){
77
}
88

99
# Calculate the number of nodes split on each variable for a data frame with the whole forest
10+
# randomForest
1011
measure_no_of_nodes <- function(forest_table){
1112
`split var` <- NULL
1213
frame <- dplyr::group_by(forest_table, `split var`) %>% dplyr::summarize(n())
@@ -16,7 +17,19 @@ measure_no_of_nodes <- function(forest_table){
1617
return(frame)
1718
}
1819

20+
# Calculate the number of nodes split on each variable for a data frame with the whole forest
21+
# randomForest
22+
measure_no_of_nodes_ranger <- function(forest_table){
23+
splitvarName <- NULL
24+
frame <- dplyr::group_by(forest_table, splitvarName) %>% dplyr::summarize(n())
25+
colnames(frame) <- c("variable", "no_of_nodes")
26+
frame <- as.data.frame(frame[!is.na(frame$variable),])
27+
frame$variable <- as.character(frame$variable)
28+
return(frame)
29+
}
30+
1931
# Extract randomForest variable importance measures
32+
# randomForest
2033
measure_vimp <- function(forest, only_nonlocal = FALSE){
2134
if(forest$type == "classification"){
2235
if(dim(forest$importance)[2] == 1){
@@ -44,6 +57,20 @@ measure_vimp <- function(forest, only_nonlocal = FALSE){
4457
return(frame)
4558
}
4659

60+
# Extract randomForest variable importance measures
61+
# ranger
62+
measure_vimp_ranger <- function(forest){
63+
if (forest$importance.mode == "none") {
64+
stop("No variable importance available, regenerate forest by ranger(..., importance='impurity').")
65+
}
66+
frame <- data.frame(importance=forest$variable.importance,
67+
variable=names(forest$variable.importance),
68+
stringsAsFactors = FALSE)
69+
colnames(frame)[1] <- forest$importance.mode
70+
# possible values are: impurity, 'impurity_corrected', 'permutation'.
71+
return(frame)
72+
}
73+
4774
# Calculate the number of trees using each variable for splitting
4875
measure_no_of_trees <- function(min_depth_frame){
4976
variable <- NULL
@@ -87,15 +114,19 @@ measure_p_value <- function(importance_frame){
87114
#'
88115
#' @return A data frame with rows corresponding to variables and columns to various measures of importance of variables
89116
#'
90-
#' @import dplyr
91-
#' @importFrom data.table rbindlist
92-
#'
93117
#' @examples
94118
#' forest <- randomForest::randomForest(Species ~ ., data = iris, localImp = TRUE, ntree = 300)
95119
#' measure_importance(forest)
96120
#'
97121
#' @export
98122
measure_importance <- function(forest, mean_sample = "top_trees", measures = NULL){
123+
UseMethod("measure_importance")
124+
}
125+
126+
#' @import dplyr
127+
#' @importFrom data.table rbindlist
128+
#' @export
129+
measure_importance.randomForest <- function(forest, mean_sample = "top_trees", measures = NULL){
99130
tree <- NULL; `split var` <- NULL; depth <- NULL
100131
if(is.null(measures)){
101132
if(forest$type == "classification"){
@@ -156,6 +187,54 @@ measure_importance <- function(forest, mean_sample = "top_trees", measures = NUL
156187
return(importance_frame)
157188
}
158189

190+
#' @import dplyr
191+
#' @importFrom data.table rbindlist
192+
#' @export
193+
measure_importance.ranger <- function(forest, mean_sample = "top_trees", measures = NULL){
194+
tree <- NULL; splitvarName <- NULL; depth <- NULL
195+
if(is.null(measures)){
196+
measures <- c("mean_min_depth", "no_of_nodes", forest$importance.mode, "no_of_trees", "times_a_root", "p_value")
197+
}
198+
if(("p_value" %in% measures) && !("no_of_nodes" %in% measures)){
199+
measures <- c(measures, "no_of_nodes")
200+
}
201+
importance_frame <- data.frame(variable = names(forest$variable.importance), stringsAsFactors = FALSE)
202+
# Get objects necessary to calculate importance measures based on the tree structure
203+
if(any(c("mean_min_depth", "no_of_nodes", "no_of_trees", "times_a_root", "p_value") %in% measures)){
204+
forest_table <-
205+
lapply(1:forest$num.trees, function(i) ranger::treeInfo(forest, tree = i) %>%
206+
calculate_tree_depth_ranger() %>% cbind(tree = i)) %>% rbindlist()
207+
min_depth_frame <- dplyr::group_by(forest_table, tree, splitvarName) %>%
208+
dplyr::summarize(min(depth))
209+
colnames(min_depth_frame) <- c("tree", "variable", "minimal_depth")
210+
min_depth_frame <- as.data.frame(min_depth_frame[!is.na(min_depth_frame$variable),])
211+
}
212+
# Add each importance measure to the table (if it was requested)
213+
if("mean_min_depth" %in% measures){
214+
importance_frame <- merge(importance_frame, measure_min_depth(min_depth_frame, mean_sample), all = TRUE)
215+
}
216+
if("no_of_nodes" %in% measures){
217+
importance_frame <- merge(importance_frame, measure_no_of_nodes_ranger(forest_table), all = TRUE)
218+
importance_frame[is.na(importance_frame$no_of_nodes), "no_of_nodes"] <- 0
219+
}
220+
if(forest$importance.mode %in% measures){
221+
importance_frame <- merge(importance_frame, measure_vimp_ranger(forest), all = TRUE)
222+
}
223+
if("no_of_trees" %in% measures){
224+
importance_frame <- merge(importance_frame, measure_no_of_trees(min_depth_frame), all = TRUE)
225+
importance_frame[is.na(importance_frame$no_of_trees), "no_of_trees"] <- 0
226+
}
227+
if("times_a_root" %in% measures){
228+
importance_frame <- merge(importance_frame, measure_times_a_root(min_depth_frame), all = TRUE)
229+
importance_frame[is.na(importance_frame$times_a_root), "times_a_root"] <- 0
230+
}
231+
if("p_value" %in% measures){
232+
importance_frame$p_value <- measure_p_value(importance_frame)
233+
importance_frame$variable <- as.factor(importance_frame$variable)
234+
}
235+
return(importance_frame)
236+
}
237+
159238
#' Extract k most important variables in a random forest
160239
#'
161240
#' Get the names of k variables with highest sum of rankings based on the specified importance measures
@@ -174,13 +253,16 @@ measure_importance <- function(forest, mean_sample = "top_trees", measures = NUL
174253
#' important_variables(measure_importance(forest), k = 2)
175254
#'
176255
#' @export
177-
important_variables <- function(importance_frame, k = 15, measures = names(importance_frame)[2:5],
256+
important_variables <- function(importance_frame, k = 15,
257+
measures = names(importance_frame)[2:min(5, ncol(importance_frame))],
178258
ties_action = "all"){
179259
if("randomForest" %in% class(importance_frame)){
180260
importance_frame <- measure_importance(importance_frame)
181261
if("predicted" %in% measures){
182262
measures <- names(importance_frame)[2:5]
183263
}
264+
} else if ("ranger" %in% class(importance_frame)) {
265+
importance_frame <- measure_importance(importance_frame)
184266
}
185267
rankings <- data.frame(variable = importance_frame$variable, mean_min_depth =
186268
frankv(importance_frame$mean_min_depth, ties.method = "dense"),
@@ -232,7 +314,7 @@ plot_multi_way_importance <- function(importance_frame, x_measure = "mean_min_de
232314
min_no_of_trees = 0, no_of_labels = 10,
233315
main = "Multi-way importance plot"){
234316
variable <- NULL
235-
if("randomForest" %in% class(importance_frame)){
317+
if(any(c("randomForest", "ranger") %in% class(importance_frame))){
236318
importance_frame <- measure_importance(importance_frame)
237319
}
238320
data <- importance_frame[importance_frame$no_of_trees > min_no_of_trees, ]

R/min_depth_interactions.R

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ min_depth_interactions_values <- function(forest, vars){
6161
interactions_frame <-
6262
lapply(1:forest$ntree, function(i) randomForest::getTree(forest, k = i, labelVar = T) %>%
6363
calculate_tree_depth() %>% cbind(., tree = i, number = 1:nrow(.))) %>%
64-
data.table::rbindlist() %>% as.data.frame()
64+
data.table::rbindlist() %>% as.data.frame()
6565
interactions_frame[vars] <- as.numeric(NA)
6666
interactions_frame <-
6767
data.table::as.data.table(interactions_frame)[, conditional_depth(as.data.frame(.SD), vars), by = tree] %>% as.data.frame()
@@ -128,7 +128,7 @@ min_depth_interactions <- function(){
128128
#' @importFrom data.table rbindlist
129129
#' @export
130130
min_depth_interactions.randomForest <- function(forest, vars = important_variables(measure_importance(forest)),
131-
mean_sample = "top_trees", uncond_mean_sample = mean_sample){
131+
mean_sample = "top_trees", uncond_mean_sample = mean_sample){
132132
variable <- NULL; `.` <- NULL; tree <- NULL; `split var` <- NULL; depth <- NULL
133133
min_depth_interactions_frame <- min_depth_interactions_values(forest, vars)
134134
mean_tree_depth <- min_depth_interactions_frame[[2]]
@@ -145,7 +145,7 @@ min_depth_interactions.randomForest <- function(forest, vars = important_variabl
145145
non_occurrences[, -1] <- forest$ntree - occurrences[, -1]
146146
interactions_frame[is.na(as.matrix(interactions_frame))] <- 0
147147
interactions_frame[, -1] <- (interactions_frame[, -1] * occurrences[, -1] +
148-
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth))/forest$ntree
148+
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth))/forest$ntree
149149
} else if(mean_sample == "top_trees"){
150150
non_occurrences <- occurrences
151151
non_occurrences[, -1] <- forest$ntree - occurrences[, -1]
@@ -177,7 +177,7 @@ min_depth_interactions.randomForest <- function(forest, vars = important_variabl
177177
#' @importFrom data.table rbindlist
178178
#' @export
179179
min_depth_interactions.ranger <- function(forest, vars = important_variables(measure_importance(forest)),
180-
mean_sample = "top_trees", uncond_mean_sample = mean_sample){
180+
mean_sample = "top_trees", uncond_mean_sample = mean_sample){
181181
variable <- NULL; `.` <- NULL; tree <- NULL; splitvarName <- NULL; depth <- NULL
182182
min_depth_interactions_frame <- min_depth_interactions_values_ranger(forest, vars)
183183
mean_tree_depth <- min_depth_interactions_frame[[2]]
@@ -217,7 +217,7 @@ min_depth_interactions.ranger <- function(forest, vars = important_variables(mea
217217
dplyr::summarize(min(depth))
218218
colnames(min_depth_frame) <- c("tree", "variable", "minimal_depth")
219219
min_depth_frame <- as.data.frame(min_depth_frame[!is.na(min_depth_frame$variable),])
220-
importance_frame <- get_min_depth_means_ranger(min_depth_frame, min_depth_count_ranger(min_depth_frame), uncond_mean_sample)
220+
importance_frame <- get_min_depth_means_ranger(min_depth_frame, min_depth_count(min_depth_frame), uncond_mean_sample)
221221
colnames(importance_frame)[2] <- "uncond_mean_min_depth"
222222
interactions_frame <- merge(interactions_frame, importance_frame)
223223
}
@@ -254,7 +254,7 @@ plot_min_depth_interactions <- function(interactions_frame, k = 30,
254254
aes(x = interaction, y = mean_min_depth, fill = occurrences)) +
255255
geom_bar(stat = "identity") +
256256
geom_pointrange(aes(ymin = pmin(mean_min_depth, uncond_mean_min_depth), y = uncond_mean_min_depth,
257-
ymax = pmax(mean_min_depth, uncond_mean_min_depth), shape = "unconditional"), fatten = 2, size = 1) +
257+
ymax = pmax(mean_min_depth, uncond_mean_min_depth), shape = "unconditional"), fatten = 2, size = 1) +
258258
geom_hline(aes(yintercept = minimum, linetype = "minimum"), color = "red", size = 1.5) +
259259
scale_linetype_manual(name = NULL, values = 1) + theme_bw() +
260260
scale_shape_manual(name = NULL, values = 19) +
@@ -267,7 +267,7 @@ plot_min_depth_interactions <- function(interactions_frame, k = 30,
267267

268268
#' Plot the prediction of the forest for a grid of values of two numerical variables
269269
#'
270-
#' @param forest A randomForest object
270+
#' @param forest A randomForest or ranger object
271271
#' @param data The data frame on which forest was trained
272272
#' @param variable1 A character string with the name a numerical predictor that will on X-axis
273273
#' @param variable2 A character string with the name a numerical predictor that will on Y-axis
@@ -295,8 +295,8 @@ plot_predict_interaction <- function(forest, data, variable1, variable2, grid =
295295
#' @importFrom stats as.formula
296296
#' @export
297297
plot_predict_interaction.randomForest <- function(forest, data, variable1, variable2, grid = 100,
298-
main = paste0("Prediction of the forest for different values of ",
299-
paste0(variable1, paste0(" and ", variable2)))){
298+
main = paste0("Prediction of the forest for different values of ",
299+
paste0(variable1, paste0(" and ", variable2)))){
300300
newdata <- expand.grid(seq(min(data[[variable1]]), max(data[[variable1]]), length.out = grid),
301301
seq(min(data[[variable2]]), max(data[[variable2]]), length.out = grid))
302302
colnames(newdata) <- c(variable1, variable2)
@@ -342,8 +342,8 @@ plot_predict_interaction.randomForest <- function(forest, data, variable1, varia
342342
#' @importFrom stats as.formula
343343
#' @export
344344
plot_predict_interaction.ranger <- function(forest, data, variable1, variable2, grid = 100,
345-
main = paste0("Prediction of the forest for different values of ",
346-
paste0(variable1, paste0(" and ", variable2)))){
345+
main = paste0("Prediction of the forest for different values of ",
346+
paste0(variable1, paste0(" and ", variable2)))){
347347
newdata <- expand.grid(seq(min(data[[variable1]]), max(data[[variable1]]), length.out = grid),
348348
seq(min(data[[variable2]]), max(data[[variable2]]), length.out = grid))
349349
colnames(newdata) <- c(variable1, variable2)

man/explain_forest.Rd

Lines changed: 6 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/important_variables.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/min_depth_interactions.Rd

Lines changed: 1 addition & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/plot_min_depth_distribution.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/plot_multi_way_importance.Rd

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/plot_predict_interaction.Rd

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)