Skip to content

Commit 5bd849f

Browse files
razdoburdindmitry.razdoburdintrivialfis
authored
Unify the partitioner for hist and approx.
Co-authored-by: dmitry.razdoburdin <[email protected]> Co-authored-by: jiamingy <[email protected]>
1 parent c69af90 commit 5bd849f

13 files changed

+358
-450
lines changed

src/common/column_matrix.h

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,18 @@ class SparseColumnIter : public Column<BinIdxT> {
103103

104104
template <typename BinIdxT, bool any_missing>
105105
class DenseColumnIter : public Column<BinIdxT> {
106+
public:
107+
using ByteType = bool;
108+
106109
private:
107110
using Base = Column<BinIdxT>;
108111
/* flags for missing values in dense columns */
109-
std::vector<bool> const& missing_flags_;
112+
std::vector<ByteType> const& missing_flags_;
110113
size_t feature_offset_;
111114

112115
public:
113116
explicit DenseColumnIter(common::Span<const BinIdxT> index, bst_bin_t index_base,
114-
std::vector<bool> const& missing_flags, size_t feature_offset)
117+
std::vector<ByteType> const& missing_flags, size_t feature_offset)
115118
: Base{index, index_base}, missing_flags_{missing_flags}, feature_offset_{feature_offset} {}
116119
DenseColumnIter(DenseColumnIter const&) = delete;
117120
DenseColumnIter(DenseColumnIter&&) = default;
@@ -153,6 +156,7 @@ class ColumnMatrix {
153156
}
154157

155158
public:
159+
using ByteType = bool;
156160
// get number of features
157161
bst_feature_t GetNumFeature() const { return static_cast<bst_feature_t>(type_.size()); }
158162

@@ -195,6 +199,8 @@ class ColumnMatrix {
195199
}
196200
}
197201

202+
bool IsInitialized() const { return !type_.empty(); }
203+
198204
/**
199205
* \brief Push batch of data for Quantile DMatrix support.
200206
*
@@ -352,6 +358,13 @@ class ColumnMatrix {
352358

353359
fi->Read(&row_ind_);
354360
fi->Read(&feature_offsets_);
361+
362+
std::vector<std::uint8_t> missing;
363+
fi->Read(&missing);
364+
missing_flags_.resize(missing.size());
365+
std::transform(missing.cbegin(), missing.cend(), missing_flags_.begin(),
366+
[](std::uint8_t flag) { return !!flag; });
367+
355368
index_base_ = index_base;
356369
#if !DMLC_LITTLE_ENDIAN
357370
std::underlying_type<BinTypeSize>::type v;
@@ -386,6 +399,11 @@ class ColumnMatrix {
386399
#endif // !DMLC_LITTLE_ENDIAN
387400
write_vec(row_ind_);
388401
write_vec(feature_offsets_);
402+
// dmlc can not handle bool vector
403+
std::vector<std::uint8_t> missing(missing_flags_.size());
404+
std::transform(missing_flags_.cbegin(), missing_flags_.cend(), missing.begin(),
405+
[](bool flag) { return static_cast<std::uint8_t>(flag); });
406+
write_vec(missing);
389407

390408
#if !DMLC_LITTLE_ENDIAN
391409
auto v = static_cast<std::underlying_type<BinTypeSize>::type>(bins_type_size_);
@@ -413,7 +431,7 @@ class ColumnMatrix {
413431

414432
// index_base_[fid]: least bin id for feature fid
415433
uint32_t const* index_base_;
416-
std::vector<bool> missing_flags_;
434+
std::vector<ByteType> missing_flags_;
417435
BinTypeSize bins_type_size_;
418436
bool any_missing_;
419437
};

src/common/numeric.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#ifndef XGBOOST_COMMON_NUMERIC_H_
55
#define XGBOOST_COMMON_NUMERIC_H_
66

7+
#include <dmlc/common.h> // OMPException
8+
79
#include <algorithm> // std::max
810
#include <iterator> // std::iterator_traits
911
#include <vector>
@@ -106,6 +108,26 @@ inline double Reduce(Context const*, HostDeviceVector<float> const&) {
106108
* \brief Reduction with summation.
107109
*/
108110
double Reduce(Context const* ctx, HostDeviceVector<float> const& values);
111+
112+
template <typename It>
113+
void Iota(Context const* ctx, It first, It last,
114+
typename std::iterator_traits<It>::value_type const& value) {
115+
auto n = std::distance(first, last);
116+
std::int32_t n_threads = ctx->Threads();
117+
const size_t block_size = n / n_threads + !!(n % n_threads);
118+
dmlc::OMPException exc;
119+
#pragma omp parallel num_threads(n_threads)
120+
{
121+
exc.Run([&]() {
122+
const size_t tid = omp_get_thread_num();
123+
const size_t ibegin = tid * block_size;
124+
const size_t iend = std::min(ibegin + block_size, static_cast<size_t>(n));
125+
for (size_t i = ibegin; i < iend; ++i) {
126+
first[i] = i + value;
127+
}
128+
});
129+
}
130+
}
109131
} // namespace common
110132
} // namespace xgboost
111133

