@@ -58,10 +58,10 @@ Rcpp::List tree_search_rcpp(const Rcpp::NumericMatrix& X,
58
58
// We store the tree as the same list data structure (`nodes`) as GRF for seamless integration with
59
59
// the plot and print methods. We also store the tree as an array (`tree_array`) for faster lookups.
60
60
// This will make a difference for a very large amount of lookups, like n = 1 000 000.
61
- // The columns 0 to 3 are:
62
- // split_variable (-1 if leaf) | split_value (action_id if leaf) | left_child | right_child
61
+ // The columns 0 to 4 are:
62
+ // split_variable (-1 if leaf) | split_value (action_id if leaf) | left_child | right_child | optional node_id
63
63
int num_nodes = pow (2.0 , depth + 1.0 ) - 1 ;
64
- Rcpp::NumericMatrix tree_array (num_nodes, 4 );
64
+ Rcpp::NumericMatrix tree_array (num_nodes, 5 );
65
65
Rcpp::List nodes;
66
66
int i = 1 ;
67
67
int j = 0 ;
@@ -76,6 +76,7 @@ Rcpp::List tree_search_rcpp(const Rcpp::NumericMatrix& X,
76
76
nodes.push_back (list_node);
77
77
tree_array (j, 0 ) = -1 ;
78
78
tree_array (j, 1 ) = node->action_id + 1 ;
79
+ tree_array (j, 4 ) = j; // only used by hybrid policytree to predict leaf number that matches up with the printed leaf number.
79
80
} else {
80
81
auto list_node = Rcpp::List::create (Rcpp::Named (" is_leaf" ) = false ,
81
82
Rcpp::Named (" split_variable" ) = node->index + 1 , // C++ index
@@ -87,6 +88,7 @@ Rcpp::List tree_search_rcpp(const Rcpp::NumericMatrix& X,
87
88
tree_array (j, 1 ) = node->value ;
88
89
tree_array (j, 2 ) = i + 1 ; // left child
89
90
tree_array (j, 3 ) = i + 2 ; // right child
91
+ tree_array (j, 4 ) = j;
90
92
frontier.push (std::move (node->left_child ));
91
93
frontier.push (std::move (node->right_child ));
92
94
i += 2 ;
@@ -120,8 +122,9 @@ Rcpp::NumericMatrix tree_search_rcpp_predict(const Rcpp::NumericMatrix& tree_arr
120
122
bool is_leaf = tree_array (node, 0 ) == -1 ;
121
123
if (is_leaf) {
122
124
size_t action = tree_array (node, 1 );
125
+ size_t node_id = tree_array (node, 4 );
123
126
result (sample, 0 ) = action;
124
- result (sample, 1 ) = node ;
127
+ result (sample, 1 ) = node_id ;
125
128
break ;
126
129
}
127
130
size_t split_var = tree_array (node, 0 ) - 1 ; // Offset by 1 for C++ indexing
0 commit comments