Skip to content

Commit de8c84c

Browse files
(Semi-)vectorize includes (#5590)
Co-authored-by: Stephan T. Lavavej <[email protected]>
1 parent 78bb452 commit de8c84c

File tree

4 files changed

+554
-1
lines changed

4 files changed

+554
-1
lines changed

benchmarks/src/includes.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,16 @@ void common_args(auto bm) {
111111
}
112112
}
113113

114+
BENCHMARK(bm_includes<uint8_t, alg_type::std_fn>)->Apply(common_args);
115+
BENCHMARK(bm_includes<uint16_t, alg_type::std_fn>)->Apply(common_args);
116+
BENCHMARK(bm_includes<uint32_t, alg_type::std_fn>)->Apply(common_args);
117+
BENCHMARK(bm_includes<uint64_t, alg_type::std_fn>)->Apply(common_args);
118+
119+
BENCHMARK(bm_includes<uint8_t, alg_type::rng>)->Apply(common_args);
120+
BENCHMARK(bm_includes<uint16_t, alg_type::rng>)->Apply(common_args);
121+
BENCHMARK(bm_includes<uint32_t, alg_type::rng>)->Apply(common_args);
122+
BENCHMARK(bm_includes<uint64_t, alg_type::rng>)->Apply(common_args);
123+
114124
BENCHMARK(bm_includes<int8_t, alg_type::std_fn>)->Apply(common_args);
115125
BENCHMARK(bm_includes<int16_t, alg_type::std_fn>)->Apply(common_args);
116126
BENCHMARK(bm_includes<int32_t, alg_type::std_fn>)->Apply(common_args);

