Skip to content

Commit 45d2688

Browse files
authored
Update printed hybrid tree node labels (#156)
Update node labels
1 parent 841fdf8 commit 45d2688

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

r-package/policytree/R/hybrid_policy_tree.R

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,10 @@ hybrid_policy_tree <- function(X, Gamma,
120120
node <- node + 1
121121
}
122122

123-
tree[["nodes"]] <- unpack_tree(tree.nodes)
124-
tree[["_tree_array"]] <- tree_mat(tree[["nodes"]], depth)
123+
unpacked.nodes <- unpack_tree(tree.nodes)
124+
converted.nodes <- convert_nodes(unpacked.nodes, depth)
125+
tree[["nodes"]] <- converted.nodes[[1]]
126+
tree[["_tree_array"]] <- converted.nodes[[2]]
125127
tree[["depth"]] <- depth
126128

127129
tree
@@ -177,9 +179,10 @@ unpack_tree <- function(tree) {
177179
nodes
178180
}
179181

180-
# Convert an adjacency list to array for predictions (see Rcppbindigs.cpp for details).
181-
# The 5th column is just the node label according to the print() order a hybrid tree has.
182-
tree_mat <- function(nodes, depth) {
182+
# 1) Convert tree to array for predictions (see Rcppbindings.cpp for details)
183+
# and 2) use the same breadth-first node ordering (new.nodes) as the rest of policytree.
184+
convert_nodes <- function(nodes, depth) {
185+
new.nodes <- list()
183186
num.nodes <- 2^(depth + 1) - 1
184187
tree.array <- matrix(0, num.nodes, 4)
185188
frontier <- 1
@@ -191,16 +194,22 @@ tree_mat <- function(nodes, depth) {
191194
if (nodes[[node]]$is_leaf) {
192195
tree.array[j, 1] <- -1
193196
tree.array[j, 2] <- nodes[[node]]$action
197+
new.nodes[[j]] <- list(is_leaf = TRUE, action = nodes[[node]]$action)
194198
} else {
195199
tree.array[j, 1] <- nodes[[node]]$split_variable
196200
tree.array[j, 2] <- nodes[[node]]$split_value
197201
tree.array[j, 3] <- i + 1
198202
tree.array[j, 4] <- i + 2
203+
new.nodes[[j]] <- list(is_leaf = FALSE,
204+
split_variable = nodes[[node]]$split_variable,
205+
split_value = nodes[[node]]$split_value,
206+
left_child = i + 1,
207+
right_child = i + 2)
199208
frontier <- c(frontier, nodes[[node]]$left_child, nodes[[node]]$right_child)
200209
i <- i + 2
201210
}
202211
j <- j + 1
203212
}
204213

205-
tree.array
214+
list(nodes = new.nodes, tree.array = tree.array)
206215
}

r-package/policytree/tests/testthat/test_hybrid_policy_tree.R

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
test_that("hybrid_policy_tree works as expected", {
22
n <- 500
33
p <- 2
4-
d <- 2
4+
d <- 42
55
X <- round(matrix(rnorm(n * p), n, p), 1)
66
Y <- matrix(runif(n * d), n, d)
77

@@ -25,6 +25,17 @@ test_that("hybrid_policy_tree works as expected", {
2525
}
2626
}
2727
expect_equal(hpp, pp)
28+
# is internally consistent with predicted node ids.
29+
hpp.node <- predict(htree, X, type = "node.id")
30+
values <- aggregate(Y, by = list(leaf.node = hpp.node), FUN = mean)
31+
best <- apply(values[, -1], 1, FUN = which.max)
32+
expect_equal(best[match(hpp.node, values[, 1])], hpp)
33+
34+
# uses node labels that are the same as the printed labels.
35+
printed.node.id <- lapply(seq_along(htree$nodes), function(i) {
36+
if (htree$nodes[[i]]$is_leaf) i
37+
})
38+
expect_true(all(unique(hpp.node) %in% printed.node.id))
2839

2940
# search.depth = 1 when a single split is optimal is identical to a depth 1 policy_tree
3041
n <- 250
@@ -102,7 +113,7 @@ test_that("hybrid_policy_tree utils are internally consistent", {
102113

103114
depth <- 2
104115
tree <- policy_tree(X, Y, depth = depth)
105-
expect_equal(tree_mat(tree[["nodes"]], depth), tree[["_tree_array"]], tolerance = 1e-16)
116+
expect_equal(convert_nodes(tree[["nodes"]], depth)[[2]], tree[["_tree_array"]], tolerance = 1e-16)
106117
tree.nodes <- lapply(seq_along(tree$nodes), function(i) {
107118
node <- tree$nodes[[i]]
108119
node$has_subtree <- FALSE
@@ -112,7 +123,7 @@ test_that("hybrid_policy_tree utils are internally consistent", {
112123

113124
depth <- 3
114125
tree <- policy_tree(X, Y, depth = depth)
115-
expect_equal(tree_mat(tree[["nodes"]], depth), tree[["_tree_array"]], tolerance = 1e-16)
126+
expect_equal(convert_nodes(tree[["nodes"]], depth)[[2]], tree[["_tree_array"]], tolerance = 1e-16)
116127
tree.nodes <- lapply(seq_along(tree$nodes), function(i) {
117128
node <- tree$nodes[[i]]
118129
node$has_subtree <- FALSE

0 commit comments

Comments
 (0)