Skip to content

Commit 17f0a70

Browse files
Yue-Jiangpbiecek
authored andcommitted
add tests, passing along maintainer role (#11)
* make `plot_min_depth_distribution` work with ranger forests * first pass adding ranger support for min_depth_interactions, untested * first pass adding ranger support for multi way importance, untested * further clean up to the point explain_forest works for ranger * unsupervised randomForest should be supported by all functions except plot_predict_interaction * add my email address * unfinished work adding tests * first pass ranfomForest tests * add tests for ranger and fix an issue with interactions for ranger * add .travis.yml * pass along maintainer role
1 parent 32ca649 commit 17f0a70

File tree

10 files changed

+244
-18
lines changed

10 files changed

+244
-18
lines changed

.travis.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# R for travis: see documentation at https://docs.travis-ci.com/user/languages/r
2+
3+
language: r
4+
r:
5+
- release
6+
- devel
7+
sudo: false
8+
cache: packages
9+
10+
r_packages:
11+
- covr
12+
13+
after_success:
14+
- Rscript -e 'library(covr); codecov()'

DESCRIPTION

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ Package: randomForestExplainer
22
Title: Explaining and Visualizing Random Forests in Terms of Variable Importance
33
Version: 0.9
44
Authors@R: c(
5-
person("Aleksandra", "Paluszynska", email = "[email protected]", role = c("aut", "cre")),
5+
person("Aleksandra", "Paluszynska", email = "[email protected]", role = c("aut")),
66
person("Przemyslaw", "Biecek", email = "[email protected]", role = c("aut","ths")),
7-
person("Yue", "Jiang", role = "aut")
7+
person("Yue", "Jiang", email = "[email protected]", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-9798-5517"))
88
)
99
Description: A set of tools to help explain which variables are most important in a random forests. Various variable importance measures are calculated and visualized in different settings in order to get an idea on how their importance changes depending on our criteria (Hemant Ishwaran and Udaya B. Kogalur and Eiran Z. Gorodeski and Andy J. Minn and Michael S. Lauer (2010) <doi:10.1198/jasa.2009.tm08622>, Leo Breiman (2001) <doi:10.1023/A:1010933404324>).
1010
Depends: R (>= 3.0)
@@ -25,7 +25,8 @@ Imports:
2525
reshape2 (>= 1.4.2),
2626
rmarkdown (>= 1.5)
2727
Suggests:
28-
knitr
28+
knitr,
29+
testthat
2930
VignetteBuilder: knitr
3031
RoxygenNote: 6.1.1
3132
URL: https://github.com/ModelOriented/randomForestExplainer

R/explain_forest.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ explain_forest <- function(forest, interactions = FALSE, data = NULL, vars = NUL
2525
measures = NULL){
2626
if(is.null(measures)){
2727
if("randomForest" %in% class(forest)){
28-
if(forest$type == "classification"){
28+
if(forest$type %in% c("classification", "unsupervised")){
2929
measures <- c("mean_min_depth", "accuracy_decrease", "gini_decrease", "no_of_nodes", "times_a_root")
3030
} else{
3131
measures <- c("mean_min_depth", "mse_increase", "node_purity_increase", "no_of_nodes", "times_a_root")
@@ -36,7 +36,7 @@ explain_forest <- function(forest, interactions = FALSE, data = NULL, vars = NUL
3636
}
3737
if("randomForest" %in% class(forest) && dim(forest$importance)[2] == 1){
3838
stop(paste("Your forest does not contain information on local importance so",
39-
ifelse(forest$type == "classification", "accuracy_decrease", "mse_increase"),
39+
ifelse(forest$type %in% c("classification", "unsupervised"), "accuracy_decrease", "mse_increase"),
4040
"measure cannot be extracted.",
4141
"To add it regrow the forest with the option localImp = TRUE and run this function again."))
4242
}

R/measure_importance.R

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ measure_no_of_nodes_ranger <- function(forest_table){
3131
# Extract randomForest variable importance measures
3232
# randomForest
3333
measure_vimp <- function(forest, only_nonlocal = FALSE){
34-
if(forest$type == "classification"){
34+
if(forest$type %in% c("classification", "unsupervised")){
3535
if(dim(forest$importance)[2] == 1){
3636
if(only_nonlocal == FALSE){
3737
print("Warning: your forest does not contain information on local importance so 'accuracy_decrease' measure cannot be extracted. To add it regrow the forest with the option localImp = TRUE and run this function again.")
@@ -129,7 +129,7 @@ measure_importance <- function(forest, mean_sample = "top_trees", measures = NUL
129129
measure_importance.randomForest <- function(forest, mean_sample = "top_trees", measures = NULL){
130130
tree <- NULL; `split var` <- NULL; depth <- NULL
131131
if(is.null(measures)){
132-
if(forest$type == "classification"){
132+
if(forest$type %in% c("classification", "unsupervised")){
133133
measures <- c("mean_min_depth", "no_of_nodes", "accuracy_decrease",
134134
"gini_decrease", "no_of_trees", "times_a_root", "p_value")
135135
} else if(forest$type =="regression"){
@@ -143,6 +143,9 @@ measure_importance.randomForest <- function(forest, mean_sample = "top_trees", m
143143
importance_frame <- data.frame(variable = rownames(forest$importance), stringsAsFactors = FALSE)
144144
# Get objects necessary to calculate importance measures based on the tree structure
145145
if(any(c("mean_min_depth", "no_of_nodes", "no_of_trees", "times_a_root", "p_value") %in% measures)){
146+
if (is.null(forest$forest)) {
147+
stop("Make sure forest has been saved when calling randomForest by randomForest(..., keep.forest = TRUE).")
148+
}
146149
forest_table <-
147150
lapply(1:forest$ntree, function(i) randomForest::getTree(forest, k = i, labelVar = T) %>%
148151
calculate_tree_depth() %>% cbind(tree = i)) %>% rbindlist()
@@ -159,7 +162,7 @@ measure_importance.randomForest <- function(forest, mean_sample = "top_trees", m
159162
importance_frame <- merge(importance_frame, measure_no_of_nodes(forest_table), all = TRUE)
160163
importance_frame[is.na(importance_frame$no_of_nodes), "no_of_nodes"] <- 0
161164
}
162-
if(forest$type == "classification"){
165+
if(forest$type %in% c("classification", "unsupervised")){
163166
vimp <- c("accuracy_decrease", "gini_decrease")
164167
} else if(forest$type =="regression"){
165168
vimp <- c("mse_increase", "node_purity_increase")

R/min_depth_interactions.R

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ conditional_depth_ranger <- function(frame, vars){
3939
df <- frame[begin:nrow(frame), setdiff(names(frame), setdiff(vars, j))]
4040
df[[j]][1] <- 0
4141
for(k in 2:nrow(df)){
42-
if(length(df[(!is.na(df[, "leftChild"]) & df[, "leftChild"] == as.numeric(df[k, "number"])) |
43-
(!is.na(df[, "rightChild"]) & df[, "rightChild"] == as.numeric(df[k, "number"])), j]) != 0){
42+
if(length(df[(!is.na(df[, "leftChild"]) & df[, "leftChild"] == as.numeric(df[k, "number"]) - 1) |
43+
(!is.na(df[, "rightChild"]) & df[, "rightChild"] == as.numeric(df[k, "number"]) - 1), j]) != 0){
4444
df[k, j] <-
45-
df[(!is.na(df[, "leftChild"]) & df[, "leftChild"] == as.numeric(df[k, "number"])) |
46-
(!is.na(df[, "rightChild"]) & df[, "rightChild"] == as.numeric(df[k, "number"])), j] + 1
45+
df[(!is.na(df[, "leftChild"]) & df[, "leftChild"] == as.numeric(df[k, "number"]) - 1) |
46+
(!is.na(df[, "rightChild"]) & df[, "rightChild"] == as.numeric(df[k, "number"]) - 1), j] + 1
4747
}
4848
}
4949
frame[begin:nrow(frame), setdiff(names(frame), setdiff(vars, j))] <- df
@@ -68,7 +68,7 @@ min_depth_interactions_values <- function(forest, vars){
6868
mean_tree_depth <- dplyr::group_by(interactions_frame[, c("tree", vars)], tree) %>%
6969
dplyr::summarize_at(vars, funs(max(., na.rm = TRUE))) %>% as.data.frame()
7070
mean_tree_depth[mean_tree_depth == -Inf] <- NA
71-
mean_tree_depth <- colMeans(mean_tree_depth[, vars], na.rm = TRUE)
71+
mean_tree_depth <- colMeans(mean_tree_depth[, vars, drop = FALSE], na.rm = TRUE)
7272
min_depth_interactions_frame <-
7373
interactions_frame %>% dplyr::group_by(tree, `split var`) %>%
7474
dplyr::summarize_at(vars, funs(min(., na.rm = TRUE))) %>% as.data.frame()
@@ -93,7 +93,7 @@ min_depth_interactions_values_ranger <- function(forest, vars){
9393
mean_tree_depth <- dplyr::group_by(interactions_frame[, c("tree", vars)], tree) %>%
9494
dplyr::summarize_at(vars, funs(max(., na.rm = TRUE))) %>% as.data.frame()
9595
mean_tree_depth[mean_tree_depth == -Inf] <- NA
96-
mean_tree_depth <- colMeans(mean_tree_depth[, vars], na.rm = TRUE)
96+
mean_tree_depth <- colMeans(mean_tree_depth[, vars, drop = FALSE], na.rm = TRUE)
9797
min_depth_interactions_frame <-
9898
interactions_frame %>% dplyr::group_by(tree, splitvarName) %>%
9999
dplyr::summarize_at(vars, funs(min(., na.rm = TRUE))) %>% as.data.frame()
@@ -146,15 +146,15 @@ min_depth_interactions.randomForest <- function(forest, vars = important_variabl
146146
non_occurrences[, -1] <- forest$ntree - occurrences[, -1]
147147
interactions_frame[is.na(as.matrix(interactions_frame))] <- 0
148148
interactions_frame[, -1] <- (interactions_frame[, -1] * occurrences[, -1] +
149-
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth))/forest$ntree
149+
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth, nrow = length(mean_tree_depth)))/forest$ntree
150150
} else if(mean_sample == "top_trees"){
151151
non_occurrences <- occurrences
152152
non_occurrences[, -1] <- forest$ntree - occurrences[, -1]
153153
minimum_non_occurrences <- min(non_occurrences[, -1])
154154
non_occurrences[, -1] <- non_occurrences[, -1] - minimum_non_occurrences
155155
interactions_frame[is.na(as.matrix(interactions_frame))] <- 0
156156
interactions_frame[, -1] <- (interactions_frame[, -1] * occurrences[, -1] +
157-
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth))/(forest$ntree - minimum_non_occurrences)
157+
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth, nrow = length(mean_tree_depth)))/(forest$ntree - minimum_non_occurrences)
158158
}
159159
interactions_frame <- reshape2::melt(interactions_frame, id.vars = "variable")
160160
colnames(interactions_frame)[2:3] <- c("root_variable", "mean_min_depth")
@@ -195,15 +195,15 @@ min_depth_interactions.ranger <- function(forest, vars = important_variables(mea
195195
non_occurrences[, -1] <- forest$num.trees - occurrences[, -1]
196196
interactions_frame[is.na(as.matrix(interactions_frame))] <- 0
197197
interactions_frame[, -1] <- (interactions_frame[, -1] * occurrences[, -1] +
198-
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth))/forest$num.trees
198+
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth, nrow = length(mean_tree_depth)))/forest$num.trees
199199
} else if(mean_sample == "top_trees"){
200200
non_occurrences <- occurrences
201201
non_occurrences[, -1] <- forest$num.trees - occurrences[, -1]
202202
minimum_non_occurrences <- min(non_occurrences[, -1])
203203
non_occurrences[, -1] <- non_occurrences[, -1] - minimum_non_occurrences
204204
interactions_frame[is.na(as.matrix(interactions_frame))] <- 0
205205
interactions_frame[, -1] <- (interactions_frame[, -1] * occurrences[, -1] +
206-
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth))/(forest$num.trees - minimum_non_occurrences)
206+
as.matrix(non_occurrences[, -1]) %*% diag(mean_tree_depth, nrow = length(mean_tree_depth)))/(forest$num.trees - minimum_non_occurrences)
207207
}
208208
interactions_frame <- reshape2::melt(interactions_frame, id.vars = "variable")
209209
colnames(interactions_frame)[2:3] <- c("root_variable", "mean_min_depth")
@@ -303,6 +303,10 @@ plot_predict_interaction.randomForest <- function(forest, data, variable1, varia
303303
main = paste0("Prediction of the forest for different values of ",
304304
paste0(variable1, paste0(" and ", variable2))),
305305
time = NULL){
306+
if (forest$type == "unsupervised") {
307+
warning("plot_predict_interaction cannot be performed on unsupervised random forests.")
308+
return(NULL)
309+
}
306310
newdata <- expand.grid(seq(min(data[[variable1]]), max(data[[variable1]]), length.out = grid),
307311
seq(min(data[[variable2]]), max(data[[variable2]]), length.out = grid))
308312
colnames(newdata) <- c(variable1, variable2)
9.21 KB
Binary file not shown.

inst/testdata/test_ranger.rda

2.45 KB
Binary file not shown.

tests/testthat.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
library(testthat)
2+
library(randomForestExplainer)
3+
4+
test_check("randomForestExplainer")

tests/testthat/test_randomForest.R

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
library(randomForest)
2+
library(dplyr)
3+
load(system.file("testdata/test_randomForest.rda", package="randomForestExplainer", mustWork=TRUE))
4+
# Test input generated by:
5+
# library(randomForest)
6+
# set.seed(12345)
7+
# rf_c <- randomForest(Species ~ ., data = iris, localImp = TRUE, ntree = 2)
8+
# rf_r <- randomForest(mpg ~ ., data = mtcars, localImp = TRUE, ntree = 2)
9+
# rf_u <- randomForest(x = iris, keep.forest = TRUE, localImp = TRUE, ntree = 2)
10+
# save(rf_c, rf_r, rf_u, file = "inst/testdata/test_randomForest.rda")
11+
12+
13+
context("Test randomForest classification forests")
14+
15+
test_that("measure_importance works", {
16+
imp_df <- measure_importance(rf_c, mean_sample = "all_trees",
17+
measures = c("mean_min_depth", "accuracy_decrease", "gini_decrease",
18+
"no_of_nodes", "times_a_root", "p_value"))
19+
expect_equal(as.character(imp_df$variable),
20+
c("Petal.Length", "Petal.Width", "Sepal.Length", "Sepal.Width"))
21+
})
22+
23+
test_that("important_variables works", {
24+
imp_vars <- important_variables(rf_c, k = 3,
25+
measures = c("mean_min_depth", "accuracy_decrease", "gini_decrease",
26+
"no_of_nodes", "times_a_root", "p_value"))
27+
expect_equal(imp_vars, c("Petal.Width", "Petal.Length", "Sepal.Length"))
28+
})
29+
30+
test_that("min_depth_distribution works", {
31+
min_depth_dist <- min_depth_distribution(rf_c)
32+
expect_equivalent(min_depth_dist[min_depth_dist$tree == 1 & min_depth_dist$variable == "Petal.Width", ]$minimal_depth,
33+
0)
34+
})
35+
36+
test_that("min_depth_interactions works", {
37+
min_depth_int <- min_depth_interactions(rf_c, vars = c("Petal.Width"))
38+
expect_equivalent(min_depth_int[min_depth_int$interaction == "Petal.Width:Sepal.Length", ]$mean_min_depth,
39+
0)
40+
})
41+
42+
43+
context("Test randomForest regression forests")
44+
45+
test_that("measure_importance works", {
46+
imp_df <- measure_importance(rf_r, mean_sample = "all_trees",
47+
measures = c("mean_min_depth", "mse_increase", "node_purity_increase",
48+
"no_of_nodes", "times_a_root", "p_value"))
49+
expect_equal(as.character(imp_df$variable),
50+
c("am", "carb", "cyl", "disp", "drat", "gear", "hp", "qsec", "vs", "wt"))
51+
})
52+
53+
test_that("important_variables works", {
54+
imp_vars <- important_variables(rf_r, k = 3,
55+
measures = c("mean_min_depth", "mse_increase", "node_purity_increase",
56+
"no_of_nodes", "times_a_root", "p_value"))
57+
expect_equal(imp_vars, c("cyl", "disp", "hp", "wt"))
58+
})
59+
60+
test_that("min_depth_distribution works", {
61+
min_depth_dist <- min_depth_distribution(rf_r)
62+
expect_equivalent(min_depth_dist[min_depth_dist$tree == 1 & min_depth_dist$variable == "cyl", ]$minimal_depth,
63+
0)
64+
})
65+
66+
test_that("min_depth_interactions works", {
67+
min_depth_int <- min_depth_interactions(rf_r, vars = c("cyl"))
68+
expect_equivalent(min_depth_int[min_depth_int$interaction == "cyl:wt", ]$mean_min_depth,
69+
1)
70+
})
71+
72+
73+
context("Test randomForest unsupervised forests")
74+
75+
test_that("measure_importance works", {
76+
imp_df <- measure_importance(rf_u, mean_sample = "all_trees",
77+
measures = c("mean_min_depth", "accuracy_decrease", "gini_decrease",
78+
"no_of_nodes", "times_a_root", "p_value"))
79+
expect_equal(as.character(imp_df$variable),
80+
c("Petal.Length", "Petal.Width", "Sepal.Length", "Sepal.Width", "Species"))
81+
})
82+
83+
test_that("important_variables works", {
84+
imp_vars <- important_variables(rf_u, k = 3,
85+
measures = c("mean_min_depth", "accuracy_decrease", "gini_decrease",
86+
"no_of_nodes", "times_a_root", "p_value"))
87+
expect_equal(imp_vars, c("Petal.Length", "Sepal.Length", "Species"))
88+
})
89+
90+
test_that("min_depth_distribution works", {
91+
min_depth_dist <- min_depth_distribution(rf_u)
92+
expect_equivalent(min_depth_dist[min_depth_dist$tree == 1 & min_depth_dist$variable == "Sepal.Width", ]$minimal_depth,
93+
0)
94+
})
95+
96+
test_that("min_depth_interactions works", {
97+
min_depth_int <- min_depth_interactions(rf_u, vars = c("Petal.Width"))
98+
expect_equivalent(min_depth_int[min_depth_int$interaction == "Petal.Width:Sepal.Length", ]$mean_min_depth,
99+
1)
100+
})

0 commit comments

Comments
 (0)