Skip to content

Commit 5620dce

Browse files
committed
[libc++] Add input validation for set_intersection() in debug mode.
The use of one-sided binary search introduced by a066217 changes behaviour on invalid, unsorted input (see llvm#75230 (comment)). Add input validation on `_LIBCPP_HARDENING_MODE_DEBUG` to help users. * Change interface of `__is_sorted_until()` so that it accepts a sentinel that's of a different type than the beginning iterator, and to ensure it won't try to copy the comparison function object. * Add one assertion for each input range confirming that they are sorted. * Stop validating complexity of `set_intersection()` in debug mode, it's hopeless and also not meaningful: there are no complexity guarantees in debug mode, we're happy to trade performance for diagnosability. * Fix bugs in `ranges_robust_against_differing_projections.pass`: we were using an input range as output for `std::ranges::partial_sort_copy()`, and using projections which return the opposite value means that algorithms requiring a sorted range can't use ranges sorted with ascending values if the comparator is `std::ranges::less`. Added `const` where appropriate to make sure we weren't using inputs as outputs in other places.
1 parent ef67664 commit 5620dce

File tree

4 files changed

+61
-49
lines changed

4 files changed

+61
-49
lines changed

libcxx/include/__algorithm/is_sorted_until.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,18 @@
2020

2121
_LIBCPP_BEGIN_NAMESPACE_STD
2222

23-
template <class _Compare, class _ForwardIterator>
23+
template <class _Compare, class _ForwardIterator, class _Sent>
2424
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _ForwardIterator
25-
__is_sorted_until(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) {
25+
__is_sorted_until(_ForwardIterator __first, _Sent __last, _Compare&& __comp) {
2626
if (__first != __last) {
2727
_ForwardIterator __i = __first;
28-
while (++__i != __last) {
29-
if (__comp(*__i, *__first))
30-
return __i;
31-
__first = __i;
28+
while (++__first != __last) {
29+
if (__comp(*__first, *__i))
30+
return __first;
31+
__i = __first;
3232
}
3333
}
34-
return __last;
34+
return __first;
3535
}
3636

3737
template <class _ForwardIterator, class _Compare>

libcxx/include/__algorithm/set_intersection.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111

1212
#include <__algorithm/comp.h>
1313
#include <__algorithm/comp_ref_type.h>
14+
#include <__algorithm/is_sorted_until.h>
1415
#include <__algorithm/iterator_operations.h>
1516
#include <__algorithm/lower_bound.h>
17+
#include <__assert>
1618
#include <__config>
1719
#include <__functional/identity.h>
1820
#include <__iterator/iterator_traits.h>
1921
#include <__iterator/next.h>
22+
#include <__type_traits/is_constant_evaluated.h>
2023
#include <__type_traits/is_same.h>
2124
#include <__utility/exchange.h>
2225
#include <__utility/move.h>
@@ -95,6 +98,14 @@ __set_intersection(
9598
_Compare&& __comp,
9699
std::forward_iterator_tag,
97100
std::forward_iterator_tag) {
101+
#if _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG
102+
if (!__libcpp_is_constant_evaluated()) {
103+
_LIBCPP_ASSERT_INTERNAL(
104+
std::__is_sorted_until(__first1, __last1, __comp) == __last1, "set_intersection: input range 1 must be sorted");
105+
_LIBCPP_ASSERT_INTERNAL(
106+
std::__is_sorted_until(__first2, __last2, __comp) == __last2, "set_intersection: input range 2 must be sorted");
107+
}
108+
#endif
98109
_LIBCPP_CONSTEXPR std::__identity __proj;
99110
bool __prev_may_be_equal = false;
100111

libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/set_intersection_complexity.pass.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -43,33 +43,32 @@
4343

4444
#include "test_iterators.h"
4545

46-
namespace {
47-
48-
// __debug_less will perform an additional comparison in an assertion
49-
static constexpr unsigned std_less_comparison_count_multiplier() noexcept {
50-
#if _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG
51-
return 2;
46+
// debug mode provides no complexity guarantees, testing them would be a waste of effort
47+
// but we still want to run this test, to ensure we don't trigger any assertions
48+
#ifdef _LIBCPP_HARDENING_MODE_DEBUG
49+
# define ASSERT_COMPLEXITY(expression)
5250
#else
53-
return 1;
51+
# define ASSERT_COMPLEXITY(expression) assert(expression)
5452
#endif
55-
}
53+
54+
namespace {
5655

5756
struct [[nodiscard]] OperationCounts {
5857
std::size_t comparisons{};
5958
struct PerInput {
6059
std::size_t proj{};
6160
IteratorOpCounts iterops;
6261

63-
[[nodiscard]] constexpr bool isNotBetterThan(const PerInput& other) {
62+
[[nodiscard]] constexpr bool isNotBetterThan(const PerInput& other) const noexcept {
6463
return proj >= other.proj && iterops.increments + iterops.decrements + iterops.zero_moves >=
6564
other.iterops.increments + other.iterops.decrements + other.iterops.zero_moves;
6665
}
6766
};
6867
std::array<PerInput, 2> in;
6968

70-
[[nodiscard]] constexpr bool isNotBetterThan(const OperationCounts& expect) {
71-
return std_less_comparison_count_multiplier() * comparisons >= expect.comparisons &&
72-
in[0].isNotBetterThan(expect.in[0]) && in[1].isNotBetterThan(expect.in[1]);
69+
[[nodiscard]] constexpr bool isNotBetterThan(const OperationCounts& expect) const noexcept {
70+
return comparisons >= expect.comparisons && in[0].isNotBetterThan(expect.in[0]) &&
71+
in[1].isNotBetterThan(expect.in[1]);
7372
}
7473
};
7574

@@ -80,16 +79,17 @@ struct counted_set_intersection_result {
8079

8180
constexpr counted_set_intersection_result() = default;
8281

83-
constexpr explicit counted_set_intersection_result(std::array<int, ResultSize>&& contents) : result{contents} {}
82+
constexpr explicit counted_set_intersection_result(std::array<int, ResultSize>&& contents) noexcept
83+
: result{contents} {}
8484

85-
constexpr void assertNotBetterThan(const counted_set_intersection_result& other) {
85+
constexpr void assertNotBetterThan(const counted_set_intersection_result& other) const noexcept {
8686
assert(result == other.result);
87-
assert(opcounts.isNotBetterThan(other.opcounts));
87+
ASSERT_COMPLEXITY(opcounts.isNotBetterThan(other.opcounts));
8888
}
8989
};
9090

9191
template <std::size_t ResultSize>
92-
counted_set_intersection_result(std::array<int, ResultSize>) -> counted_set_intersection_result<ResultSize>;
92+
counted_set_intersection_result(std::array<int, ResultSize>) noexcept -> counted_set_intersection_result<ResultSize>;
9393

9494
template <template <class...> class InIterType1,
9595
template <class...>
@@ -306,7 +306,7 @@ constexpr bool testComplexityBasic() {
306306
std::array<int, 5> r2{2, 4, 6, 8, 10};
307307
std::array<int, 0> expected{};
308308

309-
const std::size_t maxOperation = std_less_comparison_count_multiplier() * (2 * (r1.size() + r2.size()) - 1);
309+
[[maybe_unused]] const std::size_t maxOperation = 2 * (r1.size() + r2.size()) - 1;
310310

311311
// std::set_intersection
312312
{
@@ -321,7 +321,7 @@ constexpr bool testComplexityBasic() {
321321
std::set_intersection(r1.begin(), r1.end(), r2.begin(), r2.end(), out.data(), comp);
322322

323323
assert(std::ranges::equal(out, expected));
324-
assert(numberOfComp <= maxOperation);
324+
ASSERT_COMPLEXITY(numberOfComp <= maxOperation);
325325
}
326326

327327
// ranges::set_intersection iterator overload
@@ -349,9 +349,9 @@ constexpr bool testComplexityBasic() {
349349
std::ranges::set_intersection(r1.begin(), r1.end(), r2.begin(), r2.end(), out.data(), comp, proj1, proj2);
350350

351351
assert(std::ranges::equal(out, expected));
352-
assert(numberOfComp <= maxOperation);
353-
assert(numberOfProj1 <= maxOperation);
354-
assert(numberOfProj2 <= maxOperation);
352+
ASSERT_COMPLEXITY(numberOfComp <= maxOperation);
353+
ASSERT_COMPLEXITY(numberOfProj1 <= maxOperation);
354+
ASSERT_COMPLEXITY(numberOfProj2 <= maxOperation);
355355
}
356356

357357
// ranges::set_intersection range overload
@@ -379,9 +379,9 @@ constexpr bool testComplexityBasic() {
379379
std::ranges::set_intersection(r1, r2, out.data(), comp, proj1, proj2);
380380

381381
assert(std::ranges::equal(out, expected));
382-
assert(numberOfComp < maxOperation);
383-
assert(numberOfProj1 < maxOperation);
384-
assert(numberOfProj2 < maxOperation);
382+
ASSERT_COMPLEXITY(numberOfComp < maxOperation);
383+
ASSERT_COMPLEXITY(numberOfProj1 < maxOperation);
384+
ASSERT_COMPLEXITY(numberOfProj2 < maxOperation);
385385
}
386386
return true;
387387
}

libcxx/test/std/algorithms/ranges_robust_against_differing_projections.pass.cpp

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,20 @@ constexpr bool test_all() {
4040
constexpr auto operator<=>(const A&) const = default;
4141
};
4242

43-
std::array in = {1, 2, 3};
44-
std::array in2 = {A{4}, A{5}, A{6}};
43+
const std::array in = {1, 2, 3};
44+
const std::array in2 = {A{4}, A{5}, A{6}};
4545

4646
std::array output = {7, 8, 9, 10, 11, 12};
4747
auto out = output.begin();
4848
std::array output2 = {A{7}, A{8}, A{9}, A{10}, A{11}, A{12}};
4949
auto out2 = output2.begin();
5050

51-
std::ranges::equal_to eq;
52-
std::ranges::less less;
53-
auto sum = [](int lhs, A rhs) { return lhs + rhs.x; };
54-
auto proj1 = [](int x) { return x * -1; };
55-
auto proj2 = [](A a) { return a.x * -1; };
51+
const std::ranges::equal_to eq;
52+
const std::ranges::less less;
53+
const std::ranges::greater greater;
54+
const auto sum = [](int lhs, A rhs) { return lhs + rhs.x; };
55+
const auto proj1 = [](int x) { return x * -1; };
56+
const auto proj2 = [](A a) { return a.x * -1; };
5657

5758
#if TEST_STD_VER >= 23
5859
test(std::ranges::ends_with, in, in2, eq, proj1, proj2);
@@ -67,17 +68,17 @@ constexpr bool test_all() {
6768
test(std::ranges::find_end, in, in2, eq, proj1, proj2);
6869
test(std::ranges::transform, in, in2, out, sum, proj1, proj2);
6970
test(std::ranges::transform, in, in2, out2, sum, proj1, proj2);
70-
test(std::ranges::partial_sort_copy, in, in2, less, proj1, proj2);
71-
test(std::ranges::merge, in, in2, out, less, proj1, proj2);
72-
test(std::ranges::merge, in, in2, out2, less, proj1, proj2);
73-
test(std::ranges::set_intersection, in, in2, out, less, proj1, proj2);
74-
test(std::ranges::set_intersection, in, in2, out2, less, proj1, proj2);
75-
test(std::ranges::set_difference, in, in2, out, less, proj1, proj2);
76-
test(std::ranges::set_difference, in, in2, out2, less, proj1, proj2);
77-
test(std::ranges::set_symmetric_difference, in, in2, out, less, proj1, proj2);
78-
test(std::ranges::set_symmetric_difference, in, in2, out2, less, proj1, proj2);
79-
test(std::ranges::set_union, in, in2, out, less, proj1, proj2);
80-
test(std::ranges::set_union, in, in2, out2, less, proj1, proj2);
71+
test(std::ranges::partial_sort_copy, in, output, less, proj1, proj2);
72+
test(std::ranges::merge, in, in2, out, greater, proj1, proj2);
73+
test(std::ranges::merge, in, in2, out2, greater, proj1, proj2);
74+
test(std::ranges::set_intersection, in, in2, out, greater, proj1, proj2);
75+
test(std::ranges::set_intersection, in, in2, out2, greater, proj1, proj2);
76+
test(std::ranges::set_difference, in, in2, out, greater, proj1, proj2);
77+
test(std::ranges::set_difference, in, in2, out2, greater, proj1, proj2);
78+
test(std::ranges::set_symmetric_difference, in, in2, out, greater, proj1, proj2);
79+
test(std::ranges::set_symmetric_difference, in, in2, out2, greater, proj1, proj2);
80+
test(std::ranges::set_union, in, in2, out, greater, proj1, proj2);
81+
test(std::ranges::set_union, in, in2, out2, greater, proj1, proj2);
8182
#if TEST_STD_VER > 20
8283
test(std::ranges::starts_with, in, in2, eq, proj1, proj2);
8384
#endif

0 commit comments

Comments
 (0)