src/common/partition_builder.h

Lines changed: 59 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "categorical.h"
1919
#include "column_matrix.h"
20+
#include "../tree/hist/expand_entry.h"
2021
#include "xgboost/generic_parameters.h"
2122
#include "xgboost/tree_model.h"
2223

@@ -107,34 +108,42 @@ class PartitionBuilder {
107108
}
108109

109110
template <typename BinIdxType, bool any_missing, bool any_cat>
110-
void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range,
111+
void Partition(const size_t node_in_set, std::vector<xgboost::tree::CPUExpandEntry> const &nodes,
112+
const common::Range1d range,
111113
const bst_bin_t split_cond, GHistIndexMatrix const& gmat,
112-
const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) {
114+
const common::ColumnMatrix& column_matrix,
115+
const RegTree& tree, const size_t* rid) {
113116
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
114117
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
115118
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
116-
const bst_uint fid = tree[nid].SplitIndex();
117-
const bool default_left = tree[nid].DefaultLeft();
119+
std::size_t nid = nodes[node_in_set].nid;
120+
bst_feature_t fid = tree[nid].SplitIndex();
121+
bool default_left = tree[nid].DefaultLeft();
118122
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
119123
auto node_cats = tree.NodeCats(nid);
120124

121125
auto const& index = gmat.index;
122126
auto const& cut_values = gmat.cut.Values();
123127
auto const& cut_ptrs = gmat.cut.Ptrs();
124128

125-
auto pred = [&](auto ridx, auto bin_id) {
129+
auto gidx_calc = [&](auto ridx) {
130+
auto begin = gmat.RowIdx(ridx);
131+
if (gmat.IsDense()) {
132+
return static_cast<bst_bin_t>(index[begin + fid]);
133+
}
134+
auto end = gmat.RowIdx(ridx + 1);
135+
auto f_begin = cut_ptrs[fid];
136+
auto f_end = cut_ptrs[fid + 1];
137+
// bypassing the column matrix as we need the cut value instead of bin idx for categorical
138+
// features.
139+
return BinarySearchBin(begin, end, index, f_begin, f_end);
140+
};
141+
142+
auto pred_hist = [&](auto ridx, auto bin_id) {
126143
if (any_cat && is_cat) {
127-
auto begin = gmat.RowIdx(ridx);
128-
auto end = gmat.RowIdx(ridx + 1);
129-
auto f_begin = cut_ptrs[fid];
130-
auto f_end = cut_ptrs[fid + 1];
131-
// bypassing the column matrix as we need the cut value instead of bin idx for categorical
132-
// features.
133-
auto gidx = BinarySearchBin(begin, end, index, f_begin, f_end);
134-
bool go_left;
135-
if (gidx == -1) {
136-
go_left = default_left;
137-
} else {
144+
auto gidx = gidx_calc(ridx);
145+
bool go_left = default_left;
146+
if (gidx > -1) {
138147
go_left = Decision(node_cats, cut_values[gidx], default_left);
139148
}
140149
return go_left;
@@ -143,25 +152,43 @@ class PartitionBuilder {
143152
}
144153
};
145154

146-
std::pair<size_t, size_t> child_nodes_sizes;
147-
if (column_matrix.GetColumnType(fid) == xgboost::common::kDenseColumn) {
148-
auto column = column_matrix.DenseColumn<BinIdxType, any_missing>(fid);
149-
if (default_left) {
150-
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
151-
gmat.base_rowid, pred);
152-
} else {
153-
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
154-
gmat.base_rowid, pred);
155+
auto pred_approx = [&](auto ridx) {
156+
auto gidx = gidx_calc(ridx);
157+
bool go_left = default_left;
158+
if (gidx > -1) {
159+
if (is_cat) {
160+
go_left = Decision(node_cats, cut_values[gidx], default_left);
161+
} else {
162+
go_left = cut_values[gidx] <= nodes[node_in_set].split.split_value;
163+
}
155164
}
165+
return go_left;
166+
};
167+
168+
std::pair<size_t, size_t> child_nodes_sizes;
169+
if (!column_matrix.IsInitialized()) {
170+
child_nodes_sizes = PartitionRangeKernel(rid_span, left, right, pred_approx);
156171
} else {
157-
CHECK_EQ(any_missing, true);
158-
auto column = column_matrix.SparseColumn<BinIdxType>(fid, rid_span.front() - gmat.base_rowid);
159-
if (default_left) {
160-
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
161-
gmat.base_rowid, pred);
172+
if (column_matrix.GetColumnType(fid) == xgboost::common::kDenseColumn) {
173+
auto column = column_matrix.DenseColumn<BinIdxType, any_missing>(fid);
174+
if (default_left) {
175+
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
176+
gmat.base_rowid, pred_hist);
177+
} else {
178+
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
179+
gmat.base_rowid, pred_hist);
180+
}
162181
} else {
163-
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
164-
gmat.base_rowid, pred);
182+
CHECK_EQ(any_missing, true);
183+
auto column =
184+
column_matrix.SparseColumn<BinIdxType>(fid, rid_span.front() - gmat.base_rowid);
185+
if (default_left) {
186+
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
187+
gmat.base_rowid, pred_hist);
188+
} else {
189+
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
190+
gmat.base_rowid, pred_hist);
191+
}
165192
}
166193
}
167194

