@@ -62,18 +62,19 @@ class HistUpdater {
62
62
tree_evaluator_(qu, param, fmat->Info ().num_col_),
63
63
pruner_(std::move(pruner)),
64
64
interaction_constraints_{std::move (int_constraints_)},
65
- p_last_tree_ (nullptr ), p_last_fmat_(fmat),
66
- column_sampler_(seed_) {
65
+ p_last_tree_ (nullptr ), p_last_fmat_(fmat) {
67
66
builder_monitor_.Init (" SYCL::Quantile::HistUpdater" );
68
67
kernel_monitor_.Init (" SYCL::Quantile::HistUpdater" );
68
+ if (param.max_depth > 0 ) {
69
+ snode_device_.Resize (&qu, 1u << (param.max_depth + 1 ));
70
+ }
69
71
const auto sub_group_sizes =
70
72
qu_.get_device ().get_info <::sycl::info::device::sub_group_sizes>();
71
73
sub_group_size_ = sub_group_sizes.back ();
72
74
}
73
75
74
76
// update one tree, growing
75
- void Update (Context const * ctx,
76
- xgboost::tree::TrainParam const *param,
77
+ void Update (xgboost::tree::TrainParam const *param,
77
78
const common::GHistIndexMatrix &gmat,
78
79
linalg::Matrix<GradientPair> *gpair,
79
80
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
@@ -181,11 +182,6 @@ class HistUpdater {
181
182
RegTree *p_tree,
182
183
const USMVector<GradientPair, MemoryType::on_device> &gpair);
183
184
184
- void ExpandWithLossGuide (const common::GHistIndexMatrix& gmat,
185
- DMatrix* p_fmat,
186
- RegTree* p_tree,
187
- const USMVector<GradientPair, MemoryType::on_device>& gpair);
188
-
189
185
void EvaluateAndApplySplits (const common::GHistIndexMatrix &gmat,
190
186
RegTree *p_tree,
191
187
int *num_leaves,
@@ -200,6 +196,11 @@ class HistUpdater {
200
196
std::vector<ExpandEntry>* nodes_for_apply_split,
201
197
std::vector<ExpandEntry>* temp_qexpand_depth);
202
198
199
+ void ExpandWithLossGuide (const common::GHistIndexMatrix& gmat,
200
+ DMatrix* p_fmat,
201
+ RegTree* p_tree,
202
+ const USMVector<GradientPair, MemoryType::on_device>& gpair);
203
+
203
204
void ReduceHists (const std::vector<int >& sync_ids, size_t nbins);
204
205
205
206
inline static bool LossGuide (ExpandEntry lhs, ExpandEntry rhs) {
@@ -213,13 +214,14 @@ class HistUpdater {
213
214
// --data fields--
214
215
const Context* ctx_;
215
216
size_t sub_group_size_;
217
+ const xgboost::tree::TrainParam& param_;
218
+ std::shared_ptr<xgboost::common::ColumnSampler> column_sampler_;
216
219
217
220
// the internal row sets
218
221
common::RowSetCollection row_set_collection_;
219
222
std::vector<SplitQuery> split_queries_host_;
220
223
USMVector<SplitQuery, MemoryType::on_device> split_queries_device_;
221
224
222
- const xgboost::tree::TrainParam& param_;
223
225
TreeEvaluator<GradientSumT> tree_evaluator_;
224
226
std::unique_ptr<TreeUpdater> pruner_;
225
227
FeatureInteractionConstraintHost interaction_constraints_;
@@ -258,7 +260,6 @@ class HistUpdater {
258
260
uint32_t fid_least_bins_;
259
261
260
262
uint64_t seed_ = 0 ;
261
- xgboost::common::ColumnSampler column_sampler_;
262
263
263
264
common::PartitionBuilder partition_builder_;
264
265
0 commit comments