Skip to content

Commit 1d131ff

Browse files
authored
[libc++] Optimize most of the __tree search algorithms (#155245)
This patch introduces a new comparator, namely `__lazy_synth_three_way`, which tries to provide an efficient three way comparator for known types and falls back to using the provided comparator if it doesn't know how to do that. Currently, an efficient three way comparison is only provided when using one of the `less` comparions object from the standard library and `std::string`. This will be extended in future patches. ``` ------------------------------------------------------------------------------------------------------------------------------ Benchmark old new ------------------------------------------------------------------------------------------------------------------------------ std::map<std::string, int>::ctor(const&)/0 12.6 ns 12.6 ns std::map<std::string, int>::ctor(const&)/32 858 ns 837 ns std::map<std::string, int>::ctor(const&)/1024 46700 ns 46739 ns std::map<std::string, int>::ctor(const&)/8192 458100 ns 449806 ns std::map<std::string, int>::ctor(iterator, iterator) (unsorted sequence)/0 12.8 ns 12.7 ns std::map<std::string, int>::ctor(iterator, iterator) (unsorted sequence)/32 1286 ns 1266 ns std::map<std::string, int>::ctor(iterator, iterator) (unsorted sequence)/1024 93812 ns 84686 ns std::map<std::string, int>::ctor(iterator, iterator) (unsorted sequence)/8192 1480346 ns 1385924 ns std::map<std::string, int>::ctor(iterator, iterator) (sorted sequence)/0 12.9 ns 12.8 ns std::map<std::string, int>::ctor(iterator, iterator) (sorted sequence)/32 1044 ns 1055 ns std::map<std::string, int>::ctor(iterator, iterator) (sorted sequence)/1024 63071 ns 62861 ns std::map<std::string, int>::ctor(iterator, iterator) (sorted sequence)/8192 595046 ns 590223 ns std::map<std::string, int>::operator=(const&) (into cleared Container)/0 13.6 ns 13.6 ns std::map<std::string, int>::operator=(const&) (into cleared Container)/32 880 ns 911 ns std::map<std::string, int>::operator=(const&) (into cleared Container)/1024 48627 ns 47808 ns std::map<std::string, int>::operator=(const&) (into cleared Container)/8192 458552 ns 454497 ns std::map<std::string, int>::operator=(const&) (into partially populated Container)/0 13.8 ns 13.6 ns std::map<std::string, int>::operator=(const&) (into partially populated Container)/32 864 ns 851 ns std::map<std::string, int>::operator=(const&) (into partially populated Container)/1024 49483 ns 49555 ns std::map<std::string, int>::operator=(const&) (into partially populated Container)/8192 456977 ns 457894 ns std::map<std::string, int>::operator=(const&) (into populated Container)/0 1.31 ns 1.31 ns std::map<std::string, int>::operator=(const&) (into populated Container)/32 425 ns 415 ns std::map<std::string, int>::operator=(const&) (into populated Container)/1024 14248 ns 14225 ns std::map<std::string, int>::operator=(const&) (into populated Container)/8192 136684 ns 133696 ns std::map<std::string, int>::insert(value) (already present)/0 21.5 ns 16.2 ns std::map<std::string, int>::insert(value) (already present)/32 22.7 ns 25.1 ns std::map<std::string, int>::insert(value) (already present)/1024 54.5 ns 29.1 ns std::map<std::string, int>::insert(value) (already present)/8192 78.4 ns 30.4 ns std::map<std::string, int>::insert(value) (new value)/0 40.9 ns 39.0 ns std::map<std::string, int>::insert(value) (new value)/32 58.3 ns 47.2 ns std::map<std::string, int>::insert(value) (new value)/1024 120 ns 71.3 ns std::map<std::string, int>::insert(value) (new value)/8192 157 ns 129 ns std::map<std::string, int>::insert(hint, value) (good hint)/0 40.3 ns 40.7 ns std::map<std::string, int>::insert(hint, value) (good hint)/32 48.0 ns 30.0 ns std::map<std::string, int>::insert(hint, value) (good hint)/1024 107 ns 63.2 ns std::map<std::string, int>::insert(hint, value) (good hint)/8192 132 ns 107 ns std::map<std::string, int>::insert(hint, value) (bad hint)/0 27.0 ns 40.9 ns std::map<std::string, int>::insert(hint, value) (bad hint)/32 68.3 ns 58.4 ns std::map<std::string, int>::insert(hint, value) (bad hint)/1024 125 ns 82.0 ns std::map<std::string, int>::insert(hint, value) (bad hint)/8192 155 ns 150 ns std::map<std::string, int>::insert(iterator, iterator) (all new keys)/0 404 ns 405 ns std::map<std::string, int>::insert(iterator, iterator) (all new keys)/32 2004 ns 1805 ns std::map<std::string, int>::insert(iterator, iterator) (all new keys)/1024 102820 ns 76102 ns std::map<std::string, int>::insert(iterator, iterator) (all new keys)/8192 1144590 ns 949266 ns std::map<std::string, int>::insert(iterator, iterator) (half new keys)/0 408 ns 404 ns std::map<std::string, int>::insert(iterator, iterator) (half new keys)/32 1592 ns 1377 ns std::map<std::string, int>::insert(iterator, iterator) (half new keys)/1024 74847 ns 53921 ns std::map<std::string, int>::insert(iterator, iterator) (half new keys)/8192 828505 ns 698716 ns std::map<std::string, int>::insert(iterator, iterator) (product_iterator from same type)/0 407 ns 407 ns std::map<std::string, int>::insert(iterator, iterator) (product_iterator from same type)/32 1584 ns 1557 ns std::map<std::string, int>::insert(iterator, iterator) (product_iterator from same type)/1024 47157 ns 47443 ns std::map<std::string, int>::insert(iterator, iterator) (product_iterator from same type)/8192 623887 ns 628385 ns std::map<std::string, int>::insert(iterator, iterator) (product_iterator from zip_view)/0 405 ns 403 ns std::map<std::string, int>::insert(iterator, iterator) (product_iterator from zip_view)/32 1478 ns 1510 ns std::map<std::string, int>::insert(iterator, iterator) (product_iterator from zip_view)/1024 47852 ns 47835 ns std::map<std::string, int>::insert(iterator, iterator) (product_iterator from zip_view)/8192 605311 ns 606951 ns std::map<std::string, int>::erase(key) (existent)/0 129 ns 94.0 ns std::map<std::string, int>::erase(key) (existent)/32 110 ns 106 ns std::map<std::string, int>::erase(key) (existent)/1024 121 ns 128 ns std::map<std::string, int>::erase(key) (existent)/8192 165 ns 66.9 ns std::map<std::string, int>::erase(key) (non-existent)/0 0.269 ns 0.257 ns std::map<std::string, int>::erase(key) (non-existent)/32 21.9 ns 11.3 ns std::map<std::string, int>::erase(key) (non-existent)/1024 53.5 ns 25.4 ns std::map<std::string, int>::erase(key) (non-existent)/8192 67.3 ns 31.9 ns std::map<std::string, int>::erase(iterator)/0 46.3 ns 46.7 ns std::map<std::string, int>::erase(iterator)/32 44.4 ns 41.8 ns std::map<std::string, int>::erase(iterator)/1024 43.7 ns 46.4 ns std::map<std::string, int>::erase(iterator)/8192 45.2 ns 44.1 ns std::map<std::string, int>::erase(iterator, iterator) (erase half the container)/0 407 ns 407 ns std::map<std::string, int>::erase(iterator, iterator) (erase half the container)/32 876 ns 906 ns std::map<std::string, int>::erase(iterator, iterator) (erase half the container)/1024 20880 ns 20444 ns std::map<std::string, int>::erase(iterator, iterator) (erase half the container)/8192 252881 ns 241583 ns std::map<std::string, int>::clear()/0 407 ns 408 ns std::map<std::string, int>::clear()/32 1252 ns 1323 ns std::map<std::string, int>::clear()/1024 38488 ns 38017 ns std::map<std::string, int>::clear()/8192 416492 ns 428534 ns std::map<std::string, int>::find(key) (existent)/0 0.008 ns 0.008 ns std::map<std::string, int>::find(key) (existent)/32 33.9 ns 15.3 ns std::map<std::string, int>::find(key) (existent)/1024 43.0 ns 25.5 ns std::map<std::string, int>::find(key) (existent)/8192 44.6 ns 29.3 ns std::map<std::string, int>::find(key) (non-existent)/0 0.259 ns 0.257 ns std::map<std::string, int>::find(key) (non-existent)/32 22.6 ns 11.4 ns std::map<std::string, int>::find(key) (non-existent)/1024 48.6 ns 25.1 ns std::map<std::string, int>::find(key) (non-existent)/8192 64.1 ns 31.1 ns std::map<std::string, int>::count(key) (existent)/0 0.008 ns 0.008 ns std::map<std::string, int>::count(key) (existent)/32 32.2 ns 17.3 ns std::map<std::string, int>::count(key) (existent)/1024 42.4 ns 25.3 ns std::map<std::string, int>::count(key) (existent)/8192 44.4 ns 31.6 ns std::map<std::string, int>::count(key) (non-existent)/0 0.260 ns 0.259 ns std::map<std::string, int>::count(key) (non-existent)/32 22.9 ns 11.3 ns std::map<std::string, int>::count(key) (non-existent)/1024 49.8 ns 25.5 ns std::map<std::string, int>::count(key) (non-existent)/8192 66.3 ns 31.9 ns std::map<std::string, int>::contains(key) (existent)/0 0.008 ns 0.008 ns std::map<std::string, int>::contains(key) (existent)/32 31.4 ns 18.0 ns std::map<std::string, int>::contains(key) (existent)/1024 44.3 ns 26.5 ns std::map<std::string, int>::contains(key) (existent)/8192 47.4 ns 30.2 ns std::map<std::string, int>::contains(key) (non-existent)/0 0.452 ns 0.441 ns std::map<std::string, int>::contains(key) (non-existent)/32 23.1 ns 11.5 ns std::map<std::string, int>::contains(key) (non-existent)/1024 46.2 ns 26.3 ns std::map<std::string, int>::contains(key) (non-existent)/8192 63.4 ns 31.4 ns std::map<std::string, int>::lower_bound(key) (existent)/0 0.008 ns 0.008 ns std::map<std::string, int>::lower_bound(key) (existent)/32 17.2 ns 19.0 ns std::map<std::string, int>::lower_bound(key) (existent)/1024 27.1 ns 26.2 ns std::map<std::string, int>::lower_bound(key) (existent)/8192 34.0 ns 36.0 ns std::map<std::string, int>::lower_bound(key) (non-existent)/0 0.259 ns 0.257 ns std::map<std::string, int>::lower_bound(key) (non-existent)/32 11.6 ns 11.5 ns std::map<std::string, int>::lower_bound(key) (non-existent)/1024 24.8 ns 25.6 ns std::map<std::string, int>::lower_bound(key) (non-existent)/8192 31.7 ns 31.6 ns std::map<std::string, int>::upper_bound(key) (existent)/0 0.008 ns 0.008 ns std::map<std::string, int>::upper_bound(key) (existent)/32 18.8 ns 19.7 ns std::map<std::string, int>::upper_bound(key) (existent)/1024 25.3 ns 27.7 ns std::map<std::string, int>::upper_bound(key) (existent)/8192 30.2 ns 29.9 ns std::map<std::string, int>::upper_bound(key) (non-existent)/0 0.260 ns 0.259 ns std::map<std::string, int>::upper_bound(key) (non-existent)/32 11.3 ns 12.0 ns std::map<std::string, int>::upper_bound(key) (non-existent)/1024 25.6 ns 25.9 ns std::map<std::string, int>::upper_bound(key) (non-existent)/8192 33.1 ns 34.2 ns std::map<std::string, int>::equal_range(key) (existent)/0 0.008 ns 0.008 ns std::map<std::string, int>::equal_range(key) (existent)/32 33.5 ns 15.8 ns std::map<std::string, int>::equal_range(key) (existent)/1024 43.0 ns 25.1 ns std::map<std::string, int>::equal_range(key) (existent)/8192 54.1 ns 30.7 ns std::map<std::string, int>::equal_range(key) (non-existent)/0 0.265 ns 0.259 ns std::map<std::string, int>::equal_range(key) (non-existent)/32 22.1 ns 12.1 ns std::map<std::string, int>::equal_range(key) (non-existent)/1024 44.8 ns 24.4 ns std::map<std::string, int>::equal_range(key) (non-existent)/8192 62.2 ns 40.1 ns ``` Fixes #66577
1 parent 6a571a1 commit 1d131ff

File tree

14 files changed

+277
-46
lines changed

14 files changed

+277
-46
lines changed

libcxx/include/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,7 @@ set(files
911911
__utility/cmp.h
912912
__utility/convert_to_integral.h
913913
__utility/declval.h
914+
__utility/default_three_way_comparator.h
914915
__utility/element_count.h
915916
__utility/empty.h
916917
__utility/exception_guard.h
@@ -921,6 +922,7 @@ set(files
921922
__utility/integer_sequence.h
922923
__utility/is_pointer_in_range.h
923924
__utility/is_valid_range.h
925+
__utility/lazy_synth_three_way_comparator.h
924926
__utility/move.h
925927
__utility/no_destroy.h
926928
__utility/pair.h

libcxx/include/__config

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,13 @@ typedef __char32_t char32_t;
11561156
# define _LIBCPP_LIFETIMEBOUND
11571157
# endif
11581158

1159+
// This is to work around https://llvm.org/PR156809
1160+
# ifndef _LIBCPP_CXX03_LANG
1161+
# define _LIBCPP_CTOR_LIFETIMEBOUND _LIBCPP_LIFETIMEBOUND
1162+
# else
1163+
# define _LIBCPP_CTOR_LIFETIMEBOUND
1164+
# endif
1165+
11591166
# if __has_cpp_attribute(_Clang::__noescape__)
11601167
# define _LIBCPP_NOESCAPE [[_Clang::__noescape__]]
11611168
# else

libcxx/include/__tree

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include <__type_traits/is_swappable.h>
3535
#include <__type_traits/remove_const.h>
3636
#include <__utility/forward.h>
37+
#include <__utility/lazy_synth_three_way_comparator.h>
3738
#include <__utility/move.h>
3839
#include <__utility/pair.h>
3940
#include <__utility/swap.h>
@@ -1749,14 +1750,18 @@ __tree<_Tp, _Compare, _Allocator>::__find_equal(const _Key& __v) {
17491750
}
17501751

17511752
__node_base_pointer* __node_ptr = __root_ptr();
1753+
auto __comp = __lazy_synth_three_way_comparator<_Compare, _Key, value_type>(value_comp());
1754+
17521755
while (true) {
1753-
if (value_comp()(__v, __nd->__get_value())) {
1756+
auto __comp_res = __comp(__v, __nd->__get_value());
1757+
1758+
if (__comp_res.__less()) {
17541759
if (__nd->__left_ == nullptr)
17551760
return _Pair(static_cast<__end_node_pointer>(__nd), __nd->__left_);
17561761

17571762
__node_ptr = std::addressof(__nd->__left_);
17581763
__nd = static_cast<__node_pointer>(__nd->__left_);
1759-
} else if (value_comp()(__nd->__get_value(), __v)) {
1764+
} else if (__comp_res.__greater()) {
17601765
if (__nd->__right_ == nullptr)
17611766
return _Pair(static_cast<__end_node_pointer>(__nd), __nd->__right_);
17621767

@@ -2065,10 +2070,12 @@ template <class _Key>
20652070
typename __tree<_Tp, _Compare, _Allocator>::size_type
20662071
__tree<_Tp, _Compare, _Allocator>::__count_unique(const _Key& __k) const {
20672072
__node_pointer __rt = __root();
2073+
auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
20682074
while (__rt != nullptr) {
2069-
if (value_comp()(__k, __rt->__get_value())) {
2075+
auto __comp_res = __comp(__k, __rt->__get_value());
2076+
if (__comp_res.__less()) {
20702077
__rt = static_cast<__node_pointer>(__rt->__left_);
2071-
} else if (value_comp()(__rt->__get_value(), __k))
2078+
} else if (__comp_res.__greater())
20722079
__rt = static_cast<__node_pointer>(__rt->__right_);
20732080
else
20742081
return 1;
@@ -2082,11 +2089,13 @@ typename __tree<_Tp, _Compare, _Allocator>::size_type
20822089
__tree<_Tp, _Compare, _Allocator>::__count_multi(const _Key& __k) const {
20832090
__end_node_pointer __result = __end_node();
20842091
__node_pointer __rt = __root();
2092+
auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
20852093
while (__rt != nullptr) {
2086-
if (value_comp()(__k, __rt->__get_value())) {
2094+
auto __comp_res = __comp(__k, __rt->__get_value());
2095+
if (__comp_res.__less()) {
20872096
__result = static_cast<__end_node_pointer>(__rt);
20882097
__rt = static_cast<__node_pointer>(__rt->__left_);
2089-
} else if (value_comp()(__rt->__get_value(), __k))
2098+
} else if (__comp_res.__greater())
20902099
__rt = static_cast<__node_pointer>(__rt->__right_);
20912100
else
20922101
return std::distance(
@@ -2159,11 +2168,13 @@ __tree<_Tp, _Compare, _Allocator>::__equal_range_unique(const _Key& __k) {
21592168
using _Pp = pair<iterator, iterator>;
21602169
__end_node_pointer __result = __end_node();
21612170
__node_pointer __rt = __root();
2171+
auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
21622172
while (__rt != nullptr) {
2163-
if (value_comp()(__k, __rt->__get_value())) {
2173+
auto __comp_res = __comp(__k, __rt->__get_value());
2174+
if (__comp_res.__less()) {
21642175
__result = static_cast<__end_node_pointer>(__rt);
21652176
__rt = static_cast<__node_pointer>(__rt->__left_);
2166-
} else if (value_comp()(__rt->__get_value(), __k))
2177+
} else if (__comp_res.__greater())
21672178
__rt = static_cast<__node_pointer>(__rt->__right_);
21682179
else
21692180
return _Pp(iterator(__rt),
@@ -2181,11 +2192,13 @@ __tree<_Tp, _Compare, _Allocator>::__equal_range_unique(const _Key& __k) const {
21812192
using _Pp = pair<const_iterator, const_iterator>;
21822193
__end_node_pointer __result = __end_node();
21832194
__node_pointer __rt = __root();
2195+
auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
21842196
while (__rt != nullptr) {
2185-
if (value_comp()(__k, __rt->__get_value())) {
2197+
auto __comp_res = __comp(__k, __rt->__get_value());
2198+
if (__comp_res.__less()) {
21862199
__result = static_cast<__end_node_pointer>(__rt);
21872200
__rt = static_cast<__node_pointer>(__rt->__left_);
2188-
} else if (value_comp()(__rt->__get_value(), __k))
2201+
} else if (__comp_res.__greater())
21892202
__rt = static_cast<__node_pointer>(__rt->__right_);
21902203
else
21912204
return _Pp(
@@ -2202,12 +2215,14 @@ pair<typename __tree<_Tp, _Compare, _Allocator>::iterator, typename __tree<_Tp,
22022215
__tree<_Tp, _Compare, _Allocator>::__equal_range_multi(const _Key& __k) {
22032216
using _Pp = pair<iterator, iterator>;
22042217
__end_node_pointer __result = __end_node();
2205-
__node_pointer __rt = __root();
2218+
__node_pointer __rt = __root();
2219+
auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
22062220
while (__rt != nullptr) {
2207-
if (value_comp()(__k, __rt->__get_value())) {
2221+
auto __comp_res = __comp(__k, __rt->__get_value());
2222+
if (__comp_res.__less()) {
22082223
__result = static_cast<__end_node_pointer>(__rt);
22092224
__rt = static_cast<__node_pointer>(__rt->__left_);
2210-
} else if (value_comp()(__rt->__get_value(), __k))
2225+
} else if (__comp_res.__greater())
22112226
__rt = static_cast<__node_pointer>(__rt->__right_);
22122227
else
22132228
return _Pp(__lower_bound(__k, static_cast<__node_pointer>(__rt->__left_), static_cast<__end_node_pointer>(__rt)),
@@ -2223,12 +2238,14 @@ pair<typename __tree<_Tp, _Compare, _Allocator>::const_iterator,
22232238
__tree<_Tp, _Compare, _Allocator>::__equal_range_multi(const _Key& __k) const {
22242239
using _Pp = pair<const_iterator, const_iterator>;
22252240
__end_node_pointer __result = __end_node();
2226-
__node_pointer __rt = __root();
2241+
__node_pointer __rt = __root();
2242+
auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
22272243
while (__rt != nullptr) {
2228-
if (value_comp()(__k, __rt->__get_value())) {
2244+
auto __comp_res = __comp(__k, __rt->__get_value());
2245+
if (__comp_res.__less()) {
22292246
__result = static_cast<__end_node_pointer>(__rt);
22302247
__rt = static_cast<__node_pointer>(__rt->__left_);
2231-
} else if (value_comp()(__rt->__get_value(), __k))
2248+
} else if (__comp_res.__greater())
22322249
__rt = static_cast<__node_pointer>(__rt->__right_);
22332250
else
22342251
return _Pp(__lower_bound(__k, static_cast<__node_pointer>(__rt->__left_), static_cast<__end_node_pointer>(__rt)),
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef _LIBCPP___UTILITY_DEFAULT_THREE_WAY_COMPARATOR_H
10+
#define _LIBCPP___UTILITY_DEFAULT_THREE_WAY_COMPARATOR_H
11+
12+
#include <__config>
13+
#include <__type_traits/enable_if.h>
14+
#include <__type_traits/is_arithmetic.h>
15+
16+
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
17+
# pragma GCC system_header
18+
#endif
19+
20+
_LIBCPP_BEGIN_NAMESPACE_STD
21+
22+
// This struct can be specialized to provide a three way comparator between _LHS and _RHS.
23+
// The return value should be
24+
// - less than zero if (lhs_val < rhs_val)
25+
// - greater than zero if (rhs_val < lhs_val)
26+
// - zero otherwise
27+
template <class _LHS, class _RHS, class = void>
28+
struct __default_three_way_comparator;
29+
30+
template <class _Tp>
31+
struct __default_three_way_comparator<_Tp, _Tp, __enable_if_t<is_arithmetic<_Tp>::value> > {
32+
_LIBCPP_HIDE_FROM_ABI static int operator()(_Tp __lhs, _Tp __rhs) {
33+
if (__lhs < __rhs)
34+
return -1;
35+
if (__lhs > __rhs)
36+
return 1;
37+
return 0;
38+
}
39+
};
40+
41+
template <class _LHS, class _RHS, bool = true>
42+
inline const bool __has_default_three_way_comparator_v = false;
43+
44+
template <class _LHS, class _RHS>
45+
inline const bool
46+
__has_default_three_way_comparator_v< _LHS, _RHS, sizeof(__default_three_way_comparator<_LHS, _RHS>) >= 0> = true;
47+
48+
_LIBCPP_END_NAMESPACE_STD
49+
50+
#endif // _LIBCPP___UTILITY_DEFAULT_THREE_WAY_COMPARATOR_H
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef _LIBCPP___UTILITY_LAZY_SYNTH_THREE_WAY_COMPARATOR_H
10+
#define _LIBCPP___UTILITY_LAZY_SYNTH_THREE_WAY_COMPARATOR_H
11+
12+
#include <__config>
13+
#include <__type_traits/desugars_to.h>
14+
#include <__type_traits/enable_if.h>
15+
#include <__utility/default_three_way_comparator.h>
16+
17+
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
18+
# pragma GCC system_header
19+
#endif
20+
21+
// This file implements a __lazy_synth_three_way_comparator, which tries to build an efficient three way comparison from
22+
// a binary comparator. That is done in multiple steps:
23+
// 1) Check whether the comparator desugars to a less-than operator
24+
// If that is the case, check whether there exists a specialization of `__default_three_way_comparator`, which
25+
// can be specialized to implement a three way comparator for the specific types.
26+
// 2) Fall back to doing a lazy less than/greater than comparison
27+
28+
_LIBCPP_BEGIN_NAMESPACE_STD
29+
30+
template <class _Comparator, class _LHS, class _RHS>
31+
struct __lazy_compare_result {
32+
const _Comparator& __comp_;
33+
const _LHS& __lhs_;
34+
const _RHS& __rhs_;
35+
36+
_LIBCPP_HIDE_FROM_ABI
37+
__lazy_compare_result(_LIBCPP_CTOR_LIFETIMEBOUND const _Comparator& __comp,
38+
_LIBCPP_CTOR_LIFETIMEBOUND const _LHS& __lhs,
39+
_LIBCPP_CTOR_LIFETIMEBOUND const _RHS& __rhs)
40+
: __comp_(__comp), __lhs_(__lhs), __rhs_(__rhs) {}
41+
42+
_LIBCPP_HIDE_FROM_ABI bool __less() const { return __comp_(__lhs_, __rhs_); }
43+
_LIBCPP_HIDE_FROM_ABI bool __greater() const { return __comp_(__rhs_, __lhs_); }
44+
};
45+
46+
// This class provides three way comparison between _LHS and _RHS as efficiently as possible. This can be specialized if
47+
// a comparator only compares part of the object, potentially allowing an efficient three way comparison between the
48+
// subobjects. The specialization should use the __lazy_synth_three_way_comparator for the subobjects to achieve this.
49+
template <class _Comparator, class _LHS, class _RHS, class = void>
50+
struct __lazy_synth_three_way_comparator {
51+
const _Comparator& __comp_;
52+
53+
_LIBCPP_HIDE_FROM_ABI __lazy_synth_three_way_comparator(_LIBCPP_CTOR_LIFETIMEBOUND const _Comparator& __comp)
54+
: __comp_(__comp) {}
55+
56+
_LIBCPP_HIDE_FROM_ABI __lazy_compare_result<_Comparator, _LHS, _RHS>
57+
operator()(_LIBCPP_LIFETIMEBOUND const _LHS& __lhs, _LIBCPP_LIFETIMEBOUND const _RHS& __rhs) const {
58+
return __lazy_compare_result<_Comparator, _LHS, _RHS>(__comp_, __lhs, __rhs);
59+
}
60+
};
61+
62+
struct __eager_compare_result {
63+
int __res_;
64+
65+
_LIBCPP_HIDE_FROM_ABI explicit __eager_compare_result(int __res) : __res_(__res) {}
66+
67+
_LIBCPP_HIDE_FROM_ABI bool __less() const { return __res_ < 0; }
68+
_LIBCPP_HIDE_FROM_ABI bool __greater() const { return __res_ > 0; }
69+
};
70+
71+
template <class _Comparator, class _LHS, class _RHS>
72+
struct __lazy_synth_three_way_comparator<_Comparator,
73+
_LHS,
74+
_RHS,
75+
__enable_if_t<__desugars_to_v<__less_tag, _Comparator, _LHS, _RHS> &&
76+
__has_default_three_way_comparator_v<_LHS, _RHS> > > {
77+
// This lifetimebound annotation is technically incorrect, but other specializations actually capture the lifetime of
78+
// the comparator.
79+
_LIBCPP_HIDE_FROM_ABI __lazy_synth_three_way_comparator(_LIBCPP_CTOR_LIFETIMEBOUND const _Comparator&) {}
80+
81+
// Same comment as above.
82+
_LIBCPP_HIDE_FROM_ABI static __eager_compare_result
83+
operator()(_LIBCPP_LIFETIMEBOUND const _LHS& __lhs, _LIBCPP_LIFETIMEBOUND const _RHS& __rhs) {
84+
return __eager_compare_result(__default_three_way_comparator<_LHS, _RHS>()(__lhs, __rhs));
85+
}
86+
};
87+
88+
_LIBCPP_END_NAMESPACE_STD
89+
90+
#endif // _LIBCPP___UTILITY_LAZY_SYNTH_THREE_WAY_COMPARATOR_H

libcxx/include/map

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,7 @@ erase_if(multimap<Key, T, Compare, Allocator>& c, Predicate pred); // C++20
603603
# include <__type_traits/remove_const.h>
604604
# include <__type_traits/type_identity.h>
605605
# include <__utility/forward.h>
606+
# include <__utility/lazy_synth_three_way_comparator.h>
606607
# include <__utility/pair.h>
607608
# include <__utility/piecewise_construct.h>
608609
# include <__utility/swap.h>
@@ -702,6 +703,50 @@ public:
702703
# endif
703704
};
704705

706+
# if _LIBCPP_STD_VER >= 14
707+
template <class _MapValueT, class _Key, class _Compare>
708+
struct __lazy_synth_three_way_comparator<__map_value_compare<_Key, _MapValueT, _Compare>, _MapValueT, _MapValueT> {
709+
__lazy_synth_three_way_comparator<_Compare, _Key, _Key> __comp_;
710+
711+
__lazy_synth_three_way_comparator(
712+
_LIBCPP_CTOR_LIFETIMEBOUND const __map_value_compare<_Key, _MapValueT, _Compare>& __comp)
713+
: __comp_(__comp.key_comp()) {}
714+
715+
_LIBCPP_HIDE_FROM_ABI auto
716+
operator()(_LIBCPP_LIFETIMEBOUND const _MapValueT& __lhs, _LIBCPP_LIFETIMEBOUND const _MapValueT& __rhs) const {
717+
return __comp_(__lhs.first, __rhs.first);
718+
}
719+
};
720+
721+
template <class _MapValueT, class _Key, class _TransparentKey, class _Compare>
722+
struct __lazy_synth_three_way_comparator<__map_value_compare<_Key, _MapValueT, _Compare>, _TransparentKey, _MapValueT> {
723+
__lazy_synth_three_way_comparator<_Compare, _TransparentKey, _Key> __comp_;
724+
725+
__lazy_synth_three_way_comparator(
726+
_LIBCPP_CTOR_LIFETIMEBOUND const __map_value_compare<_Key, _MapValueT, _Compare>& __comp)
727+
: __comp_(__comp.key_comp()) {}
728+
729+
_LIBCPP_HIDE_FROM_ABI auto
730+
operator()(_LIBCPP_LIFETIMEBOUND const _TransparentKey& __lhs, _LIBCPP_LIFETIMEBOUND const _MapValueT& __rhs) const {
731+
return __comp_(__lhs, __rhs.first);
732+
}
733+
};
734+
735+
template <class _MapValueT, class _Key, class _TransparentKey, class _Compare>
736+
struct __lazy_synth_three_way_comparator<__map_value_compare<_Key, _MapValueT, _Compare>, _MapValueT, _TransparentKey> {
737+
__lazy_synth_three_way_comparator<_Compare, _Key, _TransparentKey> __comp_;
738+
739+
__lazy_synth_three_way_comparator(
740+
_LIBCPP_CTOR_LIFETIMEBOUND const __map_value_compare<_Key, _MapValueT, _Compare>& __comp)
741+
: __comp_(__comp.key_comp()) {}
742+
743+
_LIBCPP_HIDE_FROM_ABI auto
744+
operator()(_LIBCPP_LIFETIMEBOUND const _MapValueT& __lhs, _LIBCPP_LIFETIMEBOUND const _TransparentKey& __rhs) const {
745+
return __comp_(__lhs.first, __rhs);
746+
}
747+
};
748+
# endif // _LIBCPP_STD_VER >= 14
749+
705750
template <class _Key, class _CP, class _Compare, bool __b>
706751
inline _LIBCPP_HIDE_FROM_ABI void
707752
swap(__map_value_compare<_Key, _CP, _Compare, __b>& __x, __map_value_compare<_Key, _CP, _Compare, __b>& __y)

0 commit comments

Comments
 (0)