1
1
test_that(" hybrid_policy_tree works as expected" , {
2
2
n <- 500
3
3
p <- 2
4
- d <- 2
4
+ d <- 42
5
5
X <- round(matrix (rnorm(n * p ), n , p ), 1 )
6
6
Y <- matrix (runif(n * d ), n , d )
7
7
@@ -25,6 +25,17 @@ test_that("hybrid_policy_tree works as expected", {
25
25
}
26
26
}
27
27
expect_equal(hpp , pp )
28
+ # is internally consistent with predicted node ids.
29
+ hpp.node <- predict(htree , X , type = " node.id" )
30
+ values <- aggregate(Y , by = list (leaf.node = hpp.node ), FUN = mean )
31
+ best <- apply(values [, - 1 ], 1 , FUN = which.max )
32
+ expect_equal(best [match(hpp.node , values [, 1 ])], hpp )
33
+
34
+ # uses node labels that are the same as the printed labels.
35
+ printed.node.id <- lapply(seq_along(htree $ nodes ), function (i ) {
36
+ if (htree $ nodes [[i ]]$ is_leaf ) i
37
+ })
38
+ expect_true(all(unique(hpp.node ) %in% printed.node.id ))
28
39
29
40
# search.depth = 1 when a single split is optimal is identical to a depth 1 policy_tree
30
41
n <- 250
@@ -102,7 +113,7 @@ test_that("hybrid_policy_tree utils are internally consistent", {
102
113
103
114
depth <- 2
104
115
tree <- policy_tree(X , Y , depth = depth )
105
- expect_equal(tree_mat (tree [[" nodes" ]], depth ), tree [[" _tree_array" ]], tolerance = 1e-16 )
116
+ expect_equal(convert_nodes (tree [[" nodes" ]], depth )[[ 2 ]] , tree [[" _tree_array" ]], tolerance = 1e-16 )
106
117
tree.nodes <- lapply(seq_along(tree $ nodes ), function (i ) {
107
118
node <- tree $ nodes [[i ]]
108
119
node $ has_subtree <- FALSE
@@ -112,7 +123,7 @@ test_that("hybrid_policy_tree utils are internally consistent", {
112
123
113
124
depth <- 3
114
125
tree <- policy_tree(X , Y , depth = depth )
115
- expect_equal(tree_mat (tree [[" nodes" ]], depth ), tree [[" _tree_array" ]], tolerance = 1e-16 )
126
+ expect_equal(convert_nodes (tree [[" nodes" ]], depth )[[ 2 ]] , tree [[" _tree_array" ]], tolerance = 1e-16 )
116
127
tree.nodes <- lapply(seq_along(tree $ nodes ), function (i ) {
117
128
node <- tree $ nodes [[i ]]
118
129
node $ has_subtree <- FALSE
0 commit comments