@@ -172,37 +199,6 @@ class PartitionBuilder {
172199
SetNRightElems(node_in_set, range.begin(), n_right);
173200
}
174201

175-
/**
176-
* \brief Partition tree nodes with specific range of row indices.
177-
*
178-
* \tparam Pred Predicate for whether a row should be partitioned to the left node.
179-
*
180-
* \param node_in_set The index of node in current batch of nodes.
181-
* \param nid The canonical node index (node index in the tree).
182-
* \param range The range of input row index.
183-
* \param fidx Feature index.
184-
* \param p_row_set_collection Pointer to rows that are being partitioned.
185-
* \param pred A callback function that returns whether current row should be
186-
* partitioned to the left node, it should accept the row index as
187-
* input and returns a boolean value.
188-
*/
189-
template <typename Pred>
190-
void PartitionRange(const size_t node_in_set, const size_t nid, common::Range1d range,
191-
common::RowSetCollection* p_row_set_collection, Pred pred) {
192-
auto& row_set_collection = *p_row_set_collection;
193-
const size_t* p_ridx = row_set_collection[nid].begin;
194-
common::Span<const size_t> ridx(p_ridx + range.begin(), p_ridx + range.end());
195-
common::Span<size_t> left = this->GetLeftBuffer(node_in_set, range.begin(), range.end());
196-
common::Span<size_t> right = this->GetRightBuffer(node_in_set, range.begin(), range.end());
197-
std::pair<size_t, size_t> child_nodes_sizes = PartitionRangeKernel(ridx, left, right, pred);
198-
199-
const size_t n_left = child_nodes_sizes.first;
200-
const size_t n_right = child_nodes_sizes.second;
201-
202-
this->SetNLeftElems(node_in_set, range.begin(), n_left);
203-
this->SetNRightElems(node_in_set, range.begin(), n_right);
204-
}
205-
206202
// allocate thread local memory, should be called for each specific task
207203
void AllocateForTask(size_t id) {
208204
if (mem_blocks_[id].get() == nullptr) {

0 commit comments

Comments
 (0)