11# Calculate conditional depth in a tree with respect to all variables from vector vars
2+ # randomForest
23conditional_depth <- function (frame , vars ){
34 `.SD` <- NULL ; depth <- NULL ; `split var` <- NULL
45 index <- data.table :: as.data.table(frame )[, .SD [which.min(depth ), " number" ], by = `split var` ]
@@ -25,7 +26,36 @@ conditional_depth <- function(frame, vars){
2526 return (frame )
2627}
2728
29+ # Calculate conditional depth in a tree with respect to all variables from vector vars
30+ # ranger
31+ conditional_depth_ranger <- function (frame , vars ){
32+ `.SD` <- NULL ; depth <- NULL ; splitvarName <- NULL
33+ index <- data.table :: as.data.table(frame )[, .SD [which.min(depth ), " number" ], by = splitvarName ]
34+ index <- index [! is.na(index $ splitvarName ), ]
35+ if (any(index $ splitvarName %in% vars )){
36+ for (j in vars ){
37+ begin <- as.numeric(index [index $ splitvarName == j , " number" ])
38+ if (! is.na(begin )){
39+ df <- frame [begin : nrow(frame ), setdiff(names(frame ), setdiff(vars , j ))]
40+ df [[j ]][1 ] <- 0
41+ for (k in 2 : nrow(df )){
42+ if (length(df [df [, " leftChild" ] == as.numeric(df [k , " number" ]) |
43+ df [, " rightChild" ] == as.numeric(df [k , " number" ]), j ]) != 0 ){
44+ df [k , j ] <-
45+ df [df [, " leftChild" ] == as.numeric(df [k , " number" ]) |
46+ df [, " rightChild" ] == as.numeric(df [k , " number" ]), j ] + 1
47+ }
48+ }
49+ frame [begin : nrow(frame ), setdiff(names(frame ), setdiff(vars , j ))] <- df
50+ }
51+ }
52+ }
53+ frame [frame == 0 ] <- NA
54+ return (frame )
55+ }
56+
2857# Get a data frame with values of minimal depth conditional on selected variables for the whole forest
58+ # randomForest
2959min_depth_interactions_values <- function (forest , vars ){
3060 `.` <- NULL ; .SD <- NULL ; tree <- NULL ; `split var` <- NULL
3161 interactions_frame <-
@@ -49,6 +79,31 @@ min_depth_interactions_values <- function(forest, vars){
4979 return (list (min_depth_interactions_frame , mean_tree_depth ))
5080}
5181
82+ # Get a data frame with values of minimal depth conditional on selected variables for the whole forest
83+ # ranger
84+ min_depth_interactions_values_ranger <- function (forest , vars ){
85+ `.` <- NULL ; .SD <- NULL ; tree <- NULL ; splitvarName <- NULL
86+ interactions_frame <-
87+ lapply(1 : forest $ num.trees , function (i ) ranger :: treeInfo(forest , k = i , labelVar = T ) %> %
88+ calculate_tree_depth() %> % cbind(. , tree = i , number = 1 : nrow(. ))) %> %
89+ data.table :: rbindlist() %> % as.data.frame()
90+ interactions_frame [vars ] <- as.numeric(NA )
91+ interactions_frame <-
92+ data.table :: as.data.table(interactions_frame )[, conditional_depth_ranger(as.data.frame(.SD ), vars ), by = tree ] %> % as.data.frame()
93+ mean_tree_depth <- dplyr :: group_by(interactions_frame [, c(" tree" , vars )], tree ) %> %
94+ dplyr :: summarize_at(vars , funs(max(. , na.rm = TRUE ))) %> % as.data.frame()
95+ mean_tree_depth [mean_tree_depth == - Inf ] <- NA
96+ mean_tree_depth <- colMeans(mean_tree_depth [, vars ], na.rm = TRUE )
97+ min_depth_interactions_frame <-
98+ interactions_frame %> % dplyr :: group_by(tree , `split var` ) %> %
99+ dplyr :: summarize_at(vars , funs(min(. , na.rm = TRUE ))) %> % as.data.frame()
100+ min_depth_interactions_frame [min_depth_interactions_frame == Inf ] <- NA
101+ min_depth_interactions_frame <- min_depth_interactions_frame [! is.na(min_depth_interactions_frame $ `split var` ), ]
102+ colnames(min_depth_interactions_frame )[2 ] <- " variable"
103+ min_depth_interactions_frame [, - c(1 : 2 )] <- min_depth_interactions_frame [, - c(1 : 2 )] - 1
104+ return (list (min_depth_interactions_frame , mean_tree_depth ))
105+ }
106+
52107# ' Calculate mean conditional minimal depth
53108# '
54109# ' Calculate mean conditional minimal depth with respect to a vector of variables
@@ -60,15 +115,19 @@ min_depth_interactions_values <- function(forest, vars){
60115# '
61116# ' @return A data frame with each observarion giving the means of conditional minimal depth and the size of sample for a given interaction
62117# '
63- # ' @import dplyr
64- # ' @importFrom data.table rbindlist
65- # '
66118# ' @examples
67119# ' forest <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100)
68120# ' min_depth_interactions(forest, c("Petal.Width", "Petal.Length"))
69121# '
70122# ' @export
71- min_depth_interactions <- function (forest , vars = important_variables(measure_importance(forest )),
123+ min_depth_interactions <- function (){
124+ UseMethod(" min_depth_interactions" )
125+ }
126+
127+ # ' @import dplyr
128+ # ' @importFrom data.table rbindlist
129+ # ' @export
130+ min_depth_interactions.randomForest <- function (forest , vars = important_variables(measure_importance(forest )),
72131 mean_sample = " top_trees" , uncond_mean_sample = mean_sample ){
73132 variable <- NULL ; `.` <- NULL ; tree <- NULL ; `split var` <- NULL ; depth <- NULL
74133 min_depth_interactions_frame <- min_depth_interactions_values(forest , vars )
@@ -114,6 +173,55 @@ min_depth_interactions <- function(forest, vars = important_variables(measure_im
114173 interactions_frame <- merge(interactions_frame , importance_frame )
115174}
116175
176+ # ' @import dplyr
177+ # ' @importFrom data.table rbindlist
178+ # ' @export
179+ min_depth_interactions.ranger <- function (forest , vars = important_variables(measure_importance(forest )),
180+ mean_sample = " top_trees" , uncond_mean_sample = mean_sample ){
181+ variable <- NULL ; `.` <- NULL ; tree <- NULL ; splitvarName <- NULL ; depth <- NULL
182+ min_depth_interactions_frame <- min_depth_interactions_values_ranger(forest , vars )
183+ mean_tree_depth <- min_depth_interactions_frame [[2 ]]
184+ min_depth_interactions_frame <- min_depth_interactions_frame [[1 ]]
185+ interactions_frame <-
186+ min_depth_interactions_frame %> % dplyr :: group_by(variable ) %> %
187+ dplyr :: summarize_at(vars , funs(mean(. , na.rm = TRUE ))) %> % as.data.frame()
188+ interactions_frame [is.na(as.matrix(interactions_frame ))] <- NA
189+ occurrences <-
190+ min_depth_interactions_frame %> % dplyr :: group_by(variable ) %> %
191+ dplyr :: summarize_at(vars , funs(sum(! is.na(. )))) %> % as.data.frame()
192+ if (mean_sample == " all_trees" ){
193+ non_occurrences <- occurrences
194+ non_occurrences [, - 1 ] <- forest $ num.trees - occurrences [, - 1 ]
195+ interactions_frame [is.na(as.matrix(interactions_frame ))] <- 0
196+ interactions_frame [, - 1 ] <- (interactions_frame [, - 1 ] * occurrences [, - 1 ] +
197+ as.matrix(non_occurrences [, - 1 ]) %*% diag(mean_tree_depth ))/ forest $ num.trees
198+ } else if (mean_sample == " top_trees" ){
199+ non_occurrences <- occurrences
200+ non_occurrences [, - 1 ] <- forest $ num.trees - occurrences [, - 1 ]
201+ minimum_non_occurrences <- min(non_occurrences [, - 1 ])
202+ non_occurrences [, - 1 ] <- non_occurrences [, - 1 ] - minimum_non_occurrences
203+ interactions_frame [is.na(as.matrix(interactions_frame ))] <- 0
204+ interactions_frame [, - 1 ] <- (interactions_frame [, - 1 ] * occurrences [, - 1 ] +
205+ as.matrix(non_occurrences [, - 1 ]) %*% diag(mean_tree_depth ))/ (forest $ num.trees - minimum_non_occurrences )
206+ }
207+ interactions_frame <- reshape2 :: melt(interactions_frame , id.vars = " variable" )
208+ colnames(interactions_frame )[2 : 3 ] <- c(" root_variable" , " mean_min_depth" )
209+ occurrences <- reshape2 :: melt(occurrences , id.vars = " variable" )
210+ colnames(occurrences )[2 : 3 ] <- c(" root_variable" , " occurrences" )
211+ interactions_frame <- merge(interactions_frame , occurrences )
212+ interactions_frame $ interaction <- paste(interactions_frame $ root_variable , interactions_frame $ variable , sep = " :" )
213+ forest_table <-
214+ lapply(1 : forest $ num.trees , function (i ) ranger :: treeInfo(forest , tree = i ) %> %
215+ calculate_tree_depth_ranger() %> % cbind(tree = i )) %> % rbindlist()
216+ min_depth_frame <- dplyr :: group_by(forest_table , tree , splitvarName ) %> %
217+ dplyr :: summarize(min(depth ))
218+ colnames(min_depth_frame ) <- c(" tree" , " variable" , " minimal_depth" )
219+ min_depth_frame <- as.data.frame(min_depth_frame [! is.na(min_depth_frame $ variable ),])
220+ importance_frame <- get_min_depth_means_ranger(min_depth_frame , min_depth_count_ranger(min_depth_frame ), uncond_mean_sample )
221+ colnames(importance_frame )[2 ] <- " uncond_mean_min_depth"
222+ interactions_frame <- merge(interactions_frame , importance_frame )
223+ }
224+
117225# ' Plot the top mean conditional minimal depth
118226# '
119227# ' @param interactions_frame A data frame produced by the min_depth_interactions() function or a randomForest object
@@ -133,7 +241,7 @@ plot_min_depth_interactions <- function(interactions_frame, k = 30,
133241 main = paste0(" Mean minimal depth for " ,
134242 paste0(k , " most frequent interactions" ))){
135243 mean_min_depth <- NULL ; occurrences <- NULL ; uncond_mean_min_depth <- NULL
136- if (" randomForest" %in% class(interactions_frame )){
244+ if (any(c( " randomForest" , " ranger " ) %in% class(interactions_frame ) )){
137245 interactions_frame <- min_depth_interactions(interactions_frame )
138246 }
139247 interactions_frame $ interaction <- factor (interactions_frame $ interaction , levels =
@@ -168,19 +276,27 @@ plot_min_depth_interactions <- function(interactions_frame, k = 30,
168276# '
169277# ' @return A ggplot2 object
170278# '
171- # ' @import ggplot2
172- # ' @importFrom stats predict
173- # ' @importFrom stats terms
174- # ' @importFrom stats as.formula
175- # '
176279# ' @examples
177280# ' forest <- randomForest::randomForest(Species ~., data = iris)
178281# ' plot_predict_interaction(forest, iris, "Petal.Width", "Sepal.Width")
282+ # ' forest_ranger <- ranger::ranger(Species ~., data = iris)
283+ # ' plot_predict_interaction(forest, iris, "Petal.Width", "Sepal.Width")
179284# '
180285# ' @export
181286plot_predict_interaction <- function (forest , data , variable1 , variable2 , grid = 100 ,
182287 main = paste0(" Prediction of the forest for different values of " ,
183288 paste0(variable1 , paste0(" and " , variable2 )))){
289+ UseMethod(" plot_predict_interaction" )
290+ }
291+
292+ # ' @import ggplot2
293+ # ' @importFrom stats predict
294+ # ' @importFrom stats terms
295+ # ' @importFrom stats as.formula
296+ # ' @export
297+ plot_predict_interaction.randomForest <- function (forest , data , variable1 , variable2 , grid = 100 ,
298+ main = paste0(" Prediction of the forest for different values of " ,
299+ paste0(variable1 , paste0(" and " , variable2 )))){
184300 newdata <- expand.grid(seq(min(data [[variable1 ]]), max(data [[variable1 ]]), length.out = grid ),
185301 seq(min(data [[variable2 ]]), max(data [[variable2 ]]), length.out = grid ))
186302 colnames(newdata ) <- c(variable1 , variable2 )
@@ -219,3 +335,55 @@ plot_predict_interaction <- function(forest, data, variable1, variable2, grid =
219335 }
220336 return (plot )
221337}
338+
339+ # ' @import ggplot2
340+ # ' @importFrom stats predict
341+ # ' @importFrom stats terms
342+ # ' @importFrom stats as.formula
343+ # ' @export
344+ plot_predict_interaction.ranger <- function (forest , data , variable1 , variable2 , grid = 100 ,
345+ main = paste0(" Prediction of the forest for different values of " ,
346+ paste0(variable1 , paste0(" and " , variable2 )))){
347+ newdata <- expand.grid(seq(min(data [[variable1 ]]), max(data [[variable1 ]]), length.out = grid ),
348+ seq(min(data [[variable2 ]]), max(data [[variable2 ]]), length.out = grid ))
349+ colnames(newdata ) <- c(variable1 , variable2 )
350+ if (as.character(forest $ call [[2 ]])[3 ] == " ." ){
351+ other_vars <- setdiff(names(data ), as.character(forest $ call [[2 ]])[2 ])
352+ } else {
353+ other_vars <- labels(terms(as.formula(forest $ call [[2 ]])))
354+ }
355+ other_vars <- setdiff(other_vars , c(variable1 , variable2 ))
356+ n <- nrow(data )
357+ for (i in other_vars ){
358+ newdata [[i ]] <- data [[i ]][sample(1 : n , nrow(newdata ), replace = TRUE )]
359+ }
360+ if (forest $ treetype == " Regression" ){
361+ newdata $ prediction <- predict(forest , newdata , type = " response" )
362+ plot <- ggplot(newdata , aes_string(x = variable1 , y = variable2 , fill = " prediction" )) +
363+ geom_raster() + theme_bw() +
364+ scale_fill_gradient2(midpoint = min(newdata $ prediction ) + 0.5 * (max(newdata $ prediction ) - min(newdata $ prediction )),
365+ low = " blue" , high = " red" )
366+ } else if (forest $ treetype == " Probability estimation" ){
367+ id_vars <- colnames(newdata )
368+ pred <- predict(forest , newdata )$ predictions
369+ if (ncol(pred ) == 2 ){
370+ newdata [, paste0(" probability_" , colnames(pred )[- 1 ])] <- pred [, - 1 ]
371+ } else {
372+ newdata [, paste0(" probability_" , colnames(pred ))] <- pred
373+ }
374+ newdata <- reshape2 :: melt(newdata , id.vars = id_vars )
375+ newdata $ prediction <- newdata $ value
376+ plot <- ggplot(newdata , aes_string(x = variable1 , y = variable2 , fill = " prediction" )) +
377+ geom_raster() + theme_bw() + facet_wrap(~ variable ) +
378+ scale_fill_gradient2(midpoint = min(newdata $ prediction ) + 0.5 * (max(newdata $ prediction ) - min(newdata $ prediction )),
379+ low = " blue" , high = " red" )
380+ } else if (forest $ treetype == " Classification" ) {
381+ stop(" Ranger forest for classification needs to be generated by ranger(..., probability = TRUE)." )
382+ } else {
383+ stop(sprintf(" Ranger forest type '%s' is currently not supported." , forest $ treetype ))
384+ }
385+ if (! is.null(main )){
386+ plot <- plot + ggtitle(main )
387+ }
388+ return (plot )
389+ }
0 commit comments