77
88#include " tree_search.h"
99
10+ // Tiny heartbeat "spinner" to show solver is working.
11+ inline void print_heartbeat (std::ostream* verbose_stream) {
12+ if (verbose_stream == nullptr ) return ;
13+
14+ static const char frames[] = {' |' , ' /' , ' -' , ' \\ ' };
15+ static size_t frame = 0 ;
16+
17+ (*verbose_stream) << " \r " << frames[frame & 3 ] << std::flush;
18+ frame++;
19+ }
20+
21+ // Clears heartbeat when solver exits.
22+ struct ScopedHeartbeat {
23+ std::ostream* stream;
24+ ~ScopedHeartbeat () {
25+ if (stream) (*stream) << " \r \r " << std::flush;
26+ }
27+ };
28+
1029typedef boost::container::flat_set<Point, std::function<bool (const Point&, const Point&)>> flat_set;
1130
1231/* *
@@ -269,7 +288,8 @@ std::unique_ptr<Node> find_best_split(const std::vector<flat_set>& sorted_sets,
269288 size_t min_node_size,
270289 const Data* data,
271290 std::vector<std::vector<double >>& sum_array,
272- const std::function<void ()>& interrupt_handler) {
291+ const std::function<void ()>& interrupt_handler,
292+ std::ostream* verbose_stream) {
273293 if (level == 0 ) {
274294 // this base case will only be hit if `find_best_split` is called directly with level = 0
275295 return level_zero_learning (sorted_sets, data);
@@ -294,6 +314,7 @@ std::unique_ptr<Node> find_best_split(const std::vector<flat_set>& sorted_sets,
294314 for (size_t n = 0 ; n < num_points - 1 ; n++) {
295315 if ((n & 1023 ) == 0 ) {
296316 interrupt_handler (); // check for ctrl-c interrupt every 1024 iterations
317+ print_heartbeat (verbose_stream);
297318 }
298319 auto point = right_sorted_sets[p].cbegin (); // O(1)
299320 Point point_bk = *point; // store the Point instance since the iterator will be invalid after erase
@@ -321,8 +342,8 @@ std::unique_ptr<Node> find_best_split(const std::vector<flat_set>& sorted_sets,
321342 } else {
322343 continue ;
323344 }
324- auto left_child = find_best_split (left_sorted_sets, level - 1 , split_step, min_node_size, data, sum_array, interrupt_handler);
325- auto right_child = find_best_split (right_sorted_sets, level - 1 , split_step, min_node_size, data, sum_array, interrupt_handler);
345+ auto left_child = find_best_split (left_sorted_sets, level - 1 , split_step, min_node_size, data, sum_array, interrupt_handler, verbose_stream );
346+ auto right_child = find_best_split (right_sorted_sets, level - 1 , split_step, min_node_size, data, sum_array, interrupt_handler, verbose_stream );
326347 if ((best_left_child == nullptr ) ||
327348 (left_child->reward + right_child->reward >
328349 best_left_child->reward + best_right_child->reward )) {
@@ -355,7 +376,7 @@ std::unique_ptr<Node> find_best_split(const std::vector<flat_set>& sorted_sets,
355376}
356377
357378std::unique_ptr<Node> tree_search (int depth, int split_step, size_t min_node_size, const Data* data,
358- const std::function<void ()>& interrupt_handler) {
379+ const std::function<void ()>& interrupt_handler, std::ostream* verbose_stream ) {
359380 size_t num_rewards = data->num_rewards ();
360381 size_t num_points = data->num_rows ;
361382 auto sorted_sets = create_sorted_sets (data);
@@ -367,5 +388,7 @@ std::unique_ptr<Node> tree_search(int depth, int split_step, size_t min_node_siz
367388 v.resize (num_points + 1 , 0.0 );
368389 }
369390
370- return find_best_split (sorted_sets, depth, split_step, min_node_size, data, sum_array, interrupt_handler);
391+ ScopedHeartbeat heartbeat {verbose_stream};
392+
393+ return find_best_split (sorted_sets, depth, split_step, min_node_size, data, sum_array, interrupt_handler, verbose_stream);
371394}
0 commit comments