|
1 | 1 | /**
|
2 |
| - * Copyright 2021-2023 by XGBoost Contributors |
| 2 | + * Copyright 2021-2024, XGBoost Contributors |
3 | 3 | */
|
4 | 4 | #include "../test_evaluate_splits.h"
|
5 | 5 |
|
|
10 | 10 | #include <xgboost/logging.h> // for CHECK_EQ
|
11 | 11 | #include <xgboost/tree_model.h> // for RegTree, RTreeNodeStat
|
12 | 12 |
|
13 |
| -#include <memory> // for make_shared, shared_ptr, addressof |
| 13 | +#include <memory> // for make_shared, shared_ptr, addressof |
| 14 | +#include <numeric> // for iota |
| 15 | +#include <tuple> // for make_tuple |
14 | 16 |
|
15 | 17 | #include "../../../../src/common/hist_util.h" // for HistCollection, HistogramCuts
|
16 | 18 | #include "../../../../src/common/random.h" // for ColumnSampler
|
17 | 19 | #include "../../../../src/common/row_set.h" // for RowSetCollection
|
18 | 20 | #include "../../../../src/data/gradient_index.h" // for GHistIndexMatrix
|
19 |
| -#include "../../../../src/tree/hist/evaluate_splits.h" // for HistEvaluator |
| 21 | +#include "../../../../src/tree/hist/evaluate_splits.h" // for HistEvaluator, TreeEvaluator |
20 | 22 | #include "../../../../src/tree/hist/expand_entry.h" // for CPUExpandEntry
|
21 | 23 | #include "../../../../src/tree/hist/hist_cache.h" // for BoundedHistCollection
|
22 | 24 | #include "../../../../src/tree/hist/param.h" // for HistMakerTrainParam
|
23 | 25 | #include "../../../../src/tree/param.h" // for GradStats, TrainParam
|
24 | 26 | #include "../../helpers.h" // for RandomDataGenerator, AllThreadsFo...
|
25 | 27 |
|
26 | 28 | namespace xgboost::tree {
|
| 29 | +void TestPartitionBasedSplit::SetUp() { |
| 30 | + param_.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}}); |
| 31 | + sorted_idx_.resize(n_bins_); |
| 32 | + std::iota(sorted_idx_.begin(), sorted_idx_.end(), 0); |
| 33 | + |
| 34 | + info_.num_col_ = 1; |
| 35 | + |
| 36 | + cuts_.cut_ptrs_.Resize(2); |
| 37 | + cuts_.SetCategorical(true, n_bins_); |
| 38 | + auto &h_cuts = cuts_.cut_ptrs_.HostVector(); |
| 39 | + h_cuts[0] = 0; |
| 40 | + h_cuts[1] = n_bins_; |
| 41 | + auto &h_vals = cuts_.cut_values_.HostVector(); |
| 42 | + h_vals.resize(n_bins_); |
| 43 | + std::iota(h_vals.begin(), h_vals.end(), 0.0); |
| 44 | + |
| 45 | + cuts_.min_vals_.Resize(1); |
| 46 | + |
| 47 | + Context ctx; |
| 48 | + HistMakerTrainParam hist_param; |
| 49 | + hist_.Reset(cuts_.TotalBins(), hist_param.MaxCachedHistNodes(ctx.Device())); |
| 50 | + hist_.AllocateHistograms({0}); |
| 51 | + auto node_hist = hist_[0]; |
| 52 | + |
| 53 | + SimpleLCG lcg; |
| 54 | + SimpleRealUniformDistribution<double> grad_dist{-4.0, 4.0}; |
| 55 | + SimpleRealUniformDistribution<double> hess_dist{0.0, 4.0}; |
| 56 | + |
| 57 | + for (auto &e : node_hist) { |
| 58 | + e = GradientPairPrecise{grad_dist(&lcg), hess_dist(&lcg)}; |
| 59 | + total_gpair_ += e; |
| 60 | + } |
| 61 | + |
| 62 | + auto enumerate = [this, n_feat = info_.num_col_](common::GHistRow hist, |
| 63 | + GradientPairPrecise parent_sum) { |
| 64 | + int32_t best_thresh = -1; |
| 65 | + float best_score{-std::numeric_limits<float>::infinity()}; |
| 66 | + TreeEvaluator evaluator{param_, static_cast<bst_feature_t>(n_feat), DeviceOrd::CPU()}; |
| 67 | + auto tree_evaluator = evaluator.GetEvaluator<TrainParam>(); |
| 68 | + GradientPairPrecise left_sum; |
| 69 | + auto parent_gain = tree_evaluator.CalcGain(0, param_, GradStats{total_gpair_}); |
| 70 | + for (size_t i = 0; i < hist.size() - 1; ++i) { |
| 71 | + left_sum += hist[i]; |
| 72 | + auto right_sum = parent_sum - left_sum; |
| 73 | + auto gain = |
| 74 | + tree_evaluator.CalcSplitGain(param_, 0, 0, GradStats{left_sum}, GradStats{right_sum}) - |
| 75 | + parent_gain; |
| 76 | + if (gain > best_score) { |
| 77 | + best_score = gain; |
| 78 | + best_thresh = i; |
| 79 | + } |
| 80 | + } |
| 81 | + return std::make_tuple(best_thresh, best_score); |
| 82 | + }; |
| 83 | + |
| 84 | + // enumerate all possible partitions to find the optimal split |
| 85 | + do { |
| 86 | + std::vector<GradientPairPrecise> sorted_hist(node_hist.size()); |
| 87 | + for (size_t i = 0; i < sorted_hist.size(); ++i) { |
| 88 | + sorted_hist[i] = node_hist[sorted_idx_[i]]; |
| 89 | + } |
| 90 | + auto [thresh, score] = enumerate({sorted_hist}, total_gpair_); |
| 91 | + if (score > best_score_) { |
| 92 | + best_score_ = score; |
| 93 | + } |
| 94 | + } while (std::next_permutation(sorted_idx_.begin(), sorted_idx_.end())); |
| 95 | +} |
| 96 | + |
27 | 97 | void TestEvaluateSplits(bool force_read_by_column) {
|
28 | 98 | Context ctx;
|
29 | 99 | ctx.nthread = 4;
|
|
0 commit comments