Skip to content

Commit 755f6c9

Browse files
<algorithm>: Rework internal buffers for stable sorting algorithms (#5807)
Co-authored-by: Stephan T. Lavavej <[email protected]>
1 parent b439695 commit 755f6c9

File tree

5 files changed

+150
-24
lines changed

5 files changed

+150
-24
lines changed

stl/inc/algorithm

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -449,17 +449,17 @@ constexpr ptrdiff_t _Temporary_buffer_size(const _Diff _Value) noexcept {
449449
}
450450

451451
template <class _Ty>
452-
struct _Optimistic_temporary_buffer { // temporary storage with _alloca-like attempt
452+
struct _Optimistic_temporary_buffer2 { // temporary storage with _alloca-like attempt
453453
static constexpr size_t _Optimistic_size = 4096; // default to ~1 page
454454
static constexpr size_t _Optimistic_count = (_STD max) (static_cast<size_t>(1), _Optimistic_size / sizeof(_Ty));
455455

456456
template <class _Diff>
457-
explicit _Optimistic_temporary_buffer(const _Diff _Requested_size) noexcept { // get temporary storage
457+
explicit _Optimistic_temporary_buffer2(const _Diff _Requested_size) noexcept { // get temporary storage
458458
const auto _Attempt = _Temporary_buffer_size(_Requested_size);
459459
// Since _Diff is a count of elements in a forward range, and forward iterators must denote objects in memory,
460460
// it must fit in a size_t.
461461
if (static_cast<size_t>(_Requested_size) <= _Optimistic_count) { // unconditionally engage stack space
462-
_Data = reinterpret_cast<_Ty*>(&_Stack_space[0]);
462+
_Data = reinterpret_cast<_Ty*>(_Stack_space);
463463
_Capacity = static_cast<ptrdiff_t>(_Requested_size); // in bounds due to if condition
464464
return;
465465
}
@@ -473,22 +473,22 @@ struct _Optimistic_temporary_buffer { // temporary storage with _alloca-like att
473473

474474
// less heap space than stack space, give up and use stack instead
475475
_STD _Return_temporary_buffer(_Raw.first);
476-
_Data = reinterpret_cast<_Ty*>(&_Stack_space[0]);
476+
_Data = reinterpret_cast<_Ty*>(_Stack_space);
477477
_Capacity = _Optimistic_count;
478478
}
479479

480-
_Optimistic_temporary_buffer(const _Optimistic_temporary_buffer&) = delete;
481-
_Optimistic_temporary_buffer& operator=(const _Optimistic_temporary_buffer&) = delete;
480+
_Optimistic_temporary_buffer2(const _Optimistic_temporary_buffer2&) = delete;
481+
_Optimistic_temporary_buffer2& operator=(const _Optimistic_temporary_buffer2&) = delete;
482482

483-
~_Optimistic_temporary_buffer() noexcept {
483+
~_Optimistic_temporary_buffer2() noexcept {
484484
if (static_cast<size_t>(_Capacity) > _Optimistic_count) {
485485
_STD _Return_temporary_buffer(_Data);
486486
}
487487
}
488488

489489
_Ty* _Data; // points to heap memory iff _Capacity > _Optimistic_count
490490
ptrdiff_t _Capacity;
491-
_Aligned_storage_t<sizeof(_Ty), alignof(_Ty)> _Stack_space[_Optimistic_count];
491+
alignas(_Ty) unsigned char _Stack_space[sizeof(_Ty) * _Optimistic_count];
492492
};
493493

494494
#if _HAS_CXX20
@@ -7110,7 +7110,7 @@ _BidIt _Stable_partition_unchecked(_BidIt _First, _BidIt _Last, _Pr _Pred) {
71107110
using _Diff = _Iter_diff_t<_BidIt>;
71117111
const _Diff _Temp_count = _STD distance(_First, _Last); // _Total_count - 1 since we never need to store *_Last
71127112
const _Diff _Total_count = _Temp_count + static_cast<_Diff>(1);
7113-
_Optimistic_temporary_buffer<_Iter_value_t<_BidIt>> _Temp_buf{_Temp_count};
7113+
_Optimistic_temporary_buffer2<_Iter_value_t<_BidIt>> _Temp_buf{_Temp_count};
71147114
return _STD _Stable_partition_unchecked1(_First, _Last, _Pred, _Total_count, _Temp_buf._Data, _Temp_buf._Capacity)
71157115
.first;
71167116
}
@@ -7229,7 +7229,7 @@ namespace ranges {
72297229
} while (!_STD invoke(_Pred, _STD invoke(_Proj, *_Last)));
72307230

72317231
const iter_difference_t<_It> _Temp_count = _RANGES distance(_First, _Last);
7232-
_Optimistic_temporary_buffer<iter_value_t<_It>> _Temp_buf{_Temp_count};
7232+
_Optimistic_temporary_buffer2<iter_value_t<_It>> _Temp_buf{_Temp_count};
72337233

72347234
// _Temp_count + 1 since we work on closed ranges
72357235
const auto _Total_count = static_cast<iter_difference_t<_It>>(_Temp_count + 1);
@@ -8442,7 +8442,7 @@ void inplace_merge(_BidIt _First, _BidIt _Mid, _BidIt _Last, _Pr _Pred) {
84428442
}
84438443

84448444
const _Diff _Count2 = _STD distance(_UMid, _ULast);
8445-
_Optimistic_temporary_buffer<_Iter_value_t<_BidIt>> _Temp_buf{(_STD min) (_Count1, _Count2)};
8445+
_Optimistic_temporary_buffer2<_Iter_value_t<_BidIt>> _Temp_buf{(_STD min) (_Count1, _Count2)};
84468446
_STD _Buffered_inplace_merge_unchecked_impl(
84478447
_UFirst, _UMid, _ULast, _Count1, _Count2, _Temp_buf._Data, _Temp_buf._Capacity, _STD _Pass_fn(_Pred));
84488448
}
@@ -8793,7 +8793,7 @@ namespace ranges {
87938793
}
87948794

87958795
const iter_difference_t<_It> _Count2 = _RANGES distance(_Mid, _Last);
8796-
_Optimistic_temporary_buffer<iter_value_t<_It>> _Temp_buf{(_STD min) (_Count1, _Count2)};
8796+
_Optimistic_temporary_buffer2<iter_value_t<_It>> _Temp_buf{(_STD min) (_Count1, _Count2)};
87978797
if (_Count1 <= _Count2 && _Count1 <= _Temp_buf._Capacity) {
87988798
_RANGES _Inplace_merge_buffer_left(_STD move(_First), _STD move(_Mid), _STD move(_Last),
87998799
_Temp_buf._Data, _Temp_buf._Capacity, _Pred, _Proj);
@@ -9403,7 +9403,7 @@ void stable_sort(const _BidIt _First, const _BidIt _Last, _Pr _Pred) {
94039403
return;
94049404
}
94059405

9406-
_Optimistic_temporary_buffer<_Iter_value_t<_BidIt>> _Temp_buf{_Count - _Count / 2};
9406+
_Optimistic_temporary_buffer2<_Iter_value_t<_BidIt>> _Temp_buf{_Count - _Count / 2};
94079407
_STD _Stable_sort_unchecked(_UFirst, _ULast, _Count, _Temp_buf._Data, _Temp_buf._Capacity, _STD _Pass_fn(_Pred));
94089408
}
94099409

@@ -9470,7 +9470,7 @@ namespace ranges {
94709470
return;
94719471
}
94729472

9473-
_Optimistic_temporary_buffer<iter_value_t<_It>> _Temp_buf{_Count - _Count / 2};
9473+
_Optimistic_temporary_buffer2<iter_value_t<_It>> _Temp_buf{_Count - _Count / 2};
94749474
_Stable_sort_common_buffered(
94759475
_STD move(_First), _STD move(_Last), _Count, _Temp_buf._Data, _Temp_buf._Capacity, _Pred, _Proj);
94769476
}

stl/inc/execution

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2785,14 +2785,14 @@ void sort(_ExPo&&, const _RanIt _First, const _RanIt _Last, _Pr _Pred) noexcept
27852785
}
27862786

27872787
template <class _Ty>
2788-
struct _Static_partitioned_temporary_buffer2 {
2789-
_Optimistic_temporary_buffer<_Ty>& _Temp_buf;
2788+
struct _Static_partitioned_temporary_buffer3 {
2789+
_Optimistic_temporary_buffer2<_Ty>& _Temp_buf;
27902790
ptrdiff_t _Chunk_size;
27912791
ptrdiff_t _Unchunked_items;
27922792

27932793
template <class _Diff>
2794-
explicit _Static_partitioned_temporary_buffer2(
2795-
_Optimistic_temporary_buffer<_Ty>& _Temp_buf_raw, _Static_partition_team<_Diff>& _Team)
2794+
explicit _Static_partitioned_temporary_buffer3(
2795+
_Optimistic_temporary_buffer2<_Ty>& _Temp_buf_raw, _Static_partition_team<_Diff>& _Team)
27962796
: _Temp_buf(_Temp_buf_raw), _Chunk_size(static_cast<ptrdiff_t>(_Temp_buf._Capacity / _Team._Chunks)),
27972797
_Unchunked_items(static_cast<ptrdiff_t>(_Temp_buf._Capacity % _Team._Chunks)) {}
27982798

@@ -2899,15 +2899,15 @@ struct _Bottom_up_tree_visitor {
28992899
};
29002900

29012901
template <class _BidIt, class _Pr>
2902-
struct _Static_partitioned_stable_sort3 {
2902+
struct _Static_partitioned_stable_sort4 {
29032903
using _Diff = _Iter_diff_t<_BidIt>;
29042904
_Static_partition_team<_Diff> _Team;
29052905
_Static_partition_range<_BidIt> _Basis;
29062906
_Bottom_up_merge_tree _Merge_tree;
2907-
_Static_partitioned_temporary_buffer2<_Iter_value_t<_BidIt>> _Temp_buf;
2907+
_Static_partitioned_temporary_buffer3<_Iter_value_t<_BidIt>> _Temp_buf;
29082908
_Pr _Pred;
29092909

2910-
_Static_partitioned_stable_sort3(_Optimistic_temporary_buffer<_Iter_value_t<_BidIt>>& _Temp_buf_raw,
2910+
_Static_partitioned_stable_sort4(_Optimistic_temporary_buffer2<_Iter_value_t<_BidIt>>& _Temp_buf_raw,
29112911
const _Diff _Count, const size_t _Merge_tree_height_, const _BidIt _First, _Pr _Pred_)
29122912
: _Team(_Count, static_cast<size_t>(1) << _Merge_tree_height_), _Basis{}, _Merge_tree(_Merge_tree_height_),
29132913
_Temp_buf(_Temp_buf_raw, _Team), _Pred{_Pred_} {
@@ -3001,7 +3001,7 @@ struct _Static_partitioned_stable_sort3 {
30013001

30023002
static void __stdcall _Threadpool_callback(
30033003
__std_PTP_CALLBACK_INSTANCE, void* const _Context, __std_PTP_WORK) noexcept /* terminates */ {
3004-
_STD _Run_available_chunked_work(*static_cast<_Static_partitioned_stable_sort3*>(_Context));
3004+
_STD _Run_available_chunked_work(*static_cast<_Static_partitioned_stable_sort4*>(_Context));
30053005
}
30063006
};
30073007

@@ -3027,14 +3027,14 @@ void stable_sort(_ExPo&&, const _BidIt _First, const _BidIt _Last, _Pr _Pred) no
30273027
_Attempt_parallelism = false;
30283028
}
30293029

3030-
_Optimistic_temporary_buffer<_Iter_value_t<_BidIt>> _Temp_buf{_Attempt_parallelism ? _Count : _Count - _Count / 2};
3030+
_Optimistic_temporary_buffer2<_Iter_value_t<_BidIt>> _Temp_buf{_Attempt_parallelism ? _Count : _Count - _Count / 2};
30313031
if constexpr (remove_reference_t<_ExPo>::_Parallelize) {
30323032
if (_Attempt_parallelism) {
30333033
// forward+ iterator overflow assumption for size_t cast
30343034
const auto _Tree_height = _Get_stable_sort_tree_height(static_cast<size_t>(_Count), _Hw_threads);
30353035
if (_Tree_height != 0) {
30363036
_TRY_BEGIN
3037-
_Static_partitioned_stable_sort3 _Operation{
3037+
_Static_partitioned_stable_sort4 _Operation{
30383038
_Temp_buf, _Count, _Tree_height, _UFirst, _STD _Pass_fn(_Pred)};
30393039
_STD _Run_chunked_parallel_work(_Hw_threads, _Operation);
30403040
return;

tests/std/test.lst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ tests\GH_005472_do_not_overlap
272272
tests\GH_005546_containers_size_type_cast
273273
tests\GH_005553_regex_character_translation
274274
tests\GH_005768_pow_accuracy
275+
tests\GH_005800_stable_sort_large_alignment
275276
tests\LWG2381_num_get_floating_point
276277
tests\LWG2510_tag_classes
277278
tests\LWG2597_complex_branch_cut
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
RUNALL_INCLUDE ..\impure_matrix.lst
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#pragma warning(disable : 6262) // Function uses '16388' bytes of stack.
5+
6+
#include <algorithm>
7+
#include <array>
8+
#include <cassert>
9+
#include <cstddef>
10+
#include <cstdint>
11+
#include <iterator>
12+
13+
#if _HAS_CXX17
14+
#include <execution>
15+
#endif // _HAS_CXX17
16+
17+
using namespace std;
18+
19+
template <size_t N>
20+
struct alignas(N) large_element {
21+
array<unsigned char, N> elems;
22+
23+
#if _HAS_CXX20
24+
friend auto operator<=>(const large_element&, const large_element&) = default;
25+
#else // ^^^ _HAS_CXX20 / !_HAS_CXX20 vvv
26+
friend bool operator==(const large_element& lhs, const large_element& rhs) {
27+
return lhs.elems == rhs.elems;
28+
}
29+
30+
friend bool operator!=(const large_element& lhs, const large_element& rhs) {
31+
return lhs.elems != rhs.elems;
32+
}
33+
34+
friend bool operator<(const large_element& lhs, const large_element& rhs) {
35+
return lhs.elems < rhs.elems;
36+
}
37+
38+
friend bool operator>(const large_element& lhs, const large_element& rhs) {
39+
return lhs.elems > rhs.elems;
40+
}
41+
42+
friend bool operator<=(const large_element& lhs, const large_element& rhs) {
43+
return lhs.elems <= rhs.elems;
44+
}
45+
46+
friend bool operator>=(const large_element& lhs, const large_element& rhs) {
47+
return lhs.elems >= rhs.elems;
48+
}
49+
#endif // ^^^ !_HAS_CXX20 ^^^
50+
};
51+
52+
struct alignment_verifying_less {
53+
template <class T, class U>
54+
bool operator()(const T& t, const U& u) const {
55+
assert(reinterpret_cast<uintptr_t>(&t) % alignof(T) == 0);
56+
assert(reinterpret_cast<uintptr_t>(&u) % alignof(U) == 0);
57+
return t < u;
58+
}
59+
};
60+
61+
struct alignment_verifying_truth {
62+
template <class T>
63+
bool operator()(const T& t) const {
64+
assert(reinterpret_cast<uintptr_t>(&t) % alignof(T) == 0);
65+
return true;
66+
}
67+
};
68+
69+
template <size_t N>
70+
void test() {
71+
{
72+
large_element<N> arr[2]{};
73+
74+
stable_sort(begin(arr), end(arr), alignment_verifying_less{});
75+
stable_partition(begin(arr), end(arr), alignment_verifying_truth{});
76+
inplace_merge(begin(arr), begin(arr), end(arr), alignment_verifying_less{});
77+
}
78+
79+
#if _HAS_CXX17
80+
auto test_execution = [](const auto& execpol) {
81+
large_element<N> arr[2]{};
82+
83+
stable_sort(execpol, begin(arr), end(arr), alignment_verifying_less{});
84+
stable_partition(execpol, begin(arr), end(arr), alignment_verifying_truth{});
85+
inplace_merge(execpol, begin(arr), begin(arr), end(arr), alignment_verifying_less{});
86+
};
87+
test_execution(execution::seq);
88+
test_execution(execution::par);
89+
test_execution(execution::par_unseq);
90+
#if _HAS_CXX20
91+
test_execution(execution::unseq);
92+
#endif // _HAS_CXX20
93+
#endif // _HAS_CXX17
94+
95+
#if _HAS_CXX20
96+
{
97+
large_element<N> arr[2]{};
98+
99+
ranges::stable_sort(arr, alignment_verifying_less{});
100+
ranges::stable_partition(arr, alignment_verifying_truth{});
101+
ranges::inplace_merge(arr, ranges::begin(arr), alignment_verifying_less{});
102+
}
103+
#endif // _HAS_CXX20
104+
}
105+
106+
int main() {
107+
test<1>();
108+
test<2>();
109+
test<4>();
110+
test<8>();
111+
test<16>();
112+
test<32>();
113+
test<64>();
114+
test<128>();
115+
test<256>();
116+
test<512>();
117+
test<1024>();
118+
test<2048>();
119+
test<4096>();
120+
test<8192>();
121+
}

0 commit comments

Comments
 (0)