Skip to content

Commit 36bb63e

Browse files
committed
[libc++][test] Add set_intersection complexity validation tests prior to introducing use of one-sided binary search to fast-forward over ranges of elements.
1 parent f6bcf27 commit 36bb63e

File tree

1 file changed

+234
-6
lines changed

1 file changed

+234
-6
lines changed

libcxx/test/std/algorithms/alg.sorting/alg.set.operations/set.intersection/ranges_set_intersection.pass.cpp

Lines changed: 234 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
#include <algorithm>
2929
#include <array>
3030
#include <concepts>
31+
#include <cstddef>
32+
#include <iterator>
33+
#include <type_traits>
3134

3235
#include "almost_satisfies_types.h"
3336
#include "MoveOnly.h"
@@ -93,14 +96,17 @@ static_assert(!HasSetIntersectionRange<UncheckedRange<MoveOnly*>, UncheckedRange
9396

9497
using std::ranges::set_intersection_result;
9598

99+
// TODO: std::ranges::set_intersection calls std::ranges::copy
100+
// std::ranges::copy(contiguous_iterator<int*>, sentinel_wrapper<contiguous_iterator<int*>>, contiguous_iterator<int*>) doesn't seem to work.
101+
// It seems that std::ranges::copy calls std::copy, which unwraps contiguous_iterator<int*> into int*,
102+
// and then it failed because there is no == between int* and sentinel_wrapper<contiguous_iterator<int*>>
103+
template <typename Iter>
104+
using SentinelWorkaround = std::conditional_t<std::contiguous_iterator<Iter>, Iter, sentinel_wrapper<Iter>>;
105+
96106
template <class In1, class In2, class Out, std::size_t N1, std::size_t N2, std::size_t N3>
97107
constexpr void testSetIntersectionImpl(std::array<int, N1> in1, std::array<int, N2> in2, std::array<int, N3> expected) {
98-
// TODO: std::ranges::set_intersection calls std::ranges::copy
99-
// std::ranges::copy(contiguous_iterator<int*>, sentinel_wrapper<contiguous_iterator<int*>>, contiguous_iterator<int*>) doesn't seem to work.
100-
// It seems that std::ranges::copy calls std::copy, which unwraps contiguous_iterator<int*> into int*,
101-
// and then it failed because there is no == between int* and sentinel_wrapper<contiguous_iterator<int*>>
102-
using Sent1 = std::conditional_t<std::contiguous_iterator<In1>, In1, sentinel_wrapper<In1>>;
103-
using Sent2 = std::conditional_t<std::contiguous_iterator<In2>, In2, sentinel_wrapper<In2>>;
108+
using Sent1 = SentinelWorkaround<In1>;
109+
using Sent2 = SentinelWorkaround<In2>;
104110

105111
// iterator overload
106112
{
@@ -272,6 +278,225 @@ constexpr void runAllIteratorPermutationsTests() {
272278
static_assert(withAllPermutationsOfInIter1AndInIter2<contiguous_iterator<int*>>());
273279
}
274280

281+
namespace {
282+
struct [[nodiscard]] OperationCounts {
283+
std::size_t comparisons{};
284+
struct PerInput {
285+
std::size_t proj{};
286+
std::size_t iterator_strides{};
287+
std::ptrdiff_t iterator_displacement{};
288+
289+
// IGNORES proj!
290+
[[nodiscard]] constexpr bool operator==(const PerInput& o) const {
291+
return iterator_strides == o.iterator_strides && iterator_displacement == o.iterator_displacement;
292+
}
293+
294+
[[nodiscard]] constexpr bool matchesExpectation(const PerInput& expect) {
295+
return proj <= expect.proj && iterator_strides <= expect.iterator_strides &&
296+
iterator_displacement <= expect.iterator_displacement;
297+
}
298+
};
299+
std::array<PerInput, 2> in;
300+
301+
[[nodiscard]] constexpr bool matchesExpectation(const OperationCounts& expect) {
302+
return comparisons <= expect.comparisons && in[0].matchesExpectation(expect.in[0]) &&
303+
in[1].matchesExpectation(expect.in[1]);
304+
}
305+
306+
[[nodiscard]] constexpr bool operator==(const OperationCounts& o) const {
307+
return comparisons == o.comparisons && std::ranges::equal(in, o.in);
308+
}
309+
};
310+
} // namespace
311+
312+
#include <iostream>
313+
template <template <class...> class In1,
314+
template <class...>
315+
class In2,
316+
class Out,
317+
std::size_t N1,
318+
std::size_t N2,
319+
std::size_t N3>
320+
constexpr void testSetIntersectionAndReturnOpCounts(
321+
std::array<int, N1> in1,
322+
std::array<int, N2> in2,
323+
std::array<int, N3> expected,
324+
const OperationCounts& expectedOpCounts) {
325+
OperationCounts ops;
326+
327+
const auto comp = [&ops](int x, int y) {
328+
++ops.comparisons;
329+
return x < y;
330+
};
331+
332+
std::array<int, N3> out;
333+
334+
stride_counting_iterator b1(
335+
In1<decltype(in1.begin())>(in1.begin()), &ops.in[0].iterator_strides, &ops.in[0].iterator_displacement);
336+
stride_counting_iterator e1(
337+
In1<decltype(in1.end()) >(in1.end()), &ops.in[0].iterator_strides, &ops.in[0].iterator_displacement);
338+
stride_counting_iterator b2(
339+
In2<decltype(in2.begin())>(in2.begin()), &ops.in[1].iterator_strides, &ops.in[1].iterator_displacement);
340+
stride_counting_iterator e2(
341+
In2<decltype(in2.end()) >(in2.end()), &ops.in[1].iterator_strides, &ops.in[1].iterator_displacement);
342+
343+
std::set_intersection(b1, e1, b2, e2, Out(out.data()), comp);
344+
345+
assert(std::ranges::equal(out, expected));
346+
assert(ops.matchesExpectation(expectedOpCounts));
347+
}
348+
349+
template <template <class...> class In1,
350+
template <class...>
351+
class In2,
352+
class Out,
353+
std::size_t N1,
354+
std::size_t N2,
355+
std::size_t N3>
356+
constexpr void testRangesSetIntersectionAndReturnOpCounts(
357+
std::array<int, N1> in1,
358+
std::array<int, N2> in2,
359+
std::array<int, N3> expected,
360+
const OperationCounts& expectedOpCounts) {
361+
OperationCounts ops;
362+
363+
const auto comp = [&ops](int x, int y) {
364+
++ops.comparisons;
365+
return x < y;
366+
};
367+
368+
const auto proj1 = [&ops](const int& i) {
369+
++ops.in[0].proj;
370+
return i;
371+
};
372+
373+
const auto proj2 = [&ops](const int& i) {
374+
++ops.in[1].proj;
375+
return i;
376+
};
377+
378+
std::array<int, N3> out;
379+
380+
stride_counting_iterator b1(
381+
In1<decltype(in1.begin())>(in1.begin()), &ops.in[0].iterator_strides, &ops.in[0].iterator_displacement);
382+
stride_counting_iterator e1(
383+
In1<decltype(in1.end()) >(in1.end()), &ops.in[0].iterator_strides, &ops.in[0].iterator_displacement);
384+
stride_counting_iterator b2(
385+
In2<decltype(in2.begin())>(in2.begin()), &ops.in[1].iterator_strides, &ops.in[1].iterator_displacement);
386+
stride_counting_iterator e2(
387+
In2<decltype(in2.end()) >(in2.end()), &ops.in[1].iterator_strides, &ops.in[1].iterator_displacement);
388+
389+
std::ranges::subrange r1{b1, SentinelWorkaround<decltype(e1)>{e1}};
390+
std::ranges::subrange r2{b2, SentinelWorkaround<decltype(e2)>{e2}};
391+
std::same_as<set_intersection_result<decltype(e1), decltype(e2), Out>> decltype(auto) result =
392+
std::ranges::set_intersection(r1, r2, Out{out.data()}, comp, proj1, proj2);
393+
assert(std::ranges::equal(out, expected));
394+
assert(base(result.in1) == base(e1));
395+
assert(base(result.in2) == base(e2));
396+
assert(base(result.out) == out.data() + out.size());
397+
assert(ops.matchesExpectation(expectedOpCounts));
398+
}
399+
400+
template <template <typename...> class In1, template <typename...> class In2, class Out>
401+
constexpr void testComplexityParameterizedIter() {
402+
// Worst-case complexity:
403+
// Let N=(last1 - first1) and M=(last2 - first2)
404+
// At most 2*(N+M) - 1 comparisons and applications of each projection.
405+
// At most 2*(N+M) iterator mutations.
406+
{
407+
std::array r1{1, 3, 5, 7, 9, 11, 13, 15, 17, 19};
408+
std::array r2{2, 4, 6, 8, 10, 12, 14, 16, 18, 20};
409+
std::array<int, 0> expected{};
410+
411+
OperationCounts expectedCounts;
412+
expectedCounts.comparisons = 37;
413+
expectedCounts.in[0].proj = 37;
414+
expectedCounts.in[0].iterator_strides = 30;
415+
expectedCounts.in[0].iterator_displacement = 30;
416+
expectedCounts.in[1] = expectedCounts.in[0];
417+
418+
testSetIntersectionAndReturnOpCounts<In1, In2, Out>(r1, r2, expected, expectedCounts);
419+
testRangesSetIntersectionAndReturnOpCounts<In1, In2, Out>(r1, r2, expected, expectedCounts);
420+
}
421+
422+
{
423+
std::array r1{1, 3, 5, 7, 9, 11, 13, 15, 17, 19};
424+
std::array r2{1, 3, 5, 7, 9, 11, 13, 15, 17, 19};
425+
std::array expected{1, 3, 5, 7, 9, 11, 13, 15, 17, 19};
426+
427+
OperationCounts expectedCounts;
428+
expectedCounts.comparisons = 38;
429+
expectedCounts.in[0].proj = 38;
430+
expectedCounts.in[0].iterator_strides = 30;
431+
expectedCounts.in[0].iterator_displacement = 30;
432+
expectedCounts.in[1] = expectedCounts.in[0];
433+
434+
testSetIntersectionAndReturnOpCounts<In1, In2, Out>(r1, r2, expected, expectedCounts);
435+
testRangesSetIntersectionAndReturnOpCounts<In1, In2, Out>(r1, r2, expected, expectedCounts);
436+
}
437+
438+
// Lower complexity when there is low overlap between ranges: we can make 2*log(X) comparisons when one range
439+
// has X elements that can be skipped over.
440+
{
441+
std::array r1{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
442+
std::array r2{15};
443+
std::array expected{15};
444+
445+
OperationCounts expectedCounts;
446+
expectedCounts.comparisons = 8;
447+
expectedCounts.in[0].proj = 8;
448+
expectedCounts.in[0].iterator_strides = 24;
449+
expectedCounts.in[0].iterator_displacement = 24;
450+
expectedCounts.in[1].proj = 8;
451+
expectedCounts.in[1].iterator_strides = 3;
452+
expectedCounts.in[1].iterator_displacement = 3;
453+
454+
testSetIntersectionAndReturnOpCounts<In1, In2, Out>(r1, r2, expected, expectedCounts);
455+
testRangesSetIntersectionAndReturnOpCounts<In1, In2, Out>(r1, r2, expected, expectedCounts);
456+
}
457+
458+
{
459+
std::array r1{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
460+
std::array r2{0, 16};
461+
std::array<int, 0> expected{};
462+
463+
OperationCounts expectedCounts;
464+
expectedCounts.comparisons = 10;
465+
expectedCounts.in[0].proj = 10;
466+
expectedCounts.in[0].iterator_strides = 24;
467+
expectedCounts.in[0].iterator_displacement = 24;
468+
expectedCounts.in[1].proj = 10;
469+
expectedCounts.in[1].iterator_strides = 4;
470+
expectedCounts.in[1].iterator_displacement = 4;
471+
472+
testSetIntersectionAndReturnOpCounts<In1, In2, Out>(r1, r2, expected, expectedCounts);
473+
testRangesSetIntersectionAndReturnOpCounts<In1, In2, Out>(r1, r2, expected, expectedCounts);
474+
}
475+
}
476+
477+
template <template <typename...> class In2, class Out>
478+
constexpr void testComplexityParameterizedIterPermutateIn1() {
479+
//common_input_iterator
480+
testComplexityParameterizedIter<forward_iterator, In2, Out>();
481+
testComplexityParameterizedIter<bidirectional_iterator, In2, Out>();
482+
testComplexityParameterizedIter<random_access_iterator, In2, Out>();
483+
}
484+
485+
template <class Out>
486+
constexpr void testComplexityParameterizedIterPermutateIn1In2() {
487+
testComplexityParameterizedIterPermutateIn1<forward_iterator, Out>();
488+
testComplexityParameterizedIterPermutateIn1<bidirectional_iterator, Out>();
489+
testComplexityParameterizedIterPermutateIn1<random_access_iterator, Out>();
490+
}
491+
492+
constexpr bool testComplexityMultipleTypes() {
493+
//testComplexityParameterizedIter<cpp20_input_iterator, random_access_iterator, OutIter>();
494+
testComplexityParameterizedIterPermutateIn1In2<forward_iterator<int*>>();
495+
testComplexityParameterizedIterPermutateIn1In2<bidirectional_iterator<int*>>();
496+
testComplexityParameterizedIterPermutateIn1In2<random_access_iterator<int*>>();
497+
return true;
498+
}
499+
275500
constexpr bool test() {
276501
// check that every element is copied exactly once
277502
{
@@ -572,5 +797,8 @@ int main(int, char**) {
572797
// than the step limit.
573798
runAllIteratorPermutationsTests();
574799

800+
testComplexityMultipleTypes();
801+
static_assert(testComplexityMultipleTypes());
802+
575803
return 0;
576804
}

0 commit comments

Comments
 (0)