Skip to content

Commit 57594f9

Browse files
committed
first pass ranfomForest tests
1 parent bed77b6 commit 57594f9

File tree

3 files changed

+86
-17
lines changed

3 files changed

+86
-17
lines changed

R/measure_importance.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ measure_importance.randomForest <- function(forest, mean_sample = "top_trees", m
143143
importance_frame <- data.frame(variable = rownames(forest$importance), stringsAsFactors = FALSE)
144144
# Get objects necessary to calculate importance measures based on the tree structure
145145
if(any(c("mean_min_depth", "no_of_nodes", "no_of_trees", "times_a_root", "p_value") %in% measures)){
146+
if (is.null(forest$forest)) {
147+
stop("Make sure forest has been saved when calling randomForest by randomForest(..., keep.forest = TRUE).")
148+
}
146149
forest_table <-
147150
lapply(1:forest$ntree, function(i) randomForest::getTree(forest, k = i, labelVar = T) %>%
148151
calculate_tree_depth() %>% cbind(tree = i)) %>% rbindlist()
9.21 KB
Binary file not shown.

tests/testthat/test_randomForest.R

Lines changed: 83 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,100 @@
11
library(randomForest)
22
library(dplyr)
3-
set.seed(12345)
3+
load(system.file("testdata/test_randomForest.rda", package="randomForestExplainer", mustWork=TRUE))
4+
# Test input generated by:
5+
# library(randomForest)
6+
# set.seed(12345)
7+
# rf_c <- randomForest(Species ~ ., data = iris, localImp = TRUE, ntree = 2)
8+
# rf_r <- randomForest(mpg ~ ., data = mtcars, localImp = TRUE, ntree = 2)
9+
# rf_u <- randomForest(x = iris, keep.forest = TRUE, localImp = TRUE, ntree = 2)
10+
# save(rf_c, rf_r, rf_u, file = "inst/testdata/test_randomForest.rda")
11+
412

513
context("Test randomForest classification forests")
6-
forest <- randomForest(Species ~ ., data = iris, localImp = TRUE, ntree = 2)
714

815
test_that("measure_importance works", {
9-
imp_df <- measure_importance(forest, mean_sample = "all_trees",
10-
measures = c("mean_min_depth","accuracy_decrease",
11-
"gini_decrease", "no_of_nodes", "times_a_root"))
12-
expect_equal(imp_df$variable, c("Petal.Length", "Petal.Width", "Sepal.Length", "Sepal.Width"))
16+
imp_df <- measure_importance(rf_c, mean_sample = "all_trees",
17+
measures = c("mean_min_depth", "accuracy_decrease", "gini_decrease",
18+
"no_of_nodes", "times_a_root", "p_value"))
19+
expect_equal(as.character(imp_df$variable),
20+
c("Petal.Length", "Petal.Width", "Sepal.Length", "Sepal.Width"))
1321
})
1422

1523
test_that("important_variables works", {
16-
imp_vars <- important_variables(forest, k = 3,
17-
measures = c("mean_min_depth","accuracy_decrease",
18-
"gini_decrease", "no_of_nodes", "times_a_root"))
24+
imp_vars <- important_variables(rf_c, k = 3,
25+
measures = c("mean_min_depth", "accuracy_decrease", "gini_decrease",
26+
"no_of_nodes", "times_a_root", "p_value"))
1927
expect_equal(imp_vars, c("Petal.Width", "Petal.Length", "Sepal.Length"))
2028
})
2129

2230
test_that("min_depth_distribution works", {
23-
min_depth_dist <- min_depth_distribution(forest)
24-
print(min_depth_dist)
25-
expect_equivalent(arrange(min_depth_dist, tree, minimal_depth, variable),
26-
data.frame("tree" = c(1, 1, 1, 1, 2, 2, 2),
27-
"variable"=c("Petal.Width", "Sepal.Length", "Petal.Length", "Sepal.Width", "Petal.Width", "Sepal.Length", "Petal.Length"),
28-
"minimal_depth"=c(0, 1, 2, 4, 0, 1, 3)))
31+
min_depth_dist <- min_depth_distribution(rf_c)
32+
expect_equivalent(min_depth_dist[min_depth_dist$tree == 1 & min_depth_dist$variable == "Petal.Width", ]$minimal_depth,
33+
0)
34+
})
35+
36+
test_that("min_depth_interactions works", {
37+
min_depth_int <- min_depth_interactions(rf_c, vars = c("Petal.Width"))
38+
expect_equivalent(min_depth_int[min_depth_int$interaction == "Petal.Width:Sepal.Length", ]$mean_min_depth,
39+
0)
40+
})
41+
42+
43+
context("Test randomForest regression forests")
44+
45+
test_that("measure_importance works", {
46+
imp_df <- measure_importance(rf_r, mean_sample = "all_trees",
47+
measures = c("mean_min_depth", "mse_increase", "node_purity_increase",
48+
"no_of_nodes", "times_a_root", "p_value"))
49+
expect_equal(as.character(imp_df$variable),
50+
c("am", "carb", "cyl", "disp", "drat", "gear", "hp", "qsec", "vs", "wt"))
51+
})
52+
53+
test_that("important_variables works", {
54+
imp_vars <- important_variables(rf_r, k = 3,
55+
measures = c("mean_min_depth", "mse_increase", "node_purity_increase",
56+
"no_of_nodes", "times_a_root", "p_value"))
57+
expect_equal(imp_vars, c("cyl", "disp", "hp", "wt"))
58+
})
59+
60+
test_that("min_depth_distribution works", {
61+
min_depth_dist <- min_depth_distribution(rf_r)
62+
expect_equivalent(min_depth_dist[min_depth_dist$tree == 1 & min_depth_dist$variable == "cyl", ]$minimal_depth,
63+
0)
64+
})
65+
66+
test_that("min_depth_interactions works", {
67+
min_depth_int <- min_depth_interactions(rf_r, vars = c("cyl"))
68+
expect_equivalent(min_depth_int[min_depth_int$interaction == "cyl:wt", ]$mean_min_depth,
69+
1)
70+
})
71+
72+
73+
context("Test randomForest unsupervised forests")
74+
75+
test_that("measure_importance works", {
76+
imp_df <- measure_importance(rf_u, mean_sample = "all_trees",
77+
measures = c("mean_min_depth", "accuracy_decrease", "gini_decrease",
78+
"no_of_nodes", "times_a_root", "p_value"))
79+
expect_equal(as.character(imp_df$variable),
80+
c("Petal.Length", "Petal.Width", "Sepal.Length", "Sepal.Width", "Species"))
81+
})
82+
83+
test_that("important_variables works", {
84+
imp_vars <- important_variables(rf_u, k = 3,
85+
measures = c("mean_min_depth", "accuracy_decrease", "gini_decrease",
86+
"no_of_nodes", "times_a_root", "p_value"))
87+
expect_equal(imp_vars, c("Petal.Length", "Sepal.Length", "Species"))
88+
})
89+
90+
test_that("min_depth_distribution works", {
91+
min_depth_dist <- min_depth_distribution(rf_u)
92+
expect_equivalent(min_depth_dist[min_depth_dist$tree == 1 & min_depth_dist$variable == "Sepal.Width", ]$minimal_depth,
93+
0)
2994
})
3095

3196
test_that("min_depth_interactions works", {
32-
min_depth_int <- min_depth_interactions(forest, vars = c("Petal.Width"))
33-
expect_equal(as.character(min_depth_int$variable), c("Petal.Length", "Petal.Width", "Sepal.Length", "Sepal.Width"))
97+
min_depth_int <- min_depth_interactions(rf_u, vars = c("Petal.Width"))
98+
expect_equivalent(min_depth_int[min_depth_int$interaction == "Petal.Width:Sepal.Length", ]$mean_min_depth,
99+
1)
34100
})

0 commit comments

Comments
 (0)