diff --git a/benchmarks/CMakeLists.txt b/benchmarks/CMakeLists.txt index 9aa8a8c045..5a054a4e42 100644 --- a/benchmarks/CMakeLists.txt +++ b/benchmarks/CMakeLists.txt @@ -141,5 +141,6 @@ add_benchmark(unique src/unique.cpp) add_benchmark(vector_bool_copy src/vector_bool_copy.cpp) add_benchmark(vector_bool_copy_n src/vector_bool_copy_n.cpp) add_benchmark(vector_bool_count src/vector_bool_count.cpp) +add_benchmark(vector_bool_meow_of src/vector_bool_meow_of.cpp) add_benchmark(vector_bool_move src/vector_bool_move.cpp) add_benchmark(vector_bool_transform src/vector_bool_transform.cpp) diff --git a/benchmarks/src/vector_bool_meow_of.cpp b/benchmarks/src/vector_bool_meow_of.cpp new file mode 100644 index 0000000000..77a6839ec9 --- /dev/null +++ b/benchmarks/src/vector_bool_meow_of.cpp @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +// +#include +#include +#include +#include + +#include "skewed_allocator.hpp" + +using namespace std; + +enum class alg { any_, all_, none_ }; +enum class content { ones_then_zeros, zeros_then_ones }; + +template +void meow_of(benchmark::State& state) { + const auto size = static_cast(state.range(0)); + vector> source(size); + + if constexpr (Content == content::ones_then_zeros) { + fill(source.begin(), source.begin() + source.size() / 2, true); + } else { + fill(source.begin() + source.size() / 2, source.end(), true); + } + + for (auto _ : state) { + benchmark::DoNotOptimize(source); + bool result; + if constexpr (Alg == alg::any_) { + result = any_of(source.begin(), source.end(), Pred{}); + } else if constexpr (Alg == alg::all_) { + result = all_of(source.begin(), source.end(), Pred{}); + } else { + result = none_of(source.begin(), source.end(), Pred{}); + } + benchmark::DoNotOptimize(result); + } +} + +void common_args(auto bm) { + bm->RangeMultiplier(64)->Range(64, 64 << 10); +} + +using not_ = logical_not<>; + +BENCHMARK(meow_of)->Apply(common_args); +BENCHMARK(meow_of)->Apply(common_args); +BENCHMARK(meow_of)->Apply(common_args); +BENCHMARK(meow_of)->Apply(common_args); + +BENCHMARK_MAIN(); diff --git a/stl/inc/algorithm b/stl/inc/algorithm index d09fa5084c..9fe789cc13 100644 --- a/stl/inc/algorithm +++ b/stl/inc/algorithm @@ -1561,18 +1561,30 @@ namespace ranges { } // namespace ranges #endif // _HAS_CXX20 +struct _All_of_vbool_traits; +struct _Any_of_vbool_traits; +struct _None_of_vbool_traits; + +template +_NODISCARD _CONSTEXPR20 bool _Meow_of_vbool(_VbIt _First, _VbIt _Last, _Mapped_fn _Mapped_func); + _EXPORT_STD template _NODISCARD _CONSTEXPR20 bool all_of(_InIt _First, _InIt _Last, _Pr _Pred) { // test if all elements satisfy _Pred _STD _Adl_verify_range(_First, _Last); auto _UFirst = _STD _Get_unwrapped(_First); const auto _ULast = _STD _Get_unwrapped(_Last); - for (; _UFirst != _ULast; ++_UFirst) { - if (!_Pred(*_UFirst)) { - return false; + + if constexpr (_Is_vb_iterator && !is_void_v<_Map_vb_functor_t<_Pr>>) { + return _Meow_of_vbool<_All_of_vbool_traits>(_UFirst, _ULast, _Map_vb_functor_t<_Pr>{}); + } else { + for (; _UFirst != _ULast; ++_UFirst) { + if (!_Pred(*_UFirst)) { + return false; + } } - } - return true; + return true; + } } #if _HAS_CXX17 @@ -1628,13 +1640,18 @@ _NODISCARD _CONSTEXPR20 bool any_of(const _InIt _First, const _InIt _Last, _Pr _ _STD _Adl_verify_range(_First, _Last); auto _UFirst = _STD _Get_unwrapped(_First); const auto _ULast = _STD _Get_unwrapped(_Last); - for (; _UFirst != _ULast; ++_UFirst) { - if (_Pred(*_UFirst)) { - return true; + + if constexpr (_Is_vb_iterator && !is_void_v<_Map_vb_functor_t<_Pr>>) { + return _Meow_of_vbool<_Any_of_vbool_traits>(_UFirst, _ULast, _Map_vb_functor_t<_Pr>{}); + } else { + for (; _UFirst != _ULast; ++_UFirst) { + if (_Pred(*_UFirst)) { + return true; + } } - } - return false; + return false; + } } #if _HAS_CXX17 @@ -1690,13 +1707,17 @@ _NODISCARD _CONSTEXPR20 bool none_of(const _InIt _First, const _InIt _Last, _Pr _STD _Adl_verify_range(_First, _Last); auto _UFirst = _STD _Get_unwrapped(_First); const auto _ULast = _STD _Get_unwrapped(_Last); - for (; _UFirst != _ULast; ++_UFirst) { - if (_Pred(*_UFirst)) { - return false; + if constexpr (_Is_vb_iterator && !is_void_v<_Map_vb_functor_t<_Pr>>) { + return _Meow_of_vbool<_None_of_vbool_traits>(_UFirst, _ULast, _Map_vb_functor_t<_Pr>{}); + } else { + for (; _UFirst != _ULast; ++_UFirst) { + if (_Pred(*_UFirst)) { + return false; + } } - } - return true; + return true; + } } #if _HAS_CXX17 diff --git a/stl/inc/vector b/stl/inc/vector index dafac7c593..48c270d862 100644 --- a/stl/inc/vector +++ b/stl/inc/vector @@ -4047,6 +4047,76 @@ _CONSTEXPR20 _OutIt _Transform_vbool_aligned( return _Dest; } +struct _All_of_vbool_traits { + static constexpr bool _Default_result = true; + + static _CONSTEXPR20 bool _Check(const _Vbase _Value) { + return _Value != ~_Vbase{0}; + } + + static _CONSTEXPR20 bool _Check(const _Vbase _Value, const _Vbase _Mask) { + return (_Value & _Mask) != _Mask; + } +}; + +struct _Any_of_vbool_traits_base { + static _CONSTEXPR20 bool _Check(const _Vbase _Value) { + return _Value != 0; + } + + static _CONSTEXPR20 bool _Check(const _Vbase _Value, const _Vbase _Mask) { + return (_Value & _Mask) != 0; + } +}; + +struct _Any_of_vbool_traits : _Any_of_vbool_traits_base { + static constexpr bool _Default_result = false; +}; + +struct _None_of_vbool_traits : _Any_of_vbool_traits_base { + static constexpr bool _Default_result = true; +}; + +template +_NODISCARD _CONSTEXPR20 bool _Meow_of_vbool(const _VbIt _First, const _VbIt _Last, const _Mapped_fn _Mapped_func) { + constexpr bool _Early_result = !_Traits::_Default_result; + auto _First_ptr = _First._Myptr; + const auto _Last_ptr = _Last._Myptr; + + if (_First_ptr == _Last_ptr) { + const _Vbase _Mask = (_Vbase{1} << _Last._Myoff) - (_Vbase{1} << _First._Myoff); + if (_Mask == 0) { + return _Traits::_Default_result; + } else { + return _Traits::_Check(_Mapped_func(*_First_ptr), _Mask) ? _Early_result : _Traits::_Default_result; + } + } + + if (_First._Myoff != 0) { + const _Vbase _Mask = static_cast<_Vbase>(-1) << _First._Myoff; + if (_Traits::_Check(_Mapped_func(*_First_ptr), _Mask)) { + return _Early_result; + } + + ++_First_ptr; + } + + for (; _First_ptr != _Last_ptr; ++_First_ptr) { + if (_Traits::_Check(_Mapped_func(*_First_ptr))) { + return _Early_result; + } + } + + if (_Last._Myoff != 0) { + const _Vbase _Mask = (_Vbase{1} << _Last._Myoff) - 1; + if (_Traits::_Check(_Mapped_func(*_First_ptr), _Mask)) { + return _Early_result; + } + } + + return _Traits::_Default_result; +} + #undef _ASAN_VECTOR_MODIFY #undef _ASAN_VECTOR_REMOVE #undef _ASAN_VECTOR_CREATE diff --git a/stl/inc/xutility b/stl/inc/xutility index da3e663bde..c02d9d1d8c 100644 --- a/stl/inc/xutility +++ b/stl/inc/xutility @@ -4971,6 +4971,13 @@ struct _Map_vb_functor { using type = void; }; +#if _HAS_CXX20 +template <> +struct _Map_vb_functor { + using type = identity; +}; +#endif // _HAS_CXX20 + template using _Map_vb_functor_t = typename _Map_vb_functor<_Fn>::type; diff --git a/tests/std/tests/GH_000625_vector_bool_optimization/test.cpp b/tests/std/tests/GH_000625_vector_bool_optimization/test.cpp index 41e73cff35..c94c0b6ed1 100644 --- a/tests/std/tests/GH_000625_vector_bool_optimization/test.cpp +++ b/tests/std/tests/GH_000625_vector_bool_optimization/test.cpp @@ -190,6 +190,104 @@ CONSTEXPR20 bool test_transform() { return true; } +CONSTEXPR20 bool test_meow_of_helper(const size_t length_before, const size_t length, const size_t length_after) { + const size_t total_length = length_before + length + length_after; + + vector zeros(total_length); + vector ones(total_length); + vector mix(total_length); + fill(zeros.begin(), zeros.begin() + length_before, true); + fill(zeros.end() - length_after, zeros.end(), true); + fill(ones.begin() + length_before, ones.end() - length_after, true); + fill(mix.begin(), mix.begin() + length_before, true); + fill(mix.begin() + length_before + length / 2, mix.end() - length_after, true); + + const auto first_0 = zeros.begin() + length_before; + const auto last_0 = zeros.end() - length_after; + const auto first_1 = ones.cbegin() + length_before; + const auto last_1 = ones.cend() - length_after; + const auto first_m = mix.cbegin() + length_before; + const auto last_m = mix.cend() - length_after; + + if (length == 0) { +#if _HAS_CXX20 + assert(all_of(first_0, last_0, identity{}) == true); + assert(all_of(first_1, last_1, identity{}) == true); + assert(all_of(first_m, last_m, identity{}) == true); + + assert(any_of(first_0, last_0, identity{}) == false); + assert(any_of(first_1, last_1, identity{}) == false); + assert(any_of(first_m, last_m, identity{}) == false); + + assert(none_of(first_0, last_0, identity{}) == true); + assert(none_of(first_1, last_1, identity{}) == true); + assert(none_of(first_m, last_m, identity{}) == true); +#endif // _HAS_CXX20 + + assert(all_of(first_0, last_0, logical_not<>{}) == true); + assert(all_of(first_1, last_1, logical_not<>{}) == true); + assert(all_of(first_m, last_m, logical_not<>{}) == true); + + assert(any_of(first_0, last_0, logical_not<>{}) == false); + assert(any_of(first_1, last_1, logical_not<>{}) == false); + assert(any_of(first_m, last_m, logical_not<>{}) == false); + + assert(none_of(first_0, last_0, logical_not<>{}) == true); + assert(none_of(first_1, last_1, logical_not<>{}) == true); + assert(none_of(first_m, last_m, logical_not<>{}) == true); + } else { +#if _HAS_CXX20 + assert(all_of(first_0, last_0, identity{}) == false); + assert(all_of(first_1, last_1, identity{}) == true); + assert(all_of(first_m, last_m, identity{}) == false); + + assert(any_of(first_0, last_0, identity{}) == false); + assert(any_of(first_1, last_1, identity{}) == true); + assert(any_of(first_m, last_m, identity{}) == true); + + assert(none_of(first_0, last_0, identity{}) == true); + assert(none_of(first_1, last_1, identity{}) == false); + assert(none_of(first_m, last_m, identity{}) == false); +#endif // _HAS_CXX20 + + assert(all_of(first_0, last_0, logical_not<>{}) == true); + assert(all_of(first_1, last_1, logical_not<>{}) == false); + assert(all_of(first_m, last_m, logical_not<>{}) == false); + + assert(any_of(first_0, last_0, logical_not<>{}) == true); + assert(any_of(first_1, last_1, logical_not<>{}) == false); + assert(any_of(first_m, last_m, logical_not<>{}) == true); + + assert(none_of(first_0, last_0, logical_not<>{}) == false); + assert(none_of(first_1, last_1, logical_not<>{}) == true); + assert(none_of(first_m, last_m, logical_not<>{}) == false); + } + + return true; +} + +CONSTEXPR20 bool test_meow_of() { + // Empty range + test_meow_of_helper(0, 0, 3); + test_meow_of_helper(3, 0, 3); + + // One block, ends within block + test_meow_of_helper(0, 10, 3); + test_meow_of_helper(3, 10, 3); + + // One block, exactly + test_meow_of_helper(0, blockSize, 0); + + // Multiple blocks, spanning + test_meow_of_helper(3, blockSize - 2, 3); + test_meow_of_helper(3, blockSize + 2, 3); + + // Many blocks, exaclty + test_meow_of_helper(blockSize, 4 * blockSize, blockSize); + + return true; +} + CONSTEXPR20 void test_fill_helper(const size_t length) { // No offset { @@ -1531,6 +1629,7 @@ static_assert(test_fill()); static_assert(test_find()); static_assert(test_count()); static_assert(test_transform()); +static_assert(test_meow_of()); #if defined(__clang__) || defined(__EDG__) // TRANSITION, VSO-2574489 static_assert(test_copy_part_1()); @@ -1543,6 +1642,7 @@ int main() { test_find(); test_count(); test_transform(); + test_meow_of(); test_copy_part_1(); test_copy_part_2();