13
13
#include < limits>
14
14
#include < vector>
15
15
16
+ #include " ../../src/tree/common_row_partitioner.h"
17
+
16
18
#include " ../common/hist_util.h"
17
19
#include " ../../src/collective/allreduce.h"
18
20
@@ -786,34 +788,6 @@ void HistUpdater<GradientSumT>::EnumerateSplit(
786
788
best.SplitIndex () == total_split_index) p_best->Update (best);
787
789
}
788
790
789
- template <typename GradientSumT>
790
- void HistUpdater<GradientSumT>::FindSplitConditions(
791
- const std::vector<ExpandEntry>& nodes,
792
- const RegTree& tree,
793
- const common::GHistIndexMatrix& gmat,
794
- std::vector<int32_t >* split_conditions) {
795
- const size_t n_nodes = nodes.size ();
796
- split_conditions->resize (n_nodes);
797
-
798
- for (size_t i = 0 ; i < nodes.size (); ++i) {
799
- const int32_t nid = nodes[i].nid ;
800
- const bst_uint fid = tree[nid].SplitIndex ();
801
- const bst_float split_pt = tree[nid].SplitCond ();
802
- const uint32_t lower_bound = gmat.cut .Ptrs ()[fid];
803
- const uint32_t upper_bound = gmat.cut .Ptrs ()[fid + 1 ];
804
- int32_t split_cond = -1 ;
805
- // convert floating-point split_pt into corresponding bin_id
806
- // split_cond = -1 indicates that split_pt is less than all known cut points
807
- CHECK_LT (upper_bound,
808
- static_cast <uint32_t >(std::numeric_limits<int32_t >::max ()));
809
- for (uint32_t i = lower_bound; i < upper_bound; ++i) {
810
- if (split_pt == gmat.cut .Values ()[i]) {
811
- split_cond = static_cast <int32_t >(i);
812
- }
813
- }
814
- (*split_conditions)[i] = split_cond;
815
- }
816
- }
817
791
template <typename GradientSumT>
818
792
void HistUpdater<GradientSumT>::AddSplitsToRowSet(
819
793
const std::vector<ExpandEntry>& nodes,
@@ -835,11 +809,12 @@ void HistUpdater<GradientSumT>::ApplySplit(
835
809
const common::GHistIndexMatrix& gmat,
836
810
const common::HistCollection<GradientSumT, MemoryType::on_device>& hist,
837
811
RegTree* p_tree) {
812
+ using CommonRowPartitioner = xgboost::tree::CommonRowPartitioner;
838
813
builder_monitor_.Start (" ApplySplit" );
839
814
840
815
const size_t n_nodes = nodes.size ();
841
- std::vector<int32_t > split_conditions;
842
- FindSplitConditions (nodes, *p_tree, gmat, &split_conditions);
816
+ std::vector<int32_t > split_conditions (n_nodes) ;
817
+ CommonRowPartitioner:: FindSplitConditions (nodes, *p_tree, gmat, &split_conditions);
843
818
844
819
partition_builder_.Init (&qu_, n_nodes, [&](size_t node_in_set) {
845
820
const int32_t nid = nodes[node_in_set].nid ;
0 commit comments