@@ -58,10 +58,10 @@ Rcpp::List tree_search_rcpp(const Rcpp::NumericMatrix& X,
58
58
// We store the tree as the same list data structure (`nodes`) as GRF for seamless integration with
59
59
// the plot and print methods. We also store the tree as an array (`tree_array`) for faster lookups.
60
60
// This will make a difference for a very large amount of lookups, like n = 1 000 000.
61
- // The columns 0 to 4 are:
62
- // split_variable (-1 if leaf) | split_value (action_id if leaf) | left_child | right_child | optional node_id
61
+ // The columns 0 to 3 are:
62
+ // split_variable (-1 if leaf) | split_value (action_id if leaf) | left_child | right_child
63
63
int num_nodes = pow (2.0 , depth + 1.0 ) - 1 ;
64
- Rcpp::NumericMatrix tree_array (num_nodes, 5 );
64
+ Rcpp::NumericMatrix tree_array (num_nodes, 4 );
65
65
Rcpp::List nodes;
66
66
int i = 1 ;
67
67
int j = 0 ;
@@ -76,7 +76,6 @@ Rcpp::List tree_search_rcpp(const Rcpp::NumericMatrix& X,
76
76
nodes.push_back (list_node);
77
77
tree_array (j, 0 ) = -1 ;
78
78
tree_array (j, 1 ) = node->action_id + 1 ;
79
- tree_array (j, 4 ) = j; // only used by hybrid policytree to predict leaf number that matches up with the printed leaf number.
80
79
} else {
81
80
auto list_node = Rcpp::List::create (Rcpp::Named (" is_leaf" ) = false ,
82
81
Rcpp::Named (" split_variable" ) = node->index + 1 , // C++ index
@@ -88,7 +87,6 @@ Rcpp::List tree_search_rcpp(const Rcpp::NumericMatrix& X,
88
87
tree_array (j, 1 ) = node->value ;
89
88
tree_array (j, 2 ) = i + 1 ; // left child
90
89
tree_array (j, 3 ) = i + 2 ; // right child
91
- tree_array (j, 4 ) = j;
92
90
frontier.push (std::move (node->left_child ));
93
91
frontier.push (std::move (node->right_child ));
94
92
i += 2 ;
@@ -122,9 +120,8 @@ Rcpp::NumericMatrix tree_search_rcpp_predict(const Rcpp::NumericMatrix& tree_arr
122
120
bool is_leaf = tree_array (node, 0 ) == -1 ;
123
121
if (is_leaf) {
124
122
size_t action = tree_array (node, 1 );
125
- size_t node_id = tree_array (node, 4 );
126
123
result (sample, 0 ) = action;
127
- result (sample, 1 ) = node_id ;
124
+ result (sample, 1 ) = node ;
128
125
break ;
129
126
}
130
127
size_t split_var = tree_array (node, 0 ) - 1 ; // Offset by 1 for C++ indexing
0 commit comments