Skip to content

Commit 7b5aa10

Browse files
Split measure_importance into smaller functions, add possibility of supplying the forest to plotting functions
1 parent 4930f51 commit 7b5aa10

15 files changed

+551
-85
lines changed

R/explain_forest.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
#' @import DT
1616
#'
1717
#' @examples
18-
#' explain_forest(randomForest::randomForest(Species ~ ., data = iris, localImp = TRUE), vars = names(iris), interactions = TRUE)
18+
#' forest <- randomForest::randomForest(Species ~ ., data = iris, localImp = TRUE)
19+
#' explain_forest(forest, vars = names(iris), interactions = TRUE)
1920
#'
2021
#' @export
2122
explain_forest <- function(forest, interactions = FALSE, data = NULL, vars = NULL, no_of_pred_plots = 3, pred_grid = 100,

R/measure_importance.R

Lines changed: 163 additions & 51 deletions
Large diffs are not rendered by default.

R/min_depth_distribution.R

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Calculate the depth of each node in a single tree obtained from a forest with randomForest::getTree
22
calculate_tree_depth <- function(frame){
3-
if(!is.data.frame(frame)) stop("The object is not a data frame!")
43
if(!all(c("right daughter", "left daughter") %in% names(frame))){
54
stop("The data frame has to contain columns called 'right daughter' and 'left daughter'!
65
It should be a product of the function getTree(..., labelVar = T).")
@@ -31,7 +30,6 @@ calculate_tree_depth <- function(frame){
3130
#'
3231
#' @export
3332
min_depth_distribution <- function(forest){
34-
if(!("randomForest" %in% class(forest))) stop("The object you supplied is not a random forest!")
3533
forest_table <-
3634
lapply(1:forest$ntree, function(i) randomForest::getTree(forest, k = i, labelVar = T) %>%
3735
calculate_tree_depth() %>% cbind(tree = i)) %>% rbindlist()
@@ -84,7 +82,7 @@ get_min_depth_means <- function(min_depth_frame, min_depth_count_list, mean_samp
8482

8583
#' Plot the distribution of minimal depth in a random forest
8684
#'
87-
#' @param min_depth_frame A data frame output of min_depth_distribution function
85+
#' @param min_depth_frame A data frame output of min_depth_distribution function or a randomForest object
8886
#' @param k The maximal number of variables with lowest mean minimal depth to be used for plotting
8987
#' @param min_no_of_trees The minimal number of trees in which a variable has to be used for splitting to be used for plotting
9088
#' @param mean_sample The sample of trees on which mean minimal depth is calculated, possible values are "all_trees", "top_trees", "relevant_trees"
@@ -98,12 +96,16 @@ get_min_depth_means <- function(min_depth_frame, min_depth_count_list, mean_samp
9896
#' @import dplyr
9997
#'
10098
#' @examples
101-
#' plot_min_depth_distribution(min_depth_distribution(randomForest::randomForest(Species ~ ., data = iris)))
99+
#' forest <- randomForest::randomForest(Species ~ ., data = iris)
100+
#' plot_min_depth_distribution(min_depth_distribution(forest))
102101
#'
103102
#' @export
104103
plot_min_depth_distribution <- function(min_depth_frame, k = 10, min_no_of_trees = 0,
105104
mean_sample = "top_trees", mean_scale = FALSE, mean_round = 2,
106105
main = "Distribution of minimal depth and its mean"){
106+
if("randomForest" %in% class(min_depth_frame)){
107+
min_depth_frame <- min_depth_distribution(min_depth_frame)
108+
}
107109
min_depth_count_list <- min_depth_count(min_depth_frame)
108110
min_depth_means <- get_min_depth_means(min_depth_frame, min_depth_count_list, mean_sample)
109111
frame_with_means <- merge(min_depth_count_list[[1]], min_depth_means)

R/min_depth_interactions.R

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ conditional_depth <- function(frame, vars){
2626

2727
# Get a data frame with values of minimal depth conditional on selected variables for the whole forest
2828
min_depth_interactions_values <- function(forest, vars){
29-
if(!("randomForest" %in% class(forest))) stop("The object you supplied is not a random forest!")
3029
interactions_frame <-
3130
lapply(1:forest$ntree, function(i) randomForest::getTree(forest, k = i, labelVar = T) %>%
3231
calculate_tree_depth() %>% cbind(., tree = i, number = 1:nrow(.))) %>%
@@ -53,7 +52,7 @@ min_depth_interactions_values <- function(forest, vars){
5352
#' Calculate mean conditional minimal depth with respect to a vector of variables
5453
#'
5554
#' @param forest A randomForest object
56-
#' @param vars A character vector with variables with respect to which conditional minimal depth will be calculated
55+
#' @param vars A character vector with variables with respect to which conditional minimal depth will be calculated; by defalt it is extracted by the important_variables function but this may be time consuming
5756
#' @param mean_sample The sample of trees on which conditional mean minimal depth is calculated, possible values are "all_trees", "top_trees", "relevant_trees"
5857
#' @param uncond_mean_sample The sample of trees on which unconditional mean minimal depth is calculated, possible values are "all_trees", "top_trees", "relevant_trees"
5958
#'
@@ -63,12 +62,12 @@ min_depth_interactions_values <- function(forest, vars){
6362
#' @importFrom data.table rbindlist
6463
#'
6564
#' @examples
66-
#' min_depth_interactions(randomForest::randomForest(Species ~ ., data = iris), vars = names(iris))
65+
#' forest <- randomForest::randomForest(Species ~ ., data = iris)
66+
#' min_depth_interactions(forest, names(iris))
6767
#'
6868
#' @export
69-
min_depth_interactions <- function(forest, vars, mean_sample = "top_trees",
70-
uncond_mean_sample = mean_sample){
71-
if(!("randomForest" %in% class(forest))) stop("The object you supplied is not a random forest!")
69+
min_depth_interactions <- function(forest, vars = important_variables(measure_importance(forest)),
70+
mean_sample = "top_trees", uncond_mean_sample = mean_sample){
7271
min_depth_interactions_frame <- min_depth_interactions_values(forest, vars)
7372
mean_tree_depth <- min_depth_interactions_frame[[2]]
7473
min_depth_interactions_frame <- min_depth_interactions_frame[[1]]
@@ -114,7 +113,7 @@ min_depth_interactions <- function(forest, vars, mean_sample = "top_trees",
114113

115114
#' Plot the top mean conditional minimal depth
116115
#'
117-
#' @param interactions_frame A data frame produced by the min_depth_interactions_means() function
116+
#' @param interactions_frame A data frame produced by the min_depth_interactions() function or a randomForest object
118117
#' @param k The number of best interactions to plot, if set to NULL then all plotted
119118
#' @param main A string to be used as title of the plot
120119
#'
@@ -123,12 +122,16 @@ min_depth_interactions <- function(forest, vars, mean_sample = "top_trees",
123122
#' @import ggplot2
124123
#'
125124
#' @examples
126-
#' plot_min_depth_interactions(min_depth_interactions(randomForest::randomForest(Species ~ ., data = iris), vars = names(iris)))
125+
#' forest <- randomForest::randomForest(Species ~ ., data = iris)
126+
#' plot_min_depth_interactions(min_depth_interactions(forest, names(iris)))
127127
#'
128128
#' @export
129129
plot_min_depth_interactions <- function(interactions_frame, k = 30,
130130
main = paste0("Mean minimal depth for ",
131131
paste0(k, " most frequent interactions"))){
132+
if("randomForest" %in% class(interactions_frame)){
133+
interactions_frame <- min_depth_interactions(interactions_frame)
134+
}
132135
interactions_frame$interaction <- factor(interactions_frame$interaction, levels =
133136
interactions_frame[
134137
order(interactions_frame$occurrences, decreasing = TRUE), "interaction"])
@@ -164,7 +167,8 @@ plot_min_depth_interactions <- function(interactions_frame, k = 30,
164167
#' @import ggplot2
165168
#'
166169
#' @examples
167-
#' plot_predict_interaction(randomForest::randomForest(Species ~., data = iris), iris, "Petal.Width", "Sepal.Width")
170+
#' forest <- randomForest::randomForest(Species ~., data = iris)
171+
#' plot_predict_interaction(forest, iris, "Petal.Width", "Sepal.Width")
168172
#'
169173
#' @export
170174
plot_predict_interaction <- function(forest, data, variable1, variable2, grid = 100,

man/explain_forest.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/important_variables.Rd

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/measure_importance.Rd

Lines changed: 5 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/min_depth_interactions.Rd

Lines changed: 6 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/plot_importance_ggpairs.Rd

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/plot_importance_rankings.Rd

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)