Skip to content

Commit f5751c4

Browse files
authored
Add heartbeat flicker if verbose = TRUE (#180)
1 parent c7ca614 commit f5751c4

File tree

6 files changed

+42
-15
lines changed

6 files changed

+42
-15
lines changed

r-package/policytree/R/RcppExports.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
22
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
33

4-
tree_search_rcpp <- function(X, Y, depth, split_step, min_node_size) {
5-
.Call('_policytree_tree_search_rcpp', PACKAGE = 'policytree', X, Y, depth, split_step, min_node_size)
4+
tree_search_rcpp <- function(X, Y, depth, split_step, min_node_size, verbose) {
5+
.Call('_policytree_tree_search_rcpp', PACKAGE = 'policytree', X, Y, depth, split_step, min_node_size, verbose)
66
}
77

88
tree_search_rcpp_predict <- function(tree_array, X) {

r-package/policytree/R/policy_tree.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ policy_tree <- function(X, Gamma, depth = 2, split.step = 1, min.node.size = 1,
166166
columns <- make.names(1:ncol(X))
167167
}
168168

169-
result <- tree_search_rcpp(as.matrix(X), as.matrix(Gamma), depth, split.step, min.node.size)
169+
result <- tree_search_rcpp(as.matrix(X), as.matrix(Gamma), depth, split.step, min.node.size, verbose && interactive())
170170
tree <- list(nodes = result[[1]])
171171

172172
tree[["_tree_array"]] <- result[[2]]

r-package/policytree/src/RcppExports.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
1111
#endif
1212

1313
// tree_search_rcpp
14-
Rcpp::List tree_search_rcpp(const Rcpp::NumericMatrix& X, const Rcpp::NumericMatrix& Y, int depth, int split_step, int min_node_size);
15-
RcppExport SEXP _policytree_tree_search_rcpp(SEXP XSEXP, SEXP YSEXP, SEXP depthSEXP, SEXP split_stepSEXP, SEXP min_node_sizeSEXP) {
14+
Rcpp::List tree_search_rcpp(const Rcpp::NumericMatrix& X, const Rcpp::NumericMatrix& Y, int depth, int split_step, int min_node_size, bool verbose);
15+
RcppExport SEXP _policytree_tree_search_rcpp(SEXP XSEXP, SEXP YSEXP, SEXP depthSEXP, SEXP split_stepSEXP, SEXP min_node_sizeSEXP, SEXP verboseSEXP) {
1616
BEGIN_RCPP
1717
Rcpp::RObject rcpp_result_gen;
1818
Rcpp::RNGScope rcpp_rngScope_gen;
@@ -21,7 +21,8 @@ BEGIN_RCPP
2121
Rcpp::traits::input_parameter< int >::type depth(depthSEXP);
2222
Rcpp::traits::input_parameter< int >::type split_step(split_stepSEXP);
2323
Rcpp::traits::input_parameter< int >::type min_node_size(min_node_sizeSEXP);
24-
rcpp_result_gen = Rcpp::wrap(tree_search_rcpp(X, Y, depth, split_step, min_node_size));
24+
Rcpp::traits::input_parameter< bool >::type verbose(verboseSEXP);
25+
rcpp_result_gen = Rcpp::wrap(tree_search_rcpp(X, Y, depth, split_step, min_node_size, verbose));
2526
return rcpp_result_gen;
2627
END_RCPP
2728
}
@@ -39,7 +40,7 @@ END_RCPP
3940
}
4041

4142
static const R_CallMethodDef CallEntries[] = {
42-
{"_policytree_tree_search_rcpp", (DL_FUNC) &_policytree_tree_search_rcpp, 5},
43+
{"_policytree_tree_search_rcpp", (DL_FUNC) &_policytree_tree_search_rcpp, 6},
4344
{"_policytree_tree_search_rcpp_predict", (DL_FUNC) &_policytree_tree_search_rcpp_predict, 2},
4445
{NULL, NULL, 0}
4546
};

r-package/policytree/src/Rcppbindings.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,16 @@ Rcpp::List tree_search_rcpp(const Rcpp::NumericMatrix& X,
3131
const Rcpp::NumericMatrix& Y,
3232
int depth,
3333
int split_step,
34-
int min_node_size) {
34+
int min_node_size,
35+
bool verbose) {
3536
size_t num_rows = X.rows();
3637
size_t num_cols_x = X.cols();
3738
size_t num_cols_y = Y.cols();
3839
const Data* data = new Data(X.begin(), Y.begin(), num_rows, num_cols_x, num_cols_y);
3940
const auto interrupt_handler = []() { Rcpp::checkUserInterrupt(); };
41+
std::ostream* verbose_stream = verbose ? &Rcpp::Rcout : nullptr;
4042

41-
std::unique_ptr<Node> root = tree_search(depth, split_step, min_node_size, data, interrupt_handler);
43+
std::unique_ptr<Node> root = tree_search(depth, split_step, min_node_size, data, interrupt_handler, verbose_stream);
4244

4345
// We store the tree as the same list data structure (`nodes`) as GRF for seamless integration with
4446
// the plot and print methods. We also store the tree as an array (`tree_array`) for faster lookups.

r-package/policytree/src/tree_search.cpp

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,25 @@
77

88
#include "tree_search.h"
99

10+
// Tiny heartbeat "spinner" to show solver is working.
11+
inline void print_heartbeat(std::ostream* verbose_stream) {
12+
if (verbose_stream == nullptr) return;
13+
14+
static const char frames[] = {'|', '/', '-', '\\'};
15+
static size_t frame = 0;
16+
17+
(*verbose_stream) << "\r" << frames[frame & 3] << std::flush;
18+
frame++;
19+
}
20+
21+
// Clears heartbeat when solver exits.
22+
struct ScopedHeartbeat {
23+
std::ostream* stream;
24+
~ScopedHeartbeat() {
25+
if (stream) (*stream) << "\r \r" << std::flush;
26+
}
27+
};
28+
1029
typedef boost::container::flat_set<Point, std::function<bool(const Point&, const Point&)>> flat_set;
1130

1231
/**
@@ -269,7 +288,8 @@ std::unique_ptr<Node> find_best_split(const std::vector<flat_set>& sorted_sets,
269288
size_t min_node_size,
270289
const Data* data,
271290
std::vector<std::vector<double>>& sum_array,
272-
const std::function<void()>& interrupt_handler) {
291+
const std::function<void()>& interrupt_handler,
292+
std::ostream* verbose_stream) {
273293
if (level == 0) {
274294
// this base case will only be hit if `find_best_split` is called directly with level = 0
275295
return level_zero_learning(sorted_sets, data);
@@ -294,6 +314,7 @@ std::unique_ptr<Node> find_best_split(const std::vector<flat_set>& sorted_sets,
294314
for (size_t n = 0; n < num_points - 1; n++) {
295315
if ((n & 1023) == 0) {
296316
interrupt_handler(); // check for ctrl-c interrupt every 1024 iterations
317+
print_heartbeat(verbose_stream);
297318
}
298319
auto point = right_sorted_sets[p].cbegin(); // O(1)
299320
Point point_bk = *point; // store the Point instance since the iterator will be invalid after erase
@@ -321,8 +342,8 @@ std::unique_ptr<Node> find_best_split(const std::vector<flat_set>& sorted_sets,
321342
} else {
322343
continue;
323344
}
324-
auto left_child = find_best_split(left_sorted_sets, level - 1, split_step, min_node_size, data, sum_array, interrupt_handler);
325-
auto right_child = find_best_split(right_sorted_sets, level - 1, split_step, min_node_size, data, sum_array, interrupt_handler);
345+
auto left_child = find_best_split(left_sorted_sets, level - 1, split_step, min_node_size, data, sum_array, interrupt_handler, verbose_stream);
346+
auto right_child = find_best_split(right_sorted_sets, level - 1, split_step, min_node_size, data, sum_array, interrupt_handler, verbose_stream);
326347
if ((best_left_child == nullptr) ||
327348
(left_child->reward + right_child->reward >
328349
best_left_child->reward + best_right_child->reward)) {
@@ -355,7 +376,7 @@ std::unique_ptr<Node> find_best_split(const std::vector<flat_set>& sorted_sets,
355376
}
356377

357378
std::unique_ptr<Node> tree_search(int depth, int split_step, size_t min_node_size, const Data* data,
358-
const std::function<void()>& interrupt_handler) {
379+
const std::function<void()>& interrupt_handler, std::ostream* verbose_stream) {
359380
size_t num_rewards = data->num_rewards();
360381
size_t num_points = data->num_rows;
361382
auto sorted_sets = create_sorted_sets(data);
@@ -367,5 +388,7 @@ std::unique_ptr<Node> tree_search(int depth, int split_step, size_t min_node_siz
367388
v.resize(num_points + 1, 0.0);
368389
}
369390

370-
return find_best_split(sorted_sets, depth, split_step, min_node_size, data, sum_array, interrupt_handler);
391+
ScopedHeartbeat heartbeat {verbose_stream};
392+
393+
return find_best_split(sorted_sets, depth, split_step, min_node_size, data, sum_array, interrupt_handler, verbose_stream);
371394
}

r-package/policytree/src/tree_search.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <limits>
55
#include <memory>
66
#include <stdexcept>
7+
#include <ostream>
78
#include <vector>
89

910
const double INF = std::numeric_limits<double>::infinity();
@@ -91,6 +92,6 @@ struct Node {
9192
};
9293

9394

94-
std::unique_ptr<Node> tree_search(int, int, size_t, const Data*, const std::function<void()>&);
95+
std::unique_ptr<Node> tree_search(int, int, size_t, const Data*, const std::function<void()>&, std::ostream*);
9596

9697
#endif // TREE_SEARCH_H

0 commit comments

Comments
 (0)