stl/inc/algorithm

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,23 @@ const void* __stdcall __std_is_sorted_until_8u(const void* _First, const void* _
8686
const void* __stdcall __std_is_sorted_until_f(const void* _First, const void* _Last, bool _Greater) noexcept;
8787
const void* __stdcall __std_is_sorted_until_d(const void* _First, const void* _Last, bool _Greater) noexcept;
8888

89+
__declspec(noalias) bool __stdcall __std_includes_less_1i(
90+
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
91+
__declspec(noalias) bool __stdcall __std_includes_less_1u(
92+
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
93+
__declspec(noalias) bool __stdcall __std_includes_less_2i(
94+
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
95+
__declspec(noalias) bool __stdcall __std_includes_less_2u(
96+
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
97+
__declspec(noalias) bool __stdcall __std_includes_less_4i(
98+
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
99+
__declspec(noalias) bool __stdcall __std_includes_less_4u(
100+
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
101+
__declspec(noalias) bool __stdcall __std_includes_less_8i(
102+
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
103+
__declspec(noalias) bool __stdcall __std_includes_less_8u(
104+
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
105+
89106
// TRANSITION, DevCom-10610477
90107
__declspec(noalias) void __stdcall __std_replace_4(
91108
void* _First, void* _Last, uint32_t _Old_val, uint32_t _New_val) noexcept;
@@ -256,6 +273,40 @@ _Ty* _Is_sorted_until_vectorized(_Ty* const _First, _Ty* const _Last, const bool
256273
}
257274
}
258275

276+
template <class _Ty>
277+
bool _Includes_vectorized(
278+
const _Ty* const _First1, const _Ty* const _Last1, const _Ty* const _First2, const _Ty* const _Last2) noexcept {
279+
constexpr bool _Signed = is_signed_v<_Ty>;
280+
281+
if constexpr (sizeof(_Ty) == 1) {
282+
if constexpr (_Signed) {
283+
return ::__std_includes_less_1i(_First1, _Last1, _First2, _Last2);
284+
} else {
285+
return ::__std_includes_less_1u(_First1, _Last1, _First2, _Last2);
286+
}
287+
} else if constexpr (sizeof(_Ty) == 2) {
288+
if constexpr (_Signed) {
289+
return ::__std_includes_less_2i(_First1, _Last1, _First2, _Last2);
290+
} else {
291+
return ::__std_includes_less_2u(_First1, _Last1, _First2, _Last2);
292+
}
293+
} else if constexpr (sizeof(_Ty) == 4) {
294+
if constexpr (_Signed) {
295+
return ::__std_includes_less_4i(_First1, _Last1, _First2, _Last2);
296+
} else {
297+
return ::__std_includes_less_4u(_First1, _Last1, _First2, _Last2);
298+
}
299+
} else if constexpr (sizeof(_Ty) == 8) {
300+
if constexpr (_Signed) {
301+
return ::__std_includes_less_8i(_First1, _Last1, _First2, _Last2);
302+
} else {
303+
return ::__std_includes_less_8u(_First1, _Last1, _First2, _Last2);
304+
}
305+
} else {
306+
_STL_INTERNAL_STATIC_ASSERT(false); // unexpected size
307+
}
308+
}
309+
259310
template <class _Ty, class _TVal1, class _TVal2>
260311
__declspec(noalias) void _Replace_vectorized(
261312
_Ty* const _First, _Ty* const _Last, const _TVal1 _Old_val, const _TVal2 _New_val) noexcept {
@@ -384,6 +435,13 @@ constexpr bool _Output_iterator_for_vector_alg_is_safe() {
384435
}
385436
}
386437

438+
// Can we activate the vector algorithms for includes?
439+
template <class _Iter1, class _Iter2, class _Elem = _Iter_value_t<_Iter1>>
440+
constexpr bool _Vector_alg_includes_iterators_safe =
441+
_Iterators_are_contiguous<_Iter1, _Iter2> // Iterators must be contiguous so we can get raw pointers.
442+
&& !_Iterator_is_volatile<_Iter1> && !_Iterator_is_volatile<_Iter2> // Iterators must not be volatile.
443+
&& is_same_v<_Elem, _Iter_value_t<_Iter2>> // Iterators have the same value type.
444+
&& disjunction_v<is_integral<_Elem>, is_pointer<_Elem>>; // Integral or pointer type.
387445
_STD_END
388446
#endif // _USE_STD_VECTOR_ALGORITHMS
389447

@@ -10244,6 +10302,15 @@ _NODISCARD _CONSTEXPR20 bool includes(_InIt1 _First1, _InIt1 _Last1, _InIt2 _Fir
1024410302
return false;
1024510303
}
1024610304

10305+
#if _USE_STD_VECTOR_ALGORITHMS
10306+
if constexpr (_Vector_alg_includes_iterators_safe<_InIt1, _InIt2> && _Is_predicate_less<_InIt1, _Pr>) {
10307+
if (!_STD _Is_constant_evaluated()) {
10308+
return _STD _Includes_vectorized(_STD _To_address(_First1), _STD _To_address(_Last1),
10309+
_STD _To_address(_First2), _STD _To_address(_Last2));
10310+
}
10311+
}
10312+
#endif // _USE_STD_VECTOR_ALGORITHMS
10313+
1024710314
for (;;) {
1024810315
if (_DEBUG_LT_PRED(_Pred, *_UFirst1, *_UFirst2)) {
1024910316
++_UFirst1;
@@ -10333,6 +10400,20 @@ namespace ranges {
1033310400
return false;
1033410401
}
1033510402

10403+
#if _USE_STD_VECTOR_ALGORITHMS
10404+
if constexpr (_Vector_alg_includes_iterators_safe<_It1, _It2> && _Is_predicate_less<_It1, _Pr>
10405+
&& sized_sentinel_for<_Se1, _It1> && sized_sentinel_for<_Se2, _It2>
10406+
&& is_same_v<_Pj1, identity> && is_same_v<_Pj2, identity>) {
10407+
if (!_STD is_constant_evaluated()) {
10408+
const auto _First1_ptr = _STD to_address(_First1);
10409+
const auto _First2_ptr = _STD to_address(_First2);
10410+
const auto _Last1_ptr = _First1_ptr + static_cast<ptrdiff_t>(_Last1 - _First1);
10411+
const auto _Last2_ptr = _First2_ptr + static_cast<ptrdiff_t>(_Last2 - _First2);
10412+
return _STD _Includes_vectorized(_First1_ptr, _Last1_ptr, _First2_ptr, _Last2_ptr);
10413+
}
10414+
}
10415+
#endif // _USE_STD_VECTOR_ALGORITHMS
10416+
1033610417
for (;;) {
1033710418
if (_STD invoke(_Pred, _STD invoke(_Proj1, *_First1), _STD invoke(_Proj2, *_First2))) {
1033810419
++_First1;

0 commit comments

Comments
 (0)