Skip to content

Commit 13388c5

Browse files
authored
Predict same leaf labels as printed labels with hybrid policytree (#153)
1 parent d9d80a5 commit 13388c5

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

r-package/policytree/R/hybrid_policy_tree.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ unpack_tree <- function(tree) {
183183
# For "largish" depth it could also be sparse.
184184
tree_mat <- function(nodes, depth) {
185185
num.nodes <- 2^(depth + 1) - 1
186-
tree.array <- matrix(0, num.nodes, 4)
186+
tree.array <- matrix(0, num.nodes, 5)
187187
frontier <- 1
188188
i <- 1
189189
j <- 1
@@ -193,11 +193,13 @@ tree_mat <- function(nodes, depth) {
193193
if (nodes[[node]]$is_leaf) {
194194
tree.array[j, 1] <- -1
195195
tree.array[j, 2] <- nodes[[node]]$action
196+
tree.array[j, 5] <- node - 1
196197
} else {
197198
tree.array[j, 1] <- nodes[[node]]$split_variable
198199
tree.array[j, 2] <- nodes[[node]]$split_value
199200
tree.array[j, 3] <- i + 1
200201
tree.array[j, 4] <- i + 2
202+
tree.array[j, 5] <- node - 1
201203
frontier <- c(frontier, nodes[[node]]$left_child, nodes[[node]]$right_child)
202204
i <- i + 2
203205
}

r-package/policytree/src/Rcppbindings.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ Rcpp::List tree_search_rcpp(const Rcpp::NumericMatrix& X,
5858
// We store the tree as the same list data structure (`nodes`) as GRF for seamless integration with
5959
// the plot and print methods. We also store the tree as an array (`tree_array`) for faster lookups.
6060
// This will make a difference for a very large amount of lookups, like n = 1 000 000.
61-
// The columns 0 to 3 are:
62-
// split_variable (-1 if leaf) | split_value (action_id if leaf) | left_child | right_child
61+
// The columns 0 to 4 are:
62+
// split_variable (-1 if leaf) | split_value (action_id if leaf) | left_child | right_child | optional node_id
6363
int num_nodes = pow(2.0, depth + 1.0) - 1;
64-
Rcpp::NumericMatrix tree_array(num_nodes, 4);
64+
Rcpp::NumericMatrix tree_array(num_nodes, 5);
6565
Rcpp::List nodes;
6666
int i = 1;
6767
int j = 0;
@@ -76,6 +76,7 @@ Rcpp::List tree_search_rcpp(const Rcpp::NumericMatrix& X,
7676
nodes.push_back(list_node);
7777
tree_array(j, 0) = -1;
7878
tree_array(j, 1) = node->action_id + 1;
79+
tree_array(j, 4) = j; // only used by hybrid policytree to predict leaf number that matches up with the printed leaf number.
7980
} else {
8081
auto list_node = Rcpp::List::create(Rcpp::Named("is_leaf") = false,
8182
Rcpp::Named("split_variable") = node->index + 1, // C++ index
@@ -87,6 +88,7 @@ Rcpp::List tree_search_rcpp(const Rcpp::NumericMatrix& X,
8788
tree_array(j, 1) = node->value;
8889
tree_array(j, 2) = i + 1; // left child
8990
tree_array(j, 3) = i + 2; // right child
91+
tree_array(j, 4) = j;
9092
frontier.push(std::move(node->left_child));
9193
frontier.push(std::move(node->right_child));
9294
i += 2;
@@ -120,8 +122,9 @@ Rcpp::NumericMatrix tree_search_rcpp_predict(const Rcpp::NumericMatrix& tree_arr
120122
bool is_leaf = tree_array(node, 0) == -1;
121123
if (is_leaf) {
122124
size_t action = tree_array(node, 1);
125+
size_t node_id = tree_array(node, 4);
123126
result(sample, 0) = action;
124-
result(sample, 1) = node;
127+
result(sample, 1) = node_id;
125128
break;
126129
}
127130
size_t split_var = tree_array(node, 0) - 1; // Offset by 1 for C++ indexing

r-package/policytree/tests/testthat/test_hybrid_policy_tree.R

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
test_that("hybrid_policy_tree works as expected", {
22
n <- 500
33
p <- 2
4-
d <- 2
4+
d <- 42
55
X <- round(matrix(rnorm(n * p), n, p), 1)
66
Y <- matrix(runif(n * d), n, d)
77

@@ -25,6 +25,17 @@ test_that("hybrid_policy_tree works as expected", {
2525
}
2626
}
2727
expect_equal(hpp, pp)
28+
# is internally consistent with predicted node ids.
29+
hpp.node <- predict(htree, X, type = "node.id")
30+
values <- aggregate(Y, by = list(leaf.node = hpp.node), FUN = mean)
31+
best <- apply(values[, -1], 1, FUN = which.max)
32+
expect_equal(best[match(hpp.node, values[, 1])], hpp)
33+
34+
# uses node labels that are the same as the printed labels.
35+
printed.node.id <- lapply(seq_along(htree$nodes), function(i) {
36+
if (htree$nodes[[i]]$is_leaf) i
37+
})
38+
expect_true(all(unique(hpp.node) %in% printed.node.id))
2839

2940
# search.depth = 1 when a single split is optimal is identical to a depth 1 policy_tree
3041
n <- 250

0 commit comments

Comments
 (0)