Skip to content

Commit 688c2f5

Browse files
authored
Use parallel sort for quantile calculation when appropriate. (dmlc#11275)
- Drop the duplicated `omp_in_parallel` check. - Use parallel sort instead of parallel leaf values based on a heuristic.
1 parent 5006fe7 commit 688c2f5

File tree

5 files changed

+60
-31
lines changed

5 files changed

+60
-31
lines changed

src/common/algorithm.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
/**
2-
* Copyright 2022-2023 by XGBoost Contributors
2+
* Copyright 2022-2025, XGBoost Contributors
33
*/
44
#ifndef XGBOOST_COMMON_ALGORITHM_H_
55
#define XGBOOST_COMMON_ALGORITHM_H_
66
#include <algorithm> // upper_bound, stable_sort, sort, max
7-
#include <cinttypes> // size_t
7+
#include <cstddef> // size_t
88
#include <functional> // less
99
#include <iterator> // iterator_traits, distance
1010
#include <vector> // vector
@@ -16,6 +16,9 @@
1616
#if defined(__GNUC__) && (__GNUC__ >= 4) && !defined(__sun) && !defined(sun) && \
1717
!defined(__APPLE__) && __has_include(<omp.h>) && __has_include(<parallel/algorithm>)
1818
#define GCC_HAS_PARALLEL 1
19+
constexpr bool kHasParallelStableSort = true;
20+
#else
21+
constexpr bool kHasParallelStableSort = false;
1922
#endif // GLIC_VERSION
2023

2124
#if defined(_MSC_VER) && !defined(__INTEL_COMPILER)

src/common/stats.h

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2022-2024, XGBoost Contributors
2+
* Copyright 2022-2025, XGBoost Contributors
33
*/
44
#ifndef XGBOOST_COMMON_STATS_H_
55
#define XGBOOST_COMMON_STATS_H_
@@ -32,8 +32,9 @@ namespace common {
3232
*
3333
* \return The result of interpolation.
3434
*/
35-
template <typename Iter>
36-
float Quantile(Context const* ctx, double alpha, Iter const& begin, Iter const& end) {
35+
template <typename Iter,
36+
typename R = std::remove_reference_t<typename std::iterator_traits<Iter>::value_type>>
37+
[[nodiscard]] R Quantile(Context const* ctx, double alpha, Iter const& begin, Iter const& end) {
3738
CHECK(alpha >= 0 && alpha <= 1);
3839
auto n = static_cast<double>(std::distance(begin, end));
3940
if (n == 0) {
@@ -42,15 +43,12 @@ float Quantile(Context const* ctx, double alpha, Iter const& begin, Iter const&
4243

4344
std::vector<std::size_t> sorted_idx(n);
4445
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
45-
if (omp_in_parallel()) {
46-
std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
47-
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
48-
} else {
49-
StableSort(ctx, sorted_idx.begin(), sorted_idx.end(),
50-
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
51-
}
46+
StableSort(ctx, sorted_idx.begin(), sorted_idx.end(),
47+
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
5248

53-
auto val = [&](size_t i) { return *(begin + sorted_idx[i]); };
49+
auto val = [&](size_t i) {
50+
return *(begin + sorted_idx[i]);
51+
};
5452
static_assert(std::is_same_v<decltype(val(0)), float>);
5553

5654
if (alpha <= (1 / (n + 1))) {
@@ -77,23 +75,22 @@ float Quantile(Context const* ctx, double alpha, Iter const& begin, Iter const&
7775
* See https://aakinshin.net/posts/weighted-quantiles/ for some discussions on computing
7876
* weighted quantile with interpolation.
7977
*/
80-
template <typename Iter, typename WeightIter>
81-
float WeightedQuantile(Context const* ctx, double alpha, Iter begin, Iter end, WeightIter w_begin) {
78+
template <typename Iter, typename WeightIter,
79+
typename R = std::remove_reference_t<typename std::iterator_traits<Iter>::value_type>>
80+
[[nodiscard]] R WeightedQuantile(Context const* ctx, double alpha, Iter begin, Iter end,
81+
WeightIter w_begin) {
8282
auto n = static_cast<double>(std::distance(begin, end));
8383
if (n == 0) {
8484
return std::numeric_limits<float>::quiet_NaN();
8585
}
8686
std::vector<size_t> sorted_idx(n);
8787
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
88-
if (omp_in_parallel()) {
89-
std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
90-
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
91-
} else {
92-
StableSort(ctx, sorted_idx.begin(), sorted_idx.end(),
93-
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
94-
}
88+
StableSort(ctx, sorted_idx.begin(), sorted_idx.end(),
89+
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
9590

96-
auto val = [&](size_t i) { return *(begin + sorted_idx[i]); };
91+
auto val = [&](size_t i) {
92+
return *(begin + sorted_idx[i]);
93+
};
9794

9895
std::vector<float> weight_cdf(n); // S_n
9996
// weighted cdf is sorted during construction

src/common/threading_utils.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2019-2024, XGBoost Contributors
2+
* Copyright 2019-2025, XGBoost Contributors
33
*/
44
#ifndef XGBOOST_COMMON_THREADING_UTILS_H_
55
#define XGBOOST_COMMON_THREADING_UTILS_H_
@@ -14,6 +14,7 @@
1414
#include <new> // for bad_alloc
1515
#include <thread> // for thread
1616
#include <type_traits> // for is_signed, conditional_t, is_integral_v, invoke_result_t
17+
#include <utility> // for forward
1718
#include <vector> // for vector
1819

1920
#include "xgboost/logging.h"
@@ -181,7 +182,15 @@ struct Sched {
181182
};
182183

183184
template <typename Index, typename Func>
184-
void ParallelFor(Index size, int32_t n_threads, Sched sched, Func fn) {
185+
void ParallelFor(Index size, std::int32_t n_threads, Sched sched, Func&& fn) {
186+
if (n_threads == 1) {
187+
// early exit
188+
for (Index i = 0; i < size; ++i) {
189+
fn(i);
190+
}
191+
return;
192+
}
193+
185194
#if defined(_MSC_VER)
186195
// msvc doesn't support unsigned integer as openmp index.
187196
using OmpInd = std::conditional_t<std::is_signed<Index>::value, Index, omp_ulong>;
@@ -240,8 +249,8 @@ void ParallelFor(Index size, int32_t n_threads, Sched sched, Func fn) {
240249
}
241250

242251
template <typename Index, typename Func>
243-
void ParallelFor(Index size, int32_t n_threads, Func fn) {
244-
ParallelFor(size, n_threads, Sched::Static(), fn);
252+
void ParallelFor(Index size, std::int32_t n_threads, Func&& fn) {
253+
ParallelFor(size, n_threads, Sched::Static(), std::forward<Func>(fn));
245254
}
246255

247256
inline std::int32_t OmpGetThreadLimit() {

src/objective/adaptive.cc

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2022-2024, XGBoost Contributors
2+
* Copyright 2022-2025, XGBoost Contributors
33
*/
44
#include "adaptive.h"
55

@@ -104,10 +104,29 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
104104
auto h_predt = linalg::MakeTensorView(ctx, predt.ConstHostSpan(), info.num_row_,
105105
predt.Size() / info.num_row_);
106106

107+
// A heuristic to use parallel sort. If we use multiple threads here, the sorting is
108+
// performed using a single thread as openmp cannot allocate new threads inside a
109+
// parallel region.
110+
std::int32_t n_threads;
111+
if constexpr (kHasParallelStableSort) {
112+
CHECK_GE(h_node_ptr.size(), 1);
113+
auto it = common::MakeIndexTransformIter(
114+
[&](std::size_t i) { return h_node_ptr[i + 1] - h_node_ptr[i]; });
115+
n_threads = std::any_of(it, it + h_node_ptr.size() - 1,
116+
[](auto n) {
117+
constexpr std::size_t kNeedParallelSort = 1ul << 19;
118+
return n > kNeedParallelSort;
119+
})
120+
? 1
121+
: ctx->Threads();
122+
} else {
123+
n_threads = ctx->Threads();
124+
}
125+
107126
collective::ApplyWithLabels(
108127
ctx, info, static_cast<void*>(quantiles.data()), quantiles.size() * sizeof(float), [&] {
109128
// loop over each leaf
110-
common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) {
129+
common::ParallelFor(quantiles.size(), n_threads, [&](size_t k) {
111130
auto nidx = h_node_idx[k];
112131
CHECK(tree[nidx].IsLeaf());
113132
CHECK_LT(k + 1, h_node_ptr.size());

tests/cpp/common/test_threading_utils.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
/**
22
* Copyright 2019-2024, XGBoost Contributors
33
*/
4+
#include <dmlc/omp.h> // for omp_in_parallel
45
#include <gtest/gtest.h>
56

6-
#include <cstddef> // std::size_t
7+
#include <cstddef> // for std::size_t
78

89
#include "../../../src/common/threading_utils.h" // BlockedSpace2d,ParallelFor2d,ParallelFor
9-
#include "dmlc/omp.h" // omp_in_parallel
1010
#include "xgboost/context.h" // Context
1111

1212
namespace xgboost::common {
@@ -99,6 +99,7 @@ TEST(ParallelFor, Basic) {
9999
ASSERT_LT(i, n);
100100
});
101101
ASSERT_FALSE(omp_in_parallel());
102+
ParallelFor(n, 1, [&](auto) { ASSERT_FALSE(omp_in_parallel()); });
102103
}
103104

104105
TEST(OmpGetNumThreads, Max) {

0 commit comments

Comments
 (0)