Skip to content

Commit 3fbce5a

Browse files
committed
first pass adding ranger support for min_depth_interactions, untested
1 parent fb35f7f commit 3fbce5a

File tree

1 file changed

+178
-10
lines changed

1 file changed

+178
-10
lines changed

R/min_depth_interactions.R

Lines changed: 178 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Calculate conditional depth in a tree with respect to all variables from vector vars
2+
# randomForest
23
conditional_depth <- function(frame, vars){
34
`.SD` <- NULL; depth <- NULL; `split var` <- NULL
45
index <- data.table::as.data.table(frame)[, .SD[which.min(depth), "number"], by = `split var`]
@@ -25,7 +26,36 @@ conditional_depth <- function(frame, vars){
2526
return(frame)
2627
}
2728

29+
# Calculate conditional depth in a tree with respect to all variables from vector vars
30+
# ranger
31+
conditional_depth_ranger <- function(frame, vars){
32+
`.SD` <- NULL; depth <- NULL; splitvarName <- NULL
33+
index <- data.table::as.data.table(frame)[, .SD[which.min(depth), "number"], by = splitvarName]
34+
index <- index[!is.na(index$splitvarName), ]
35+
if(any(index$splitvarName %in% vars)){
36+
for(j in vars){
37+
begin <- as.numeric(index[index$splitvarName == j, "number"])
38+
if(!is.na(begin)){
39+
df <- frame[begin:nrow(frame), setdiff(names(frame), setdiff(vars, j))]
40+
df[[j]][1] <- 0
41+
for(k in 2:nrow(df)){
42+
if(length(df[df[, "leftChild"] == as.numeric(df[k, "number"]) |
43+
df[, "rightChild"] == as.numeric(df[k, "number"]), j]) != 0){
44+
df[k, j] <-
45+
df[df[, "leftChild"] == as.numeric(df[k, "number"]) |
46+
df[, "rightChild"] == as.numeric(df[k, "number"]), j] + 1
47+
}
48+
}
49+
frame[begin:nrow(frame), setdiff(names(frame), setdiff(vars, j))] <- df
50+
}
51+
}
52+
}
53+
frame[frame == 0] <- NA
54+
return(frame)
55+
}
56+
2857
# Get a data frame with values of minimal depth conditional on selected variables for the whole forest
58+
# randomForest
2959
min_depth_interactions_values <- function(forest, vars){
3060
`.` <- NULL; .SD <- NULL; tree <- NULL; `split var` <- NULL
3161
interactions_frame <-
@@ -49,6 +79,31 @@ min_depth_interactions_values <- function(forest, vars){
4979
return(list(min_depth_interactions_frame, mean_tree_depth))
5080
}
5181

82+
# Get a data frame with values of minimal depth conditional on selected variables for the whole forest
83+
# ranger
84+
min_depth_interactions_values_ranger <- function(forest, vars){
85+
`.` <- NULL; .SD <- NULL; tree <- NULL; splitvarName <- NULL
86+
interactions_frame <-
87+
lapply(1:forest$num.trees, function(i) ranger::treeInfo(forest, k = i, labelVar = T) %>%
88+
calculate_tree_depth() %>% cbind(., tree = i, number = 1:nrow(.))) %>%
89+
data.table::rbindlist() %>% as.data.frame()
90+
interactions_frame[vars] <- as.numeric(NA)
91+
interactions_frame <-
92+
data.table::as.data.table(interactions_frame)[, conditional_depth_ranger(as.data.frame(.SD), vars), by = tree] %>% as.data.frame()
93+
mean_tree_depth <- dplyr::group_by(interactions_frame[, c("tree", vars)], tree) %>%
94+
dplyr::summarize_at(vars, funs(max(., na.rm = TRUE))) %>% as.data.frame()
95+
mean_tree_depth[mean_tree_depth == -Inf] <- NA
96+
mean_tree_depth <- colMeans(mean_tree_depth[, vars], na.rm = TRUE)
97+
min_depth_interactions_frame <-
98+
interactions_frame %>% dplyr::group_by(tree, `split var`) %>%
99+
dplyr::summarize_at(vars, funs(min(., na.rm = TRUE))) %>% as.data.frame()
100+
min_depth_interactions_frame[min_depth_interactions_frame == Inf] <- NA
101+
min_depth_interactions_frame <- min_depth_interactions_frame[!is.na(min_depth_interactions_frame$`split var`), ]
102+
colnames(min_depth_interactions_frame)[2] <- "variable"
103+
min_depth_interactions_frame[, -c(1:2)] <- min_depth_interactions_frame[, -c(1:2)] - 1
104+
return(list(min_depth_interactions_frame, mean_tree_depth))
105+
}
106+
52107
#' Calculate mean conditional minimal depth
53108
#'
54109
#' Calculate mean conditional minimal depth with respect to a vector of variables
@@ -60,15 +115,19 @@ min_depth_interactions_values <- function(forest, vars){
60115
#'
61116
#' @return A data frame with each observarion giving the means of conditional minimal depth and the size of sample for a given interaction
62117
#'
63-
#' @import dplyr
64-
#' @importFrom data.table rbindlist
65-
#'
66118
#' @examples
67119
#' forest <- randomForest::randomForest(Species ~ ., data = iris, ntree = 100)
68120
#' min_depth_interactions(forest, c("Petal.Width", "Petal.Length"))
69121
#'
70122
#' @export
71-
min_depth_interactions <- function(forest, vars = important_variables(measure_importance(forest)),
123+
min_depth_interactions <- function(){
124+
UseMethod("min_depth_interactions")
125+
}
126+
127+
#' @import dplyr
128+
#' @importFrom data.table rbindlist
129+
#' @export
130+
min_depth_interactions.randomForest <- function(forest, vars = important_variables(measure_importance(forest)),
72131
mean_sample = "top_trees", uncond_mean_sample = mean_sample){
73132
variable <- NULL; `.` <- NULL; tree <- NULL; `split var` <- NULL; depth <- NULL
74133
min_depth_interactions_frame <- min_depth_interactions_values(forest, vars)
@@ -114,6 +173,55 @@ min_depth_interactions <- function(forest, vars = important_variables(measure_im
114173
interactions_frame <- merge(interactions_frame, importance_frame)
115174
}
116175

176+
#' @import dplyr
177+
#' @importFrom data.table rbindlist
178+
#' @export
179+
min_depth_interactions.ranger <- function(forest, vars = important_variables(measure_importance(forest)),
180+
mean_sample = "top_trees", uncond_mean_sample = mean_sample){
181+
variable <- NULL; `.` <- NULL; tree <- NULL; splitvarName <- NULL; depth <- NULL
182+
min_depth_interactions_frame <- min_depth_interactions_values_ranger(forest, vars)
183+
mean_tree_depth <- min_depth_interactions_frame[[2]]
184+
min_depth_interactions_frame <- min_depth_interactions_frame[[1]]
185+
interactions_frame <-
186+
min_depth_interactions_frame %>% dplyr::group_by(variable) %>%
187+
dplyr::summarize_at(vars, funs(mean(., na.rm = TRUE))) %>% as.data.frame()
188+
interactions_frame[is.na(as.matrix(interactions_frame))] <- NA
189+
occurrences <-
190+
min_depth_interactions_frame %>% dplyr::group_by(variable) %>%
191+
dplyr::summarize_at(vars, funs(sum(!is.na(.)))) %>% as.data.frame()
192+
if(mean_sample == "all_trees"){
193+
non_occurrences <- occurrences
194+
non_occurrences[, -1] <- forest$num.trees - occurrences[, -1]
195+
interactions_frame[is.na(as.matrix(interactions_frame))] <- 0
196+
interactions_frame[, -1] <- (interactions_frame[, -1] * occurrences[, -1] +
197+
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth))/forest$num.trees
198+
} else if(mean_sample == "top_trees"){
199+
non_occurrences <- occurrences
200+
non_occurrences[, -1] <- forest$num.trees - occurrences[, -1]
201+
minimum_non_occurrences <- min(non_occurrences[, -1])
202+
non_occurrences[, -1] <- non_occurrences[, -1] - minimum_non_occurrences
203+
interactions_frame[is.na(as.matrix(interactions_frame))] <- 0
204+
interactions_frame[, -1] <- (interactions_frame[, -1] * occurrences[, -1] +
205+
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth))/(forest$num.trees - minimum_non_occurrences)
206+
}
207+
interactions_frame <- reshape2::melt(interactions_frame, id.vars = "variable")
208+
colnames(interactions_frame)[2:3] <- c("root_variable", "mean_min_depth")
209+
occurrences <- reshape2::melt(occurrences, id.vars = "variable")
210+
colnames(occurrences)[2:3] <- c("root_variable", "occurrences")
211+
interactions_frame <- merge(interactions_frame, occurrences)
212+
interactions_frame$interaction <- paste(interactions_frame$root_variable, interactions_frame$variable, sep = ":")
213+
forest_table <-
214+
lapply(1:forest$num.trees, function(i) ranger::treeInfo(forest, tree = i) %>%
215+
calculate_tree_depth_ranger() %>% cbind(tree = i)) %>% rbindlist()
216+
min_depth_frame <- dplyr::group_by(forest_table, tree, splitvarName) %>%
217+
dplyr::summarize(min(depth))
218+
colnames(min_depth_frame) <- c("tree", "variable", "minimal_depth")
219+
min_depth_frame <- as.data.frame(min_depth_frame[!is.na(min_depth_frame$variable),])
220+
importance_frame <- get_min_depth_means_ranger(min_depth_frame, min_depth_count_ranger(min_depth_frame), uncond_mean_sample)
221+
colnames(importance_frame)[2] <- "uncond_mean_min_depth"
222+
interactions_frame <- merge(interactions_frame, importance_frame)
223+
}
224+
117225
#' Plot the top mean conditional minimal depth
118226
#'
119227
#' @param interactions_frame A data frame produced by the min_depth_interactions() function or a randomForest object
@@ -133,7 +241,7 @@ plot_min_depth_interactions <- function(interactions_frame, k = 30,
133241
main = paste0("Mean minimal depth for ",
134242
paste0(k, " most frequent interactions"))){
135243
mean_min_depth <- NULL; occurrences <- NULL; uncond_mean_min_depth <- NULL
136-
if("randomForest" %in% class(interactions_frame)){
244+
if(any(c("randomForest", "ranger") %in% class(interactions_frame))){
137245
interactions_frame <- min_depth_interactions(interactions_frame)
138246
}
139247
interactions_frame$interaction <- factor(interactions_frame$interaction, levels =
@@ -168,19 +276,27 @@ plot_min_depth_interactions <- function(interactions_frame, k = 30,
168276
#'
169277
#' @return A ggplot2 object
170278
#'
171-
#' @import ggplot2
172-
#' @importFrom stats predict
173-
#' @importFrom stats terms
174-
#' @importFrom stats as.formula
175-
#'
176279
#' @examples
177280
#' forest <- randomForest::randomForest(Species ~., data = iris)
178281
#' plot_predict_interaction(forest, iris, "Petal.Width", "Sepal.Width")
282+
#' forest_ranger <- ranger::ranger(Species ~., data = iris)
283+
#' plot_predict_interaction(forest, iris, "Petal.Width", "Sepal.Width")
179284
#'
180285
#' @export
181286
plot_predict_interaction <- function(forest, data, variable1, variable2, grid = 100,
182287
main = paste0("Prediction of the forest for different values of ",
183288
paste0(variable1, paste0(" and ", variable2)))){
289+
UseMethod("plot_predict_interaction")
290+
}
291+
292+
#' @import ggplot2
293+
#' @importFrom stats predict
294+
#' @importFrom stats terms
295+
#' @importFrom stats as.formula
296+
#' @export
297+
plot_predict_interaction.randomForest <- function(forest, data, variable1, variable2, grid = 100,
298+
main = paste0("Prediction of the forest for different values of ",
299+
paste0(variable1, paste0(" and ", variable2)))){
184300
newdata <- expand.grid(seq(min(data[[variable1]]), max(data[[variable1]]), length.out = grid),
185301
seq(min(data[[variable2]]), max(data[[variable2]]), length.out = grid))
186302
colnames(newdata) <- c(variable1, variable2)
@@ -219,3 +335,55 @@ plot_predict_interaction <- function(forest, data, variable1, variable2, grid =
219335
}
220336
return(plot)
221337
}
338+
339+
#' @import ggplot2
340+
#' @importFrom stats predict
341+
#' @importFrom stats terms
342+
#' @importFrom stats as.formula
343+
#' @export
344+
plot_predict_interaction.ranger <- function(forest, data, variable1, variable2, grid = 100,
345+
main = paste0("Prediction of the forest for different values of ",
346+
paste0(variable1, paste0(" and ", variable2)))){
347+
newdata <- expand.grid(seq(min(data[[variable1]]), max(data[[variable1]]), length.out = grid),
348+
seq(min(data[[variable2]]), max(data[[variable2]]), length.out = grid))
349+
colnames(newdata) <- c(variable1, variable2)
350+
if(as.character(forest$call[[2]])[3] == "."){
351+
other_vars <- setdiff(names(data), as.character(forest$call[[2]])[2])
352+
} else {
353+
other_vars <- labels(terms(as.formula(forest$call[[2]])))
354+
}
355+
other_vars <- setdiff(other_vars, c(variable1, variable2))
356+
n <- nrow(data)
357+
for(i in other_vars){
358+
newdata[[i]] <- data[[i]][sample(1:n, nrow(newdata), replace = TRUE)]
359+
}
360+
if(forest$treetype == "Regression"){
361+
newdata$prediction <- predict(forest, newdata, type = "response")
362+
plot <- ggplot(newdata, aes_string(x = variable1, y = variable2, fill = "prediction")) +
363+
geom_raster() + theme_bw() +
364+
scale_fill_gradient2(midpoint = min(newdata$prediction) + 0.5 * (max(newdata$prediction) - min(newdata$prediction)),
365+
low = "blue", high = "red")
366+
} else if(forest$treetype == "Probability estimation"){
367+
id_vars <- colnames(newdata)
368+
pred <- predict(forest, newdata)$predictions
369+
if(ncol(pred) == 2){
370+
newdata[, paste0("probability_", colnames(pred)[-1])] <- pred[, -1]
371+
} else {
372+
newdata[, paste0("probability_", colnames(pred))] <- pred
373+
}
374+
newdata <- reshape2::melt(newdata, id.vars = id_vars)
375+
newdata$prediction <- newdata$value
376+
plot <- ggplot(newdata, aes_string(x = variable1, y = variable2, fill = "prediction")) +
377+
geom_raster() + theme_bw() + facet_wrap(~ variable) +
378+
scale_fill_gradient2(midpoint = min(newdata$prediction) + 0.5 * (max(newdata$prediction) - min(newdata$prediction)),
379+
low = "blue", high = "red")
380+
} else if(forest$treetype == "Classification") {
381+
stop("Ranger forest for classification needs to be generated by ranger(..., probability = TRUE).")
382+
} else {
383+
stop(sprintf("Ranger forest type '%s' is currently not supported.", forest$treetype))
384+
}
385+
if(!is.null(main)){
386+
plot <- plot + ggtitle(main)
387+
}
388+
return(plot)
389+
}

0 commit comments

Comments
 (0)