@@ -7,6 +7,7 @@ measure_min_depth <- function(min_depth_frame, mean_sample){
77}
88
99# Calculate the number of nodes split on each variable for a data frame with the whole forest
10+ # randomForest
1011measure_no_of_nodes <- function (forest_table ){
1112 `split var` <- NULL
1213 frame <- dplyr :: group_by(forest_table , `split var` ) %> % dplyr :: summarize(n())
@@ -16,7 +17,19 @@ measure_no_of_nodes <- function(forest_table){
1617 return (frame )
1718}
1819
20+ # Calculate the number of nodes split on each variable for a data frame with the whole forest
21+ # randomForest
22+ measure_no_of_nodes_ranger <- function (forest_table ){
23+ splitvarName <- NULL
24+ frame <- dplyr :: group_by(forest_table , splitvarName ) %> % dplyr :: summarize(n())
25+ colnames(frame ) <- c(" variable" , " no_of_nodes" )
26+ frame <- as.data.frame(frame [! is.na(frame $ variable ),])
27+ frame $ variable <- as.character(frame $ variable )
28+ return (frame )
29+ }
30+
1931# Extract randomForest variable importance measures
32+ # randomForest
2033measure_vimp <- function (forest , only_nonlocal = FALSE ){
2134 if (forest $ type == " classification" ){
2235 if (dim(forest $ importance )[2 ] == 1 ){
@@ -44,6 +57,20 @@ measure_vimp <- function(forest, only_nonlocal = FALSE){
4457 return (frame )
4558}
4659
60+ # Extract randomForest variable importance measures
61+ # ranger
62+ measure_vimp_ranger <- function (forest ){
63+ if (forest $ importance.mode == " none" ){
64+ stop(" No variable importance available, regenerate forest by ranger(..., importance='impurity')." )
65+ }
66+ frame <- data.frame (importance = forest $ variable.importance ,
67+ variable = names(forest $ variable.importance ),
68+ stringsAsFactors = FALSE )
69+ colnames(frame )[1 ] <- forest $ importance.mode
70+ # possible values are: impurity, 'impurity_corrected', 'permutation'.
71+ return (frame )
72+ }
73+
4774# Calculate the number of trees using each variable for splitting
4875measure_no_of_trees <- function (min_depth_frame ){
4976 variable <- NULL
@@ -68,8 +95,8 @@ measure_times_a_root <- function(min_depth_frame){
6895measure_p_value <- function (importance_frame ){
6996 total_no_of_nodes <- sum(importance_frame $ no_of_nodes )
7097 p_value <- unlist(lapply(importance_frame $ no_of_nodes ,
71- function (x ) stats :: binom.test(x , total_no_of_nodes , 1 / nrow(importance_frame ),
72- alternative = " greater" )$ p.value ))
98+ function (x ) stats :: binom.test(x , total_no_of_nodes , 1 / nrow(importance_frame ),
99+ alternative = " greater" )$ p.value ))
73100 return (p_value )
74101}
75102
@@ -87,21 +114,25 @@ measure_p_value <- function(importance_frame){
87114# '
88115# ' @return A data frame with rows corresponding to variables and columns to various measures of importance of variables
89116# '
90- # ' @import dplyr
91- # ' @importFrom data.table rbindlist
92- # '
93117# ' @examples
94118# ' forest <- randomForest::randomForest(Species ~ ., data = iris, localImp = TRUE, ntree = 300)
95119# ' measure_importance(forest)
96120# '
97121# ' @export
98122measure_importance <- function (forest , mean_sample = " top_trees" , measures = NULL ){
123+ UseMethod(" measure_importance" )
124+ }
125+
126+ # ' @import dplyr
127+ # ' @importFrom data.table rbindlist
128+ # ' @export
129+ measure_importance.randomForest <- function (forest , mean_sample = " top_trees" , measures = NULL ){
99130 tree <- NULL ; `split var` <- NULL ; depth <- NULL
100131 if (is.null(measures )){
101132 if (forest $ type == " classification" ){
102133 measures <- c(" mean_min_depth" , " no_of_nodes" , " accuracy_decrease" ,
103134 " gini_decrease" , " no_of_trees" , " times_a_root" , " p_value" )
104- } else if (forest $ type == " regression" ) {
135+ } else if (forest $ type == " regression" ){
105136 measures <- c(" mean_min_depth" , " no_of_nodes" , " mse_increase" , " node_purity_increase" ,
106137 " no_of_trees" , " times_a_root" , " p_value" )
107138 }
@@ -130,7 +161,7 @@ measure_importance <- function(forest, mean_sample = "top_trees", measures = NUL
130161 }
131162 if (forest $ type == " classification" ){
132163 vimp <- c(" accuracy_decrease" , " gini_decrease" )
133- } else if (forest $ type == " regression" ) {
164+ } else if (forest $ type == " regression" ){
134165 vimp <- c(" mse_increase" , " node_purity_increase" )
135166 }
136167 if (all(vimp %in% measures )){
@@ -156,6 +187,54 @@ measure_importance <- function(forest, mean_sample = "top_trees", measures = NUL
156187 return (importance_frame )
157188}
158189
190+ # ' @import dplyr
191+ # ' @importFrom data.table rbindlist
192+ # ' @export
193+ measure_importance.ranger <- function (forest , mean_sample = " top_trees" , measures = NULL ){
194+ tree <- NULL ; splitvarName <- NULL ; depth <- NULL
195+ if (is.null(measures )){
196+ measures <- c(" mean_min_depth" , " no_of_nodes" , forest $ importance.mode , " no_of_trees" , " times_a_root" , " p_value" )
197+ }
198+ if ((" p_value" %in% measures ) && ! (" no_of_nodes" %in% measures )){
199+ measures <- c(measures , " no_of_nodes" )
200+ }
201+ importance_frame <- data.frame (variable = names(forest $ variable.importance ), stringsAsFactors = FALSE )
202+ # Get objects necessary to calculate importance measures based on the tree structure
203+ if (any(c(" mean_min_depth" , " no_of_nodes" , " no_of_trees" , " times_a_root" , " p_value" ) %in% measures )){
204+ forest_table <-
205+ lapply(1 : forest $ num.trees , function (i ) ranger :: treeInfo(forest , tree = i ) %> %
206+ calculate_tree_depth_ranger() %> % cbind(tree = i )) %> % rbindlist()
207+ min_depth_frame <- dplyr :: group_by(forest_table , tree , splitvarName ) %> %
208+ dplyr :: summarize(min(depth ))
209+ colnames(min_depth_frame ) <- c(" tree" , " variable" , " minimal_depth" )
210+ min_depth_frame <- as.data.frame(min_depth_frame [! is.na(min_depth_frame $ variable ),])
211+ }
212+ # Add each importance measure to the table (if it was requested)
213+ if (" mean_min_depth" %in% measures ){
214+ importance_frame <- merge(importance_frame , measure_min_depth(min_depth_frame , mean_sample ), all = TRUE )
215+ }
216+ if (" no_of_nodes" %in% measures ){
217+ importance_frame <- merge(importance_frame , measure_no_of_nodes_ranger(forest_table ), all = TRUE )
218+ importance_frame [is.na(importance_frame $ no_of_nodes ), " no_of_nodes" ] <- 0
219+ }
220+ if (forest $ importance.mode %in% measures ){
221+ importance_frame <- merge(importance_frame , measure_vimp_ranger(forest ), all = TRUE )
222+ }
223+ if (" no_of_trees" %in% measures ){
224+ importance_frame <- merge(importance_frame , measure_no_of_trees(min_depth_frame ), all = TRUE )
225+ importance_frame [is.na(importance_frame $ no_of_trees ), " no_of_trees" ] <- 0
226+ }
227+ if (" times_a_root" %in% measures ){
228+ importance_frame <- merge(importance_frame , measure_times_a_root(min_depth_frame ), all = TRUE )
229+ importance_frame [is.na(importance_frame $ times_a_root ), " times_a_root" ] <- 0
230+ }
231+ if (" p_value" %in% measures ){
232+ importance_frame $ p_value <- measure_p_value(importance_frame )
233+ importance_frame $ variable <- as.factor(importance_frame $ variable )
234+ }
235+ return (importance_frame )
236+ }
237+
159238# ' Extract k most important variables in a random forest
160239# '
161240# ' Get the names of k variables with highest sum of rankings based on the specified importance measures
@@ -174,13 +253,16 @@ measure_importance <- function(forest, mean_sample = "top_trees", measures = NUL
174253# ' important_variables(measure_importance(forest), k = 2)
175254# '
176255# ' @export
177- important_variables <- function (importance_frame , k = 15 , measures = names(importance_frame )[2 : 5 ],
256+ important_variables <- function (importance_frame , k = 15 ,
257+ measures = names(importance_frame )[2 : min(5 , ncol(importance_frame ))],
178258 ties_action = " all" ){
179259 if (" randomForest" %in% class(importance_frame )){
180260 importance_frame <- measure_importance(importance_frame )
181261 if (" predicted" %in% measures ){
182262 measures <- names(importance_frame )[2 : 5 ]
183263 }
264+ } else if (" ranger" %in% class(importance_frame )){
265+ importance_frame <- measure_importance(importance_frame )
184266 }
185267 rankings <- data.frame (variable = importance_frame $ variable , mean_min_depth =
186268 frankv(importance_frame $ mean_min_depth , ties.method = " dense" ),
@@ -232,7 +314,7 @@ plot_multi_way_importance <- function(importance_frame, x_measure = "mean_min_de
232314 min_no_of_trees = 0 , no_of_labels = 10 ,
233315 main = " Multi-way importance plot" ){
234316 variable <- NULL
235- if (" randomForest" %in% class(importance_frame )){
317+ if (any(c( " randomForest" , " ranger " ) %in% class(importance_frame ) )){
236318 importance_frame <- measure_importance(importance_frame )
237319 }
238320 data <- importance_frame [importance_frame $ no_of_trees > min_no_of_trees , ]
@@ -294,14 +376,16 @@ plot_multi_way_importance <- function(importance_frame, x_measure = "mean_min_de
294376# ' plot_importance_ggpairs(frame, measures = c("mean_min_depth", "times_a_root"))
295377# '
296378# ' @export
297- plot_importance_ggpairs <- function (importance_frame , measures =
298- names(importance_frame )[c(2 , 4 , 5 , 3 , 7 )],
379+ plot_importance_ggpairs <- function (importance_frame , measures = NULL ,
299380 main = " Relations between measures of importance" ){
300- if (" randomForest" %in% class(importance_frame )){
381+ if (any(c( " randomForest" , " ranger " ) %in% class(importance_frame ) )){
301382 importance_frame <- measure_importance(importance_frame )
302- if (" predicted" %in% measures ){
303- names(importance_frame )[c(2 , 4 , 5 , 3 , 7 )]
304- }
383+ }
384+ if (is.null(measures )){
385+ default_measures <- c(" gini_decrease" , " node_purity_increase" , # randomForest
386+ " impurity" , " impurity_corrected" , " permutation" , # ranger
387+ " mean_min_depth" , " no_of_trees" , " no_of_nodes" , " p_value" )
388+ measures <- intersect(default_measures , colnames(importance_frame ))
305389 }
306390 plot <- ggpairs(importance_frame [, measures ]) + theme_bw()
307391 if (! is.null(main )){
@@ -315,7 +399,7 @@ plot_importance_ggpairs <- function(importance_frame, measures =
315399# ' Plot against each other rankings of variables according to various measures of importance
316400# '
317401# ' @param importance_frame A result of using the function measure_importance() to a random forest or a randomForest object
318- # ' @param measures A character vector specifying the measures of importance to be used
402+ # ' @param measures A character vector specifying the measures of importance to be used.
319403# ' @param main A string to be used as title of the plot
320404# '
321405# ' @return A ggplot object
@@ -329,22 +413,29 @@ plot_importance_ggpairs <- function(importance_frame, measures =
329413# ' plot_importance_ggpairs(frame, measures = c("mean_min_depth", "times_a_root"))
330414# '
331415# ' @export
332- plot_importance_rankings <- function (importance_frame , measures =
333- names(importance_frame )[c(2 , 4 , 5 , 3 , 7 )],
416+ plot_importance_rankings <- function (importance_frame , measures = NULL ,
334417 main = " Relations between rankings according to different measures" ){
335- if (" randomForest" %in% class(importance_frame )){
418+ if (any(c( " randomForest" , " ranger " ) %in% class(importance_frame ) )){
336419 importance_frame <- measure_importance(importance_frame )
337- if (" predicted" %in% measures ){
338- names(importance_frame )[c(2 , 4 , 5 , 3 , 7 )]
339- }
340420 }
341- rankings <- data.frame (variable = importance_frame $ variable , mean_min_depth =
342- frankv(importance_frame $ mean_min_depth , ties.method = " dense" ),
343- apply(importance_frame [, - c(1 , 2 )], 2 ,
421+ if (is.null(measures )){
422+ default_measures <- c(" gini_decrease" , " node_purity_increase" , # randomForest
423+ " impurity" , " impurity_corrected" , " permutation" , # ranger
424+ " mean_min_depth" , " no_of_trees" , " no_of_nodes" , " p_value" )
425+ measures <- intersect(default_measures , colnames(importance_frame ))
426+ }
427+ rankings <- data.frame (variable = importance_frame $ variable ,
428+ apply(importance_frame [, ! colnames(importance_frame ) %in% c(" variable" , " mean_min_depth" , " p_value" )], 2 ,
344429 function (x ) frankv(x , order = - 1 , ties.method = " dense" )))
430+ if (" mean_min_depth" %in% measures ){
431+ rankings $ mean_min_depth = frankv(importance_frame $ mean_min_depth , ties.method = " dense" )
432+ }
433+ if (" p_value" %in% measures ){
434+ rankings $ p_value = frankv(importance_frame $ p_value , ties.method = " dense" )
435+ }
345436 plot <- ggpairs(rankings [, measures ], lower = list (continuous = function (data , mapping ){
346437 ggplot(data = data , mapping = mapping ) + geom_point() + geom_smooth(method = " loess" )
347- }))+ theme_bw()
438+ })) + theme_bw()
348439 if (! is.null(main )){
349440 plot <- plot + ggtitle(main )
350441 }
0 commit comments