Skip to content

Commit 5d10506

Browse files
authored
Dev/cpu/push row page optimisation (#67)
* add row-wise processing to PushRowPage * fix * clang tildy --------- Co-authored-by: Dmitry Razdoburdin <>
1 parent 766bfcc commit 5d10506

File tree

1 file changed

+92
-22
lines changed

1 file changed

+92
-22
lines changed

src/common/quantile.h

Lines changed: 92 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -840,47 +840,117 @@ class SketchContainerImpl {
840840
template <typename Batch, typename IsValid>
841841
void PushRowPageImpl(Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz,
842842
size_t n_features, bool is_dense, IsValid is_valid) {
843-
auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid);
844-
845843
dmlc::OMPException exc;
846-
#pragma omp parallel num_threads(n_threads_)
847-
{
848-
exc.Run([&]() {
849-
auto tid = static_cast<uint32_t>(omp_get_thread_num());
850-
auto const begin = thread_columns_ptr[tid];
851-
auto const end = thread_columns_ptr[tid + 1];
852-
853-
// do not iterate if no columns are assigned to the thread
854-
if (begin < end && end <= n_features) {
855-
for (size_t ridx = 0; ridx < batch.Size(); ++ridx) {
844+
size_t ridx_block_size = batch.Size() / n_threads_ + (batch.Size() % n_threads_ > 0);
845+
size_t min_ridx_block_size = 1024;
846+
if ((n_features < static_cast<size_t>(n_threads_)) &&
847+
(ridx_block_size > min_ridx_block_size)) {
848+
/* Row-wise parallelisation.
849+
*/
850+
std::vector<std::set<float>> categories_buff(n_threads_ * n_features);
851+
std::vector<WQSketch> sketches_buff(n_threads_ * n_features);
852+
853+
#pragma omp parallel num_threads(n_threads_)
854+
{
855+
exc.Run([&]() {
856+
auto tid = static_cast<uint32_t>(omp_get_thread_num());
857+
WQSketch* sketches_th = sketches_buff.data() + tid * n_features;
858+
std::set<float>* categories_th = categories_buff.data() + tid * n_features;
859+
860+
for (size_t ii = 0; ii < n_features; ii++) {
861+
auto n_bins = std::min(static_cast<bst_idx_t>(max_bins_), columns_size_[ii]);
862+
auto eps = 1.0 / (static_cast<float>(n_bins) * WQSketch::kFactor);
863+
sketches_th[ii].Init(columns_size_[ii], eps);
864+
}
865+
866+
size_t ridx_begin = tid * ridx_block_size;
867+
size_t ridx_end = std::min(ridx_begin + ridx_block_size, batch.Size());
868+
for (size_t ridx = ridx_begin; ridx < ridx_end; ++ridx) {
856869
auto const &line = batch.GetLine(ridx);
857870
auto w = weights[ridx + base_rowid];
858871
if (is_dense) {
859-
for (size_t ii = begin; ii < end; ii++) {
872+
for (size_t ii = 0; ii < n_features; ii++) {
860873
auto elem = line.GetElement(ii);
861874
if (is_valid(elem)) {
862875
if (IsCat(feature_types_, ii)) {
863-
categories_[ii].emplace(elem.value);
876+
categories_th[ii].emplace(elem.value);
864877
} else {
865-
sketches_[ii].Push(elem.value, w);
878+
sketches_th[ii].Push(elem.value, w);
866879
}
867880
}
868881
}
869882
} else {
870-
for (size_t i = 0; i < line.Size(); ++i) {
871-
auto const &elem = line.GetElement(i);
872-
if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) {
883+
for (size_t ii = 0; ii < line.Size(); ++ii) {
884+
auto elem = line.GetElement(ii);
885+
if (is_valid(elem)) {
873886
if (IsCat(feature_types_, elem.column_idx)) {
874-
categories_[elem.column_idx].emplace(elem.value);
887+
categories_th[elem.column_idx].emplace(elem.value);
875888
} else {
876-
sketches_[elem.column_idx].Push(elem.value, w);
889+
sketches_th[elem.column_idx].Push(elem.value, w);
877890
}
878891
}
879892
}
880893
}
881894
}
882-
}
883-
});
895+
#pragma omp barrier
896+
897+
size_t fidx_block_size = n_features / n_threads_ + (n_features % n_threads_ > 0);
898+
size_t fidx_begin = tid * fidx_block_size;
899+
size_t fidx_end = std::min(fidx_begin + fidx_block_size, n_features);
900+
for (size_t ii = fidx_begin; ii < fidx_end; ++ii) {
901+
for (int th = 0; th < n_threads_; ++th) {
902+
if (IsCat(feature_types_, ii)) {
903+
categories_[ii].merge(categories_buff[th * n_features + ii]);
904+
} else {
905+
typename WQSketch::SummaryContainer summary;
906+
sketches_buff[th * n_features + ii].GetSummary(&summary);
907+
sketches_[ii].PushSummary(summary);
908+
}
909+
}
910+
}
911+
});
912+
}
913+
} else {
914+
auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid);
915+
#pragma omp parallel num_threads(n_threads_)
916+
{
917+
exc.Run([&]() {
918+
auto tid = static_cast<uint32_t>(omp_get_thread_num());
919+
auto const begin = thread_columns_ptr[tid];
920+
auto const end = thread_columns_ptr[tid + 1];
921+
922+
// do not iterate if no columns are assigned to the thread
923+
if (begin < end && end <= n_features) {
924+
for (size_t ridx = 0; ridx < batch.Size(); ++ridx) {
925+
auto const &line = batch.GetLine(ridx);
926+
auto w = weights[ridx + base_rowid];
927+
if (is_dense) {
928+
for (size_t ii = begin; ii < end; ii++) {
929+
auto elem = line.GetElement(ii);
930+
if (is_valid(elem)) {
931+
if (IsCat(feature_types_, ii)) {
932+
categories_[ii].emplace(elem.value);
933+
} else {
934+
sketches_[ii].Push(elem.value, w);
935+
}
936+
}
937+
}
938+
} else {
939+
for (size_t i = 0; i < line.Size(); ++i) {
940+
auto const &elem = line.GetElement(i);
941+
if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) {
942+
if (IsCat(feature_types_, elem.column_idx)) {
943+
categories_[elem.column_idx].emplace(elem.value);
944+
} else {
945+
sketches_[elem.column_idx].Push(elem.value, w);
946+
}
947+
}
948+
}
949+
}
950+
}
951+
}
952+
});
953+
}
884954
}
885955
exc.Rethrow();
886956
}

0 commit comments

Comments
 (0)