@@ -14,22 +14,43 @@ calculate_tree_depth <- function(frame){
1414 return (frame )
1515}
1616
17+ # Calculate the depth of each node in a single tree obtained from a forest with ranger::treeInfo
18+ calculate_tree_depth_ranger <- function (frame ){
19+ if (! all(c(" rightChild" , " leftChild" ) %in% names(frame ))){
20+ stop(" The data frame has to contain columns called 'rightChild' and 'leftChild'!
21+ It should be a product of the function ranger::treeInfo()." )
22+ }
23+ frame $ depth <- NA
24+ frame $ depth [1 ] <- 0
25+ for (i in 2 : nrow(frame )){
26+ frame [i , " depth" ] <-
27+ frame [(! is.na(frame [, " leftChild" ]) & frame [, " leftChild" ] == frame [i , " nodeID" ]) |
28+ (! is.na(frame [, " rightChild" ]) & frame [, " rightChild" ] == frame [i , " nodeID" ]), " depth" ] + 1
29+ }
30+ return (frame )
31+ }
32+
1733# ' Calculate minimal depth distribution of a random forest
1834# '
1935# ' Get minimal depth values for all trees in a random forest
2036# '
21- # ' @param forest A randomForest object
37+ # ' @param forest A randomForest or ranger object
2238# '
2339# ' @return A data frame with the value of minimal depth for every variable in every tree
2440# '
25- # ' @import dplyr
26- # ' @importFrom data.table rbindlist
27- # '
2841# ' @examples
2942# ' min_depth_distribution(randomForest::randomForest(Species ~ ., data = iris))
43+ # ' min_depth_distribution(ranger::ranger(Species ~ ., data = iris))
3044# '
3145# ' @export
3246min_depth_distribution <- function (forest ){
47+ UseMethod(" min_depth_distribution" )
48+ }
49+
50+ # ' @import dplyr
51+ # ' @importFrom data.table rbindlist
52+ # ' @export
53+ min_depth_distribution.randomForest <- function (forest ){
3354 tree <- NULL ; `split var` <- NULL ; depth <- NULL
3455 forest_table <-
3556 lapply(1 : forest $ ntree , function (i ) randomForest :: getTree(forest , k = i , labelVar = T ) %> %
@@ -41,6 +62,21 @@ min_depth_distribution <- function(forest){
4162 return (min_depth_frame )
4263}
4364
65+ # ' @import dplyr
66+ # ' @importFrom data.table rbindlist
67+ # ' @export
68+ min_depth_distribution.ranger <- function (forest ){
69+ tree <- NULL ; splitvarName <- NULL ; depth <- NULL
70+ forest_table <-
71+ lapply(1 : forest $ num.trees , function (i ) ranger :: treeInfo(forest , tree = i ) %> %
72+ calculate_tree_depth_ranger() %> % cbind(tree = i )) %> % rbindlist()
73+ min_depth_frame <- dplyr :: group_by(forest_table , tree , splitvarName ) %> %
74+ dplyr :: summarize(min(depth ))
75+ colnames(min_depth_frame ) <- c(" tree" , " variable" , " minimal_depth" )
76+ min_depth_frame <- as.data.frame(min_depth_frame [! is.na(min_depth_frame $ variable ),])
77+ return (min_depth_frame )
78+ }
79+
4480# Count the trees in which each variable had a given minimal depth
4581min_depth_count <- function (min_depth_frame ){
4682 tree <- NULL ; minimal_depth <- NULL ; variable <- NULL
0 commit comments