Skip to content

Commit 3a1b07f

Browse files
committed
Optimize {std,ranges}::distance for segmented iterators
1 parent 9a913a3 commit 3a1b07f

File tree

3 files changed

+166
-31
lines changed

3 files changed

+166
-31
lines changed

libcxx/include/__iterator/distance.h

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@
1010
#ifndef _LIBCPP___ITERATOR_DISTANCE_H
1111
#define _LIBCPP___ITERATOR_DISTANCE_H
1212

13+
#include <__algorithm/for_each_segment.h>
1314
#include <__config>
1415
#include <__iterator/concepts.h>
1516
#include <__iterator/incrementable_traits.h>
1617
#include <__iterator/iterator_traits.h>
18+
#include <__iterator/segmented_iterator.h>
1719
#include <__ranges/access.h>
1820
#include <__ranges/concepts.h>
1921
#include <__ranges/size.h>
2022
#include <__type_traits/decay.h>
23+
#include <__type_traits/enable_if.h>
2124
#include <__type_traits/remove_cvref.h>
2225

2326
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -26,25 +29,40 @@
2629

2730
_LIBCPP_BEGIN_NAMESPACE_STD
2831

29-
template <class _InputIter>
32+
template <class _InputIter, class _Sent>
3033
inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX17 typename iterator_traits<_InputIter>::difference_type
31-
__distance(_InputIter __first, _InputIter __last, input_iterator_tag) {
34+
__distance(_InputIter __first, _Sent __last) {
3235
typename iterator_traits<_InputIter>::difference_type __r(0);
3336
for (; __first != __last; ++__first)
3437
++__r;
3538
return __r;
3639
}
3740

38-
template <class _RandIter>
41+
template <class _RandIter, __enable_if_t<__has_random_access_iterator_category<_RandIter>::value, int> = 0>
3942
inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX17 typename iterator_traits<_RandIter>::difference_type
40-
__distance(_RandIter __first, _RandIter __last, random_access_iterator_tag) {
43+
__distance(_RandIter __first, _RandIter __last) {
4144
return __last - __first;
4245
}
4346

47+
template <class _SegmentedIter,
48+
__enable_if_t<!__has_random_access_iterator_category<_SegmentedIter>::value &&
49+
__is_segmented_iterator<_SegmentedIter>::value,
50+
int> = 0>
51+
inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX17 typename iterator_traits<_SegmentedIter>::difference_type
52+
__distance(_SegmentedIter __first, _SegmentedIter __last) {
53+
typename iterator_traits<_SegmentedIter>::difference_type __r(0);
54+
using _Traits = __segmented_iterator_traits<_SegmentedIter>;
55+
std::__for_each_segment(
56+
__first, __last, [&__r](typename _Traits::__local_iterator __lfirst, typename _Traits::__local_iterator __llast) {
57+
__r += std::__distance(__lfirst, __llast);
58+
});
59+
return __r;
60+
}
61+
4462
template <class _InputIter>
4563
inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX17 typename iterator_traits<_InputIter>::difference_type
4664
distance(_InputIter __first, _InputIter __last) {
47-
return std::__distance(__first, __last, typename iterator_traits<_InputIter>::iterator_category());
65+
return std::__distance(__first, __last);
4866
}
4967

5068
#if _LIBCPP_STD_VER >= 20
@@ -56,12 +74,11 @@ struct __distance {
5674
template <class _Ip, sentinel_for<_Ip> _Sp>
5775
requires(!sized_sentinel_for<_Sp, _Ip>)
5876
_LIBCPP_HIDE_FROM_ABI constexpr iter_difference_t<_Ip> operator()(_Ip __first, _Sp __last) const {
59-
iter_difference_t<_Ip> __n = 0;
60-
while (__first != __last) {
61-
++__first;
62-
++__n;
77+
if constexpr (assignable_from<_Ip&, _Sp> && __is_segmented_iterator<_Ip>::value) {
78+
return std::__distance(__first, std::move(__last));
79+
} else {
80+
return std::__distance(__first, __last);
6381
}
64-
return __n;
6582
}
6683

6784
template <class _Ip, sized_sentinel_for<decay_t<_Ip>> _Sp>
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// UNSUPPORTED: c++03, c++11, c++14, c++17
10+
11+
#include <cstddef>
12+
#include <deque>
13+
#include <iterator>
14+
#include <ranges>
15+
#include <vector>
16+
17+
#include <benchmark/benchmark.h>
18+
19+
int main(int argc, char** argv) {
20+
auto std_distance = [](auto first, auto last) { return std::distance(first, last); };
21+
22+
// {std,ranges}::distance
23+
{
24+
auto bm = []<class Container>(std::string name, auto distance, std::size_t seg_size) {
25+
benchmark::RegisterBenchmark(
26+
name,
27+
[distance, seg_size](auto& st) {
28+
std::size_t const size = st.range(0);
29+
std::size_t const segments = (size + seg_size - 1) / seg_size;
30+
Container c(segments);
31+
for (std::size_t i = 0, n = size; i < segments; ++i, n -= seg_size) {
32+
c[i].resize(std::min(seg_size, n));
33+
}
34+
35+
auto view = c | std::views::join;
36+
auto first = view.begin();
37+
auto last = view.end();
38+
39+
for ([[maybe_unused]] auto _ : st) {
40+
benchmark::DoNotOptimize(c);
41+
auto result = distance(first, last);
42+
benchmark::DoNotOptimize(result);
43+
}
44+
})
45+
->Arg(50) // non power-of-two
46+
->Arg(1024)
47+
->Arg(4096)
48+
->Arg(8192)
49+
->Arg(1 << 14)
50+
->Arg(1 << 16)
51+
->Arg(1 << 18)
52+
->Arg(1 << 20);
53+
};
54+
bm.operator()<std::vector<std::vector<int>>>("std::distance(join_view(vector<vector<int>>))", std_distance, 256);
55+
bm.operator()<std::deque<std::deque<int>>>("std::distance(join_view(deque<deque<int>>))", std_distance, 256);
56+
bm.operator()<std::vector<std::vector<int>>>(
57+
"rng::distance(join_view(vector<vector<int>>)", std::ranges::distance, 256);
58+
bm.operator()<std::deque<std::deque<int>>>(
59+
"rng::distance(join_view(deque<deque<int>>)", std::ranges::distance, 256);
60+
61+
// bm.operator()<std::vector<std::vector<int>>>("std::distance(join_view(vector<vector<int>>))", std_distance, 1024);
62+
// bm.operator()<std::deque<std::deque<int>>>("std::distance(join_view(deque<deque<int>>))", std_distance, 1024);
63+
// bm.operator()<std::vector<std::vector<int>>>("rng::distance(join_view(vector<vector<int>>)", std::ranges::distance, 1024);
64+
// bm.operator()<std::deque<std::deque<int>>>("rng::distance(join_view(deque<deque<int>>)", std::ranges::distance, 1024);
65+
}
66+
67+
benchmark::Initialize(&argc, argv);
68+
benchmark::RunSpecifiedBenchmarks();
69+
benchmark::Shutdown();
70+
return 0;
71+
}

libcxx/test/std/iterators/iterator.primitives/iterator.operations/distance.pass.cpp

Lines changed: 68 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,85 @@
1616
// Iter::difference_type
1717
// distance(Iter first, Iter last); // constexpr in C++17
1818

19-
#include <iterator>
19+
#include <array>
2020
#include <cassert>
21+
#include <deque>
22+
#include <iterator>
23+
#include <vector>
2124
#include <type_traits>
2225

2326
#include "test_macros.h"
2427
#include "test_iterators.h"
2528

2629
template <class It>
27-
TEST_CONSTEXPR_CXX17
28-
void check_distance(It first, It last, typename std::iterator_traits<It>::difference_type dist)
29-
{
30-
typedef typename std::iterator_traits<It>::difference_type Difference;
31-
static_assert(std::is_same<decltype(std::distance(first, last)), Difference>::value, "");
32-
assert(std::distance(first, last) == dist);
30+
TEST_CONSTEXPR_CXX17 void check_distance(It first, It last, typename std::iterator_traits<It>::difference_type dist) {
31+
typedef typename std::iterator_traits<It>::difference_type Difference;
32+
static_assert(std::is_same<decltype(std::distance(first, last)), Difference>::value, "");
33+
assert(std::distance(first, last) == dist);
3334
}
3435

35-
TEST_CONSTEXPR_CXX17 bool tests()
36-
{
37-
const char* s = "1234567890";
38-
check_distance(cpp17_input_iterator<const char*>(s), cpp17_input_iterator<const char*>(s+10), 10);
39-
check_distance(forward_iterator<const char*>(s), forward_iterator<const char*>(s+10), 10);
40-
check_distance(bidirectional_iterator<const char*>(s), bidirectional_iterator<const char*>(s+10), 10);
41-
check_distance(random_access_iterator<const char*>(s), random_access_iterator<const char*>(s+10), 10);
42-
check_distance(s, s+10, 10);
43-
return true;
36+
#if TEST_STD_VER >= 20
37+
template <class It>
38+
TEST_CONSTEXPR_CXX20 void check_ranges_distance(It first, It last, std::iter_difference_t<It> dist) {
39+
using Difference = std::iter_difference_t<It>;
40+
static_assert(std::is_same<decltype(std::ranges::distance(first, last)), Difference>::value, "");
41+
assert(std::ranges::distance(first, last) == dist);
42+
}
43+
#endif
44+
45+
TEST_CONSTEXPR_CXX17 bool tests() {
46+
const char* s = "1234567890";
47+
check_distance(cpp17_input_iterator<const char*>(s), cpp17_input_iterator<const char*>(s + 10), 10);
48+
check_distance(forward_iterator<const char*>(s), forward_iterator<const char*>(s + 10), 10);
49+
check_distance(bidirectional_iterator<const char*>(s), bidirectional_iterator<const char*>(s + 10), 10);
50+
check_distance(random_access_iterator<const char*>(s), random_access_iterator<const char*>(s + 10), 10);
51+
check_distance(s, s + 10, 10);
52+
53+
#if TEST_STD_VER >= 20
54+
check_ranges_distance(forward_iterator(s), forward_iterator(s + 10), 10);
55+
check_ranges_distance(bidirectional_iterator(s), bidirectional_iterator(s + 10), 10);
56+
check_ranges_distance(random_access_iterator(s), random_access_iterator(s + 10), 10);
57+
check_ranges_distance(s, s + 10, 10);
58+
59+
{
60+
using Container = std::vector<std::vector<int>>;
61+
Container c;
62+
auto view = c | std::views::join;
63+
Container::difference_type n = 0;
64+
for (std::size_t i = 0; i < 10; ++i) {
65+
n += i;
66+
c.push_back(Container::value_type(i));
67+
}
68+
assert(std::distance(view.begin(), view.end()) == n);
69+
assert(std::ranges::distance(view.begin(), view.end()) == n);
70+
}
71+
{
72+
using Container = std::array<std::array<char, 3>, 10>;
73+
Container c;
74+
auto view = c | std::views::join;
75+
assert(std::distance(view.begin(), view.end()) == 30);
76+
assert(std::ranges::distance(view.begin(), view.end()) == 30);
77+
}
78+
if (!TEST_IS_CONSTANT_EVALUATED) {
79+
using Container = std::deque<std::deque<double>>;
80+
Container c;
81+
auto view = c | std::views::join;
82+
Container::difference_type n = 0;
83+
for (std::size_t i = 0; i < 10; ++i) {
84+
n += i;
85+
c.push_back(Container::value_type(i));
86+
}
87+
assert(std::distance(view.begin(), view.end()) == n);
88+
assert(std::ranges::distance(view.begin(), view.end()) == n);
89+
}
90+
#endif
91+
return true;
4492
}
4593

46-
int main(int, char**)
47-
{
48-
tests();
94+
int main(int, char**) {
95+
tests();
4996
#if TEST_STD_VER >= 17
50-
static_assert(tests(), "");
97+
static_assert(tests(), "");
5198
#endif
52-
return 0;
99+
return 0;
53100
}

0 commit comments

Comments
 (0)