@@ -26,7 +26,6 @@ conditional_depth <- function(frame, vars){
2626
2727# Get a data frame with values of minimal depth conditional on selected variables for the whole forest
2828min_depth_interactions_values <- function (forest , vars ){
29- if (! (" randomForest" %in% class(forest ))) stop(" The object you supplied is not a random forest!" )
3029 interactions_frame <-
3130 lapply(1 : forest $ ntree , function (i ) randomForest :: getTree(forest , k = i , labelVar = T ) %> %
3231 calculate_tree_depth() %> % cbind(. , tree = i , number = 1 : nrow(. ))) %> %
@@ -53,7 +52,7 @@ min_depth_interactions_values <- function(forest, vars){
5352# ' Calculate mean conditional minimal depth with respect to a vector of variables
5453# '
5554# ' @param forest A randomForest object
56- # ' @param vars A character vector with variables with respect to which conditional minimal depth will be calculated
55+ # ' @param vars A character vector with variables with respect to which conditional minimal depth will be calculated; by defalt it is extracted by the important_variables function but this may be time consuming
5756# ' @param mean_sample The sample of trees on which conditional mean minimal depth is calculated, possible values are "all_trees", "top_trees", "relevant_trees"
5857# ' @param uncond_mean_sample The sample of trees on which unconditional mean minimal depth is calculated, possible values are "all_trees", "top_trees", "relevant_trees"
5958# '
@@ -63,12 +62,12 @@ min_depth_interactions_values <- function(forest, vars){
6362# ' @importFrom data.table rbindlist
6463# '
6564# ' @examples
66- # ' min_depth_interactions(randomForest::randomForest(Species ~ ., data = iris), vars = names(iris))
65+ # ' forest <- randomForest::randomForest(Species ~ ., data = iris)
66+ # ' min_depth_interactions(forest, names(iris))
6767# '
6868# ' @export
69- min_depth_interactions <- function (forest , vars , mean_sample = " top_trees" ,
70- uncond_mean_sample = mean_sample ){
71- if (! (" randomForest" %in% class(forest ))) stop(" The object you supplied is not a random forest!" )
69+ min_depth_interactions <- function (forest , vars = important_variables(measure_importance(forest )),
70+ mean_sample = " top_trees" , uncond_mean_sample = mean_sample ){
7271 min_depth_interactions_frame <- min_depth_interactions_values(forest , vars )
7372 mean_tree_depth <- min_depth_interactions_frame [[2 ]]
7473 min_depth_interactions_frame <- min_depth_interactions_frame [[1 ]]
@@ -114,7 +113,7 @@ min_depth_interactions <- function(forest, vars, mean_sample = "top_trees",
114113
115114# ' Plot the top mean conditional minimal depth
116115# '
117- # ' @param interactions_frame A data frame produced by the min_depth_interactions_means () function
116+ # ' @param interactions_frame A data frame produced by the min_depth_interactions () function or a randomForest object
118117# ' @param k The number of best interactions to plot, if set to NULL then all plotted
119118# ' @param main A string to be used as title of the plot
120119# '
@@ -123,12 +122,16 @@ min_depth_interactions <- function(forest, vars, mean_sample = "top_trees",
123122# ' @import ggplot2
124123# '
125124# ' @examples
126- # ' plot_min_depth_interactions(min_depth_interactions(randomForest::randomForest(Species ~ ., data = iris), vars = names(iris)))
125+ # ' forest <- randomForest::randomForest(Species ~ ., data = iris)
126+ # ' plot_min_depth_interactions(min_depth_interactions(forest, names(iris)))
127127# '
128128# ' @export
129129plot_min_depth_interactions <- function (interactions_frame , k = 30 ,
130130 main = paste0(" Mean minimal depth for " ,
131131 paste0(k , " most frequent interactions" ))){
132+ if (" randomForest" %in% class(interactions_frame )){
133+ interactions_frame <- min_depth_interactions(interactions_frame )
134+ }
132135 interactions_frame $ interaction <- factor (interactions_frame $ interaction , levels =
133136 interactions_frame [
134137 order(interactions_frame $ occurrences , decreasing = TRUE ), " interaction" ])
@@ -164,7 +167,8 @@ plot_min_depth_interactions <- function(interactions_frame, k = 30,
164167# ' @import ggplot2
165168# '
166169# ' @examples
167- # ' plot_predict_interaction(randomForest::randomForest(Species ~., data = iris), iris, "Petal.Width", "Sepal.Width")
170+ # ' forest <- randomForest::randomForest(Species ~., data = iris)
171+ # ' plot_predict_interaction(forest, iris, "Petal.Width", "Sepal.Width")
168172# '
169173# ' @export
170174plot_predict_interaction <- function (forest , data , variable1 , variable2 , grid = 100 ,
0 commit comments