Skip to content

Commit 1780f5b

Browse files
author
Dmitry Razdoburdin
committed
add row-wise processing to PushRowPage
1 parent 7c69d92 commit 1780f5b

File tree

1 file changed

+91
-22
lines changed

1 file changed

+91
-22
lines changed

src/common/quantile.h

Lines changed: 91 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -840,47 +840,116 @@ class SketchContainerImpl {
840840
template <typename Batch, typename IsValid>
841841
void PushRowPageImpl(Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz,
842842
size_t n_features, bool is_dense, IsValid is_valid) {
843-
auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid);
844-
845843
dmlc::OMPException exc;
846-
#pragma omp parallel num_threads(n_threads_)
847-
{
848-
exc.Run([&]() {
849-
auto tid = static_cast<uint32_t>(omp_get_thread_num());
850-
auto const begin = thread_columns_ptr[tid];
851-
auto const end = thread_columns_ptr[tid + 1];
852-
853-
// do not iterate if no columns are assigned to the thread
854-
if (begin < end && end <= n_features) {
855-
for (size_t ridx = 0; ridx < batch.Size(); ++ridx) {
844+
size_t ridx_block_size = batch.Size() / n_threads_ + (batch.Size() % n_threads_ > 0);
845+
size_t min_ridx_block_size = 1024;
846+
if ((n_features < n_threads_) && (ridx_block_size > min_ridx_block_size)) {
847+
/* Row-wise parallelisation.
848+
*/
849+
std::vector<std::set<float>> categories_buff(n_threads_ * n_features);
850+
std::vector<WQSketch> sketches_buff(n_threads_ * n_features);
851+
852+
#pragma omp parallel num_threads(n_threads_)
853+
{
854+
exc.Run([&]() {
855+
auto tid = static_cast<uint32_t>(omp_get_thread_num());
856+
WQSketch* sketches_th = sketches_buff.data() + tid * n_features;
857+
std::set<float>* categories_th = categories_buff.data() + tid * n_features;
858+
859+
for (size_t ii = 0; ii < n_features; ii++) {
860+
auto n_bins = std::min(static_cast<bst_idx_t>(max_bins_), columns_size_[ii]);
861+
auto eps = 1.0 / (static_cast<float>(n_bins) * WQSketch::kFactor);
862+
sketches_th[ii].Init(columns_size_[ii], eps);
863+
}
864+
865+
size_t ridx_begin = tid * ridx_block_size;
866+
size_t ridx_end = std::min(ridx_begin + ridx_block_size, batch.Size());
867+
for (size_t ridx = ridx_begin; ridx < ridx_end; ++ridx) {
856868
auto const &line = batch.GetLine(ridx);
857869
auto w = weights[ridx + base_rowid];
858870
if (is_dense) {
859-
for (size_t ii = begin; ii < end; ii++) {
871+
for (size_t ii = 0; ii < n_features; ii++) {
860872
auto elem = line.GetElement(ii);
861873
if (is_valid(elem)) {
862874
if (IsCat(feature_types_, ii)) {
863-
categories_[ii].emplace(elem.value);
875+
categories_th[ii].emplace(elem.value);
864876
} else {
865-
sketches_[ii].Push(elem.value, w);
877+
sketches_th[ii].Push(elem.value, w);
866878
}
867879
}
868880
}
869881
} else {
870-
for (size_t i = 0; i < line.Size(); ++i) {
871-
auto const &elem = line.GetElement(i);
872-
if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) {
882+
for (size_t ii = 0; ii < line.Size(); ++ii) {
883+
auto elem = line.GetElement(ii);
884+
if (is_valid(elem)) {
873885
if (IsCat(feature_types_, elem.column_idx)) {
874-
categories_[elem.column_idx].emplace(elem.value);
886+
categories_th[elem.column_idx].emplace(elem.value);
875887
} else {
876-
sketches_[elem.column_idx].Push(elem.value, w);
888+
sketches_th[elem.column_idx].Push(elem.value, w);
877889
}
878890
}
879891
}
880892
}
881893
}
882-
}
883-
});
894+
#pragma omp barrier
895+
896+
size_t fidx_block_size = n_features / n_threads_ + (n_features % n_threads_ > 0);
897+
size_t fidx_begin = tid * fidx_block_size;
898+
size_t fidx_end = std::min(fidx_begin + fidx_block_size, n_features);
899+
for (size_t ii = fidx_begin; ii < fidx_end; ++ii) {
900+
for (size_t th = 0; th < n_threads_; ++th) {
901+
if (IsCat(feature_types_, ii)) {
902+
categories_[ii].merge(categories_buff[th * n_features + ii]);
903+
} else {
904+
typename WQSketch::SummaryContainer summary;
905+
sketches_buff[th * n_features + ii].GetSummary(&summary);
906+
sketches_[ii].PushSummary(summary);
907+
}
908+
}
909+
}
910+
});
911+
}
912+
} else {
913+
auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid);
914+
#pragma omp parallel num_threads(n_threads_)
915+
{
916+
exc.Run([&]() {
917+
auto tid = static_cast<uint32_t>(omp_get_thread_num());
918+
auto const begin = thread_columns_ptr[tid];
919+
auto const end = thread_columns_ptr[tid + 1];
920+
921+
// do not iterate if no columns are assigned to the thread
922+
if (begin < end && end <= n_features) {
923+
for (size_t ridx = 0; ridx < batch.Size(); ++ridx) {
924+
auto const &line = batch.GetLine(ridx);
925+
auto w = weights[ridx + base_rowid];
926+
if (is_dense) {
927+
for (size_t ii = begin; ii < end; ii++) {
928+
auto elem = line.GetElement(ii);
929+
if (is_valid(elem)) {
930+
if (IsCat(feature_types_, ii)) {
931+
categories_[ii].emplace(elem.value);
932+
} else {
933+
sketches_[ii].Push(elem.value, w);
934+
}
935+
}
936+
}
937+
} else {
938+
for (size_t i = 0; i < line.Size(); ++i) {
939+
auto const &elem = line.GetElement(i);
940+
if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) {
941+
if (IsCat(feature_types_, elem.column_idx)) {
942+
categories_[elem.column_idx].emplace(elem.value);
943+
} else {
944+
sketches_[elem.column_idx].Push(elem.value, w);
945+
}
946+
}
947+
}
948+
}
949+
}
950+
}
951+
});
952+
}
884953
}
885954
exc.Rethrow();
886955
}

0 commit comments

Comments
 (0)