Skip to content

Commit fb35f7f

Browse files
committed
make plot_min_depth_distribution work with ranger forests
1 parent eb0644a commit fb35f7f

File tree

4 files changed

+46
-6
lines changed

4 files changed

+46
-6
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@ Imports:
2020
ggrepel (>= 0.6.5),
2121
MASS (>= 7.3.47),
2222
randomForest (>= 4.6.12),
23+
ranger(>= 0.9.0),
2324
reshape2 (>= 1.4.2),
2425
rmarkdown (>= 1.5)
2526
Suggests:
2627
knitr
2728
VignetteBuilder: knitr
28-
RoxygenNote: 6.0.1
29+
RoxygenNote: 6.1.1
2930
URL: https://github.com/MI2DataLab/randomForestExplainer

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method(min_depth_distribution,randomForest)
4+
S3method(min_depth_distribution,ranger)
35
export(explain_forest)
46
export(important_variables)
57
export(measure_importance)

R/min_depth_distribution.R

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,43 @@ calculate_tree_depth <- function(frame){
1414
return(frame)
1515
}
1616

17+
# Calculate the depth of each node in a single tree obtained from a forest with ranger::treeInfo
18+
calculate_tree_depth_ranger <- function(frame){
19+
if(!all(c("rightChild", "leftChild") %in% names(frame))){
20+
stop("The data frame has to contain columns called 'rightChild' and 'leftChild'!
21+
It should be a product of the function ranger::treeInfo().")
22+
}
23+
frame$depth <- NA
24+
frame$depth[1] <- 0
25+
for(i in 2:nrow(frame)){
26+
frame[i, "depth"] <-
27+
frame[(!is.na(frame[, "leftChild"]) & frame[, "leftChild"] == frame[i, "nodeID"]) |
28+
(!is.na(frame[, "rightChild"]) & frame[, "rightChild"] == frame[i, "nodeID"]), "depth"] + 1
29+
}
30+
return(frame)
31+
}
32+
1733
#' Calculate minimal depth distribution of a random forest
1834
#'
1935
#' Get minimal depth values for all trees in a random forest
2036
#'
21-
#' @param forest A randomForest object
37+
#' @param forest A randomForest or ranger object
2238
#'
2339
#' @return A data frame with the value of minimal depth for every variable in every tree
2440
#'
25-
#' @import dplyr
26-
#' @importFrom data.table rbindlist
27-
#'
2841
#' @examples
2942
#' min_depth_distribution(randomForest::randomForest(Species ~ ., data = iris))
43+
#' min_depth_distribution(ranger::ranger(Species ~ ., data = iris))
3044
#'
3145
#' @export
3246
min_depth_distribution <- function(forest){
47+
UseMethod("min_depth_distribution")
48+
}
49+
50+
#' @import dplyr
51+
#' @importFrom data.table rbindlist
52+
#' @export
53+
min_depth_distribution.randomForest <- function(forest){
3354
tree <- NULL; `split var` <- NULL; depth <- NULL
3455
forest_table <-
3556
lapply(1:forest$ntree, function(i) randomForest::getTree(forest, k = i, labelVar = T) %>%
@@ -41,6 +62,21 @@ min_depth_distribution <- function(forest){
4162
return(min_depth_frame)
4263
}
4364

65+
#' @import dplyr
66+
#' @importFrom data.table rbindlist
67+
#' @export
68+
min_depth_distribution.ranger <- function(forest){
69+
tree <- NULL; splitvarName <- NULL; depth <- NULL
70+
forest_table <-
71+
lapply(1:forest$num.trees, function(i) ranger::treeInfo(forest, tree = i) %>%
72+
calculate_tree_depth_ranger() %>% cbind(tree = i)) %>% rbindlist()
73+
min_depth_frame <- dplyr::group_by(forest_table, tree, splitvarName) %>%
74+
dplyr::summarize(min(depth))
75+
colnames(min_depth_frame) <- c("tree", "variable", "minimal_depth")
76+
min_depth_frame <- as.data.frame(min_depth_frame[!is.na(min_depth_frame$variable),])
77+
return(min_depth_frame)
78+
}
79+
4480
# Count the trees in which each variable had a given minimal depth
4581
min_depth_count <- function(min_depth_frame){
4682
tree <- NULL; minimal_depth <- NULL; variable <- NULL

man/min_depth_distribution.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)