@@ -183,7 +183,7 @@ void HistUpdater<GradientSumT>::EvaluateAndApplySplits(
183
183
int *num_leaves,
184
184
int depth,
185
185
std::vector<ExpandEntry> *temp_qexpand_depth) {
186
- EvaluateSplits (qexpand_depth_wise_, gmat, hist_, *p_tree);
186
+ EvaluateSplits (qexpand_depth_wise_, gmat, *p_tree);
187
187
188
188
std::vector<ExpandEntry> nodes_for_apply_split;
189
189
AddSplitsToTree (gmat, p_tree, num_leaves, depth,
@@ -280,7 +280,7 @@ void HistUpdater<GradientSumT>::ExpandWithLossGuide(
280
280
281
281
this ->InitNewNode (ExpandEntry::kRootNid , gmat, gpair, *p_fmat, *p_tree);
282
282
283
- this ->EvaluateSplits ({node}, gmat, hist_, *p_tree);
283
+ this ->EvaluateSplits ({node}, gmat, *p_tree);
284
284
node.split .loss_chg = snode_host_[ExpandEntry::kRootNid ].best .loss_chg ;
285
285
286
286
qexpand_loss_guided_->push (node);
@@ -325,7 +325,7 @@ void HistUpdater<GradientSumT>::ExpandWithLossGuide(
325
325
snode_host_[cleft].weight , snode_host_[cright].weight );
326
326
interaction_constraints_.Split (nid, featureid, cleft, cright);
327
327
328
- this ->EvaluateSplits ({left_node, right_node}, gmat, hist_, *p_tree);
328
+ this ->EvaluateSplits ({left_node, right_node}, gmat, *p_tree);
329
329
left_node.split .loss_chg = snode_host_[cleft].best .loss_chg ;
330
330
right_node.split .loss_chg = snode_host_[cright].best .loss_chg ;
331
331
@@ -472,7 +472,7 @@ void HistUpdater<GradientSumT>::InitSampling(
472
472
});
473
473
});
474
474
} else {
475
- // Use oneDPL uniform for better perf , as far as bernoulli_distribution uses fp64
475
+ // Use oneDPL uniform, as far as bernoulli_distribution uses fp64
476
476
event = qu_.submit ([&](::sycl::handler& cgh) {
477
477
auto flag_buf_acc = flag_buf.get_access <::sycl::access::mode::read_write>(cgh);
478
478
cgh.parallel_for <>(::sycl::range<1 >(::sycl::range<1 >(num_rows)),
@@ -649,45 +649,32 @@ template<typename GradientSumT>
649
649
void HistUpdater<GradientSumT>::EvaluateSplits(
650
650
const std::vector<ExpandEntry>& nodes_set,
651
651
const common::GHistIndexMatrix& gmat,
652
- const common::HistCollection<GradientSumT, MemoryType::on_device>& hist,
653
652
const RegTree& tree) {
654
653
builder_monitor_.Start (" EvaluateSplits" );
655
654
656
655
const size_t n_nodes_in_set = nodes_set.size ();
657
656
658
657
using FeatureSetType = std::shared_ptr<HostDeviceVector<bst_feature_t >>;
659
- std::vector<FeatureSetType> features_sets (n_nodes_in_set);
660
658
661
659
// Generate feature set for each tree node
662
- size_t total_features = 0 ;
663
- for (size_t nid_in_set = 0 ; nid_in_set < n_nodes_in_set; ++nid_in_set) {
664
- const int32_t nid = nodes_set[nid_in_set].nid ;
665
- features_sets[nid_in_set] = column_sampler_->GetFeatureSet (tree.GetDepth (nid));
666
- for (size_t idx = 0 ; idx < features_sets[nid_in_set]->Size (); idx++) {
667
- const auto fid = features_sets[nid_in_set]->ConstHostVector ()[idx];
668
- if (interaction_constraints_.Query (nid, fid)) {
669
- total_features++;
670
- }
671
- }
672
- }
673
-
674
- split_queries_host_.resize (total_features);
675
660
size_t pos = 0 ;
676
-
677
661
for (size_t nid_in_set = 0 ; nid_in_set < n_nodes_in_set; ++nid_in_set) {
678
- const size_t nid = nodes_set[nid_in_set].nid ;
679
-
680
- for (size_t idx = 0 ; idx < features_sets[nid_in_set] ->Size (); idx++) {
681
- const auto fid = features_sets[nid_in_set] ->ConstHostVector ()[idx];
662
+ const bst_node_t nid = nodes_set[nid_in_set].nid ;
663
+ FeatureSetType features_set = column_sampler_-> GetFeatureSet (tree. GetDepth (nid));
664
+ for (size_t idx = 0 ; idx < features_set ->Size (); idx++) {
665
+ const size_t fid = features_set ->ConstHostVector ()[idx];
682
666
if (interaction_constraints_.Query (nid, fid)) {
683
- split_queries_host_[pos].nid = nid;
684
- split_queries_host_[pos].fid = fid;
685
- split_queries_host_[pos].hist = hist[nid].DataConst ();
686
- split_queries_host_[pos].best = snode_host_[nid].best ;
687
- pos++;
667
+ auto this_hist = hist_[nid].DataConst ();
668
+ if (pos < split_queries_host_.size ()) {
669
+ split_queries_host_[pos] = SplitQuery{nid, fid, this_hist};
670
+ } else {
671
+ split_queries_host_.push_back ({nid, fid, this_hist});
672
+ }
673
+ ++pos;
688
674
}
689
675
}
690
676
}
677
+ const size_t total_features = pos;
691
678
692
679
split_queries_device_.Resize (&qu_, total_features);
693
680
auto event = qu_.memcpy (split_queries_device_.Data (), split_queries_host_.data (),
@@ -702,10 +689,14 @@ void HistUpdater<GradientSumT>::EvaluateSplits(
702
689
snode_device_.ResizeNoCopy (&qu_, snode_host_.size ());
703
690
event = qu_.memcpy (snode_device_.Data (), snode_host_.data (),
704
691
snode_host_.size () * sizeof (NodeEntry<GradientSumT>), event);
705
- const NodeEntry<GradientSumT>* snode = snode_device_.DataConst ();
692
+ const NodeEntry<GradientSumT>* snode = snode_device_.Data ();
706
693
707
694
const float min_child_weight = param_.min_child_weight ;
708
695
696
+ best_splits_device_.ResizeNoCopy (&qu_, total_features);
697
+ if (best_splits_host_.size () < total_features) best_splits_host_.resize (total_features);
698
+ SplitEntry<GradientSumT>* best_splits = best_splits_device_.Data ();
699
+
709
700
event = qu_.submit ([&](::sycl::handler& cgh) {
710
701
cgh.depends_on (event);
711
702
cgh.parallel_for <>(::sycl::nd_range<2 >(::sycl::range<2 >(total_features, sub_group_size_),
@@ -717,17 +708,18 @@ void HistUpdater<GradientSumT>::EvaluateSplits(
717
708
int fid = split_queries_device[i].fid ;
718
709
const GradientPairT* hist_data = split_queries_device[i].hist ;
719
710
711
+ best_splits[i] = snode[nid].best ;
720
712
EnumerateSplit (sg, cut_ptr, cut_val, hist_data, snode[nid],
721
- &(split_queries_device [i]. best ), fid, nid, evaluator, min_child_weight);
713
+ &(best_splits [i]), fid, nid, evaluator, min_child_weight);
722
714
});
723
715
});
724
- event = qu_.memcpy (split_queries_host_ .data (), split_queries_device_. Data () ,
725
- total_features * sizeof (SplitQuery ), event);
716
+ event = qu_.memcpy (best_splits_host_ .data (), best_splits ,
717
+ total_features * sizeof (SplitEntry<GradientSumT> ), event);
726
718
727
719
qu_.wait ();
728
720
for (size_t i = 0 ; i < total_features; i++) {
729
721
int nid = split_queries_host_[i].nid ;
730
- snode_host_[nid].best .Update (split_queries_host_ [i]. best );
722
+ snode_host_[nid].best .Update (best_splits_host_ [i]);
731
723
}
732
724
733
725
builder_monitor_.Stop (" EvaluateSplits" );
0 commit comments