Skip to content

Commit c8f7ebc

Browse files
philnik777mahesh-attarde
authored andcommitted
[libc++] Vectorize std::find (llvm#156431)
``` Apple M4: ----------------------------------------------------------------------------- Benchmark old new ----------------------------------------------------------------------------- std::find(vector<char>) (bail 25%)/8 1.43 ns 1.44 ns std::find(vector<char>) (bail 25%)/1024 5.54 ns 5.59 ns std::find(vector<char>) (bail 25%)/8192 38.4 ns 39.1 ns std::find(vector<char>) (bail 25%)/32768 134 ns 136 ns std::find(vector<int>) (bail 25%)/8 1.56 ns 1.57 ns std::find(vector<int>) (bail 25%)/1024 65.3 ns 65.4 ns std::find(vector<int>) (bail 25%)/8192 465 ns 464 ns std::find(vector<int>) (bail 25%)/32768 1832 ns 1832 ns std::find(vector<long long>) (bail 25%)/8 0.920 ns 1.20 ns std::find(vector<long long>) (bail 25%)/1024 65.2 ns 31.2 ns std::find(vector<long long>) (bail 25%)/8192 464 ns 255 ns std::find(vector<long long>) (bail 25%)/32768 1833 ns 992 ns std::find(vector<char>) (process all)/8 1.21 ns 1.22 ns std::find(vector<char>) (process all)/50 1.92 ns 1.93 ns std::find(vector<char>) (process all)/1024 16.6 ns 16.9 ns std::find(vector<char>) (process all)/8192 134 ns 136 ns std::find(vector<char>) (process all)/32768 488 ns 503 ns std::find(vector<int>) (process all)/8 2.45 ns 2.48 ns std::find(vector<int>) (process all)/50 12.7 ns 12.7 ns std::find(vector<int>) (process all)/1024 236 ns 236 ns std::find(vector<int>) (process all)/8192 1830 ns 1834 ns std::find(vector<int>) (process all)/32768 7351 ns 7346 ns std::find(vector<long long>) (process all)/8 2.02 ns 1.45 ns std::find(vector<long long>) (process all)/50 12.0 ns 6.12 ns std::find(vector<long long>) (process all)/1024 235 ns 123 ns std::find(vector<long long>) (process all)/8192 1830 ns 983 ns std::find(vector<long long>) (process all)/32768 7306 ns 3969 ns std::find(vector<bool>) (process all)/8 1.14 ns 1.15 ns std::find(vector<bool>) (process all)/50 1.16 ns 1.17 ns std::find(vector<bool>) (process all)/1024 4.51 ns 4.53 ns std::find(vector<bool>) (process all)/8192 33.6 ns 33.5 ns std::find(vector<bool>) (process all)/1048576 3660 ns 3660 ns ```
1 parent eacd3ba commit c8f7ebc

File tree

5 files changed

+97
-23
lines changed

5 files changed

+97
-23
lines changed

libcxx/docs/ReleaseNotes/22.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ Improvements and New Features
6464
- Multiple internal types have been refactored to use ``[[no_unique_address]]``, resulting in faster compile times and
6565
reduced debug information.
6666

67+
- The performance of ``std::find`` has been improved by up to 2x for integral types
68+
6769
Deprecations and Removals
6870
-------------------------
6971

libcxx/include/__algorithm/find.h

Lines changed: 87 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <__algorithm/find_segment_if.h>
1414
#include <__algorithm/min.h>
15+
#include <__algorithm/simd_utils.h>
1516
#include <__algorithm/unwrap_iter.h>
1617
#include <__bit/countr.h>
1718
#include <__bit/invert_if.h>
@@ -44,39 +45,102 @@ _LIBCPP_BEGIN_NAMESPACE_STD
4445
// generic implementation
4546
template <class _Iter, class _Sent, class _Tp, class _Proj>
4647
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Iter
47-
__find(_Iter __first, _Sent __last, const _Tp& __value, _Proj& __proj) {
48+
__find_loop(_Iter __first, _Sent __last, const _Tp& __value, _Proj& __proj) {
4849
for (; __first != __last; ++__first)
4950
if (std::__invoke(__proj, *__first) == __value)
5051
break;
5152
return __first;
5253
}
5354

54-
// trivially equality comparable implementations
55-
template <class _Tp,
56-
class _Up,
57-
class _Proj,
58-
__enable_if_t<__is_identity<_Proj>::value && __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value &&
59-
sizeof(_Tp) == 1,
60-
int> = 0>
61-
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Tp* __find(_Tp* __first, _Tp* __last, const _Up& __value, _Proj&) {
62-
if (auto __ret = std::__constexpr_memchr(__first, __value, __last - __first))
63-
return __ret;
64-
return __last;
55+
template <class _Iter, class _Sent, class _Tp, class _Proj>
56+
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Iter
57+
__find(_Iter __first, _Sent __last, const _Tp& __value, _Proj& __proj) {
58+
return std::__find_loop(std::move(__first), std::move(__last), __value, __proj);
6559
}
6660

67-
#if _LIBCPP_HAS_WIDE_CHARACTERS
68-
template <class _Tp,
69-
class _Up,
70-
class _Proj,
71-
__enable_if_t<__is_identity<_Proj>::value && __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value &&
72-
sizeof(_Tp) == sizeof(wchar_t) && _LIBCPP_ALIGNOF(_Tp) >= _LIBCPP_ALIGNOF(wchar_t),
73-
int> = 0>
61+
#if _LIBCPP_VECTORIZE_ALGORITHMS
62+
template <class _Tp, class _Up>
63+
[[__nodiscard__]] _LIBCPP_HIDE_FROM_ABI
64+
_LIBCPP_CONSTEXPR_SINCE_CXX14 _Tp* __find_vectorized(_Tp* __first, _Tp* __last, _Up __value) {
65+
if (!__libcpp_is_constant_evaluated()) {
66+
constexpr size_t __unroll_count = 4;
67+
constexpr size_t __vec_size = __native_vector_size<_Tp>;
68+
using __vec = __simd_vector<_Tp, __vec_size>;
69+
70+
auto __orig_first = __first;
71+
72+
auto __values = static_cast<__simd_vector<_Up, __vec_size>>(__value); // broadcast the value
73+
while (static_cast<size_t>(__last - __first) >= __unroll_count * __vec_size) [[__unlikely__]] {
74+
__vec __lhs[__unroll_count];
75+
76+
for (size_t __i = 0; __i != __unroll_count; ++__i)
77+
__lhs[__i] = std::__load_vector<__vec>(__first + __i * __vec_size);
78+
79+
for (size_t __i = 0; __i != __unroll_count; ++__i) {
80+
if (auto __cmp_res = __lhs[__i] == __values; std::__any_of(__cmp_res)) {
81+
auto __offset = __i * __vec_size + std::__find_first_set(__cmp_res);
82+
return __first + __offset;
83+
}
84+
}
85+
86+
__first += __unroll_count * __vec_size;
87+
}
88+
89+
// check the remaining 0-3 vectors
90+
while (static_cast<size_t>(__last - __first) >= __vec_size) {
91+
if (auto __cmp_res = std::__load_vector<__vec>(__first) == __values; std::__any_of(__cmp_res)) {
92+
return __first + std::__find_first_set(__cmp_res);
93+
}
94+
__first += __vec_size;
95+
}
96+
97+
if (__last - __first == 0)
98+
return __first;
99+
100+
// Check if we can load elements in front of the current pointer. If that's the case load a vector at
101+
// (last - vector_size) to check the remaining elements
102+
if (static_cast<size_t>(__first - __orig_first) >= __vec_size) {
103+
__first = __last - __vec_size;
104+
return __first + std::__find_first_set(std::__load_vector<__vec>(__first) == __values);
105+
}
106+
}
107+
108+
__identity __proj;
109+
return std::__find_loop(__first, __last, __value, __proj);
110+
}
111+
#endif
112+
113+
#ifndef _LIBCPP_CXX03_LANG
114+
// trivially equality comparable implementations
115+
template <
116+
class _Tp,
117+
class _Up,
118+
class _Proj,
119+
__enable_if_t<__is_identity<_Proj>::value && __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value, int> = 0>
74120
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Tp* __find(_Tp* __first, _Tp* __last, const _Up& __value, _Proj&) {
75-
if (auto __ret = std::__constexpr_wmemchr(__first, __value, __last - __first))
76-
return __ret;
77-
return __last;
121+
if constexpr (sizeof(_Tp) == 1) {
122+
if (auto __ret = std::__constexpr_memchr(__first, __value, __last - __first))
123+
return __ret;
124+
return __last;
125+
}
126+
# if _LIBCPP_HAS_WIDE_CHARACTERS
127+
else if constexpr (sizeof(_Tp) == sizeof(wchar_t) && _LIBCPP_ALIGNOF(_Tp) >= _LIBCPP_ALIGNOF(wchar_t)) {
128+
if (auto __ret = std::__constexpr_wmemchr(__first, __value, __last - __first))
129+
return __ret;
130+
return __last;
131+
}
132+
# endif
133+
# if _LIBCPP_VECTORIZE_ALGORITHMS
134+
else if constexpr (is_integral<_Tp>::value) {
135+
return std::__find_vectorized(__first, __last, __value);
136+
}
137+
# endif
138+
else {
139+
__identity __proj;
140+
return std::__find_loop(__first, __last, __value, __proj);
141+
}
78142
}
79-
#endif // _LIBCPP_HAS_WIDE_CHARACTERS
143+
#endif
80144

81145
// TODO: This should also be possible to get right with different signedness
82146
// cast integral types to allow vectorization

libcxx/include/__algorithm/simd_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ template <class _VecT, class _Iter>
114114
}(make_index_sequence<__simd_vector_size_v<_VecT>>{});
115115
}
116116

117+
template <class _Tp, size_t _Np>
118+
[[__nodiscard__]] _LIBCPP_HIDE_FROM_ABI bool __any_of(__simd_vector<_Tp, _Np> __vec) noexcept {
119+
return __builtin_reduce_or(__builtin_convertvector(__vec, __simd_vector<bool, _Np>));
120+
}
121+
117122
template <class _Tp, size_t _Np>
118123
[[__nodiscard__]] _LIBCPP_HIDE_FROM_ABI bool __all_of(__simd_vector<_Tp, _Np> __vec) noexcept {
119124
return __builtin_reduce_and(__builtin_convertvector(__vec, __simd_vector<bool, _Np>));

libcxx/include/module.modulemap.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,7 @@ module std [system] {
12251225
header "deque"
12261226
export *
12271227
export std.iterator.reverse_iterator
1228+
export std.algorithm.simd_utils // This is a workaround for https://llvm.org/PR120108.
12281229
}
12291230

12301231
module exception {
@@ -2238,6 +2239,7 @@ module std [system] {
22382239
header "vector"
22392240
export std.iterator.reverse_iterator
22402241
export *
2242+
export std.algorithm.simd_utils // This is a workaround for https://llvm.org/PR120108.
22412243
}
22422244

22432245
// Experimental C++ Standard Library interfaces

libcxx/test/benchmarks/algorithms/nonmodifying/find.bench.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ int main(int argc, char** argv) {
5151
// find
5252
bm.template operator()<std::vector<char>>("std::find(vector<char>) (" + comment + ")", std_find);
5353
bm.template operator()<std::vector<int>>("std::find(vector<int>) (" + comment + ")", std_find);
54+
bm.template operator()<std::vector<long long>>("std::find(vector<long long>) (" + comment + ")", std_find);
5455
bm.template operator()<std::deque<int>>("std::find(deque<int>) (" + comment + ")", std_find);
5556
bm.template operator()<std::list<int>>("std::find(list<int>) (" + comment + ")", std_find);
5657

0 commit comments

Comments
 (0)