Skip to content

Commit fdfeed6

Browse files
committed
[libc++] Optimize __tree::__find_equal
1 parent a271d07 commit fdfeed6

File tree

14 files changed

+277
-46
lines changed

14 files changed

+277
-46
lines changed

libcxx/include/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,7 @@ set(files
911911
__utility/cmp.h
912912
__utility/convert_to_integral.h
913913
__utility/declval.h
914+
__utility/default_three_way_comparator.h
914915
__utility/element_count.h
915916
__utility/empty.h
916917
__utility/exception_guard.h
@@ -921,6 +922,7 @@ set(files
921922
__utility/integer_sequence.h
922923
__utility/is_pointer_in_range.h
923924
__utility/is_valid_range.h
925+
__utility/lazy_synth_three_way_comparator.h
924926
__utility/move.h
925927
__utility/no_destroy.h
926928
__utility/pair.h

libcxx/include/__config

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,13 @@ typedef __char32_t char32_t;
11561156
# define _LIBCPP_LIFETIMEBOUND
11571157
# endif
11581158

1159+
// This is to work around https://llvm.org/PR156809
1160+
# ifndef _LIBCPP_CXX03_LANG
1161+
# define _LIBCPP_CTOR_LIFETIMEBOUND _LIBCPP_LIFETIMEBOUND
1162+
# else
1163+
# define _LIBCPP_CTOR_LIFETIMEBOUND
1164+
# endif
1165+
11591166
# if __has_cpp_attribute(_Clang::__noescape__)
11601167
# define _LIBCPP_NOESCAPE [[_Clang::__noescape__]]
11611168
# else

libcxx/include/__tree

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include <__type_traits/is_swappable.h>
3535
#include <__type_traits/remove_const.h>
3636
#include <__utility/forward.h>
37+
#include <__utility/lazy_synth_three_way_comparator.h>
3738
#include <__utility/move.h>
3839
#include <__utility/pair.h>
3940
#include <__utility/swap.h>
@@ -1749,14 +1750,18 @@ __tree<_Tp, _Compare, _Allocator>::__find_equal(const _Key& __v) {
17491750
}
17501751

17511752
__node_base_pointer* __node_ptr = __root_ptr();
1753+
auto __comp = __lazy_synth_three_way_comparator<_Compare, _Key, value_type>(value_comp());
1754+
17521755
while (true) {
1753-
if (value_comp()(__v, __nd->__get_value())) {
1756+
auto __comp_res = __comp(__v, __nd->__get_value());
1757+
1758+
if (__comp_res.__less()) {
17541759
if (__nd->__left_ == nullptr)
17551760
return _Pair(static_cast<__end_node_pointer>(__nd), __nd->__left_);
17561761

17571762
__node_ptr = std::addressof(__nd->__left_);
17581763
__nd = static_cast<__node_pointer>(__nd->__left_);
1759-
} else if (value_comp()(__nd->__get_value(), __v)) {
1764+
} else if (__comp_res.__greater()) {
17601765
if (__nd->__right_ == nullptr)
17611766
return _Pair(static_cast<__end_node_pointer>(__nd), __nd->__right_);
17621767

@@ -2065,10 +2070,12 @@ template <class _Key>
20652070
typename __tree<_Tp, _Compare, _Allocator>::size_type
20662071
__tree<_Tp, _Compare, _Allocator>::__count_unique(const _Key& __k) const {
20672072
__node_pointer __rt = __root();
2073+
auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
20682074
while (__rt != nullptr) {
2069-
if (value_comp()(__k, __rt->__get_value())) {
2075+
auto __comp_res = __comp(__k, __rt->__get_value());
2076+
if (__comp_res.__less()) {
20702077
__rt = static_cast<__node_pointer>(__rt->__left_);
2071-
} else if (value_comp()(__rt->__get_value(), __k))
2078+
} else if (__comp_res.__greater())
20722079
__rt = static_cast<__node_pointer>(__rt->__right_);
20732080
else
20742081
return 1;
@@ -2082,11 +2089,13 @@ typename __tree<_Tp, _Compare, _Allocator>::size_type
20822089
__tree<_Tp, _Compare, _Allocator>::__count_multi(const _Key& __k) const {
20832090
__end_node_pointer __result = __end_node();
20842091
__node_pointer __rt = __root();
2092+
auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
20852093
while (__rt != nullptr) {
2086-
if (value_comp()(__k, __rt->__get_value())) {
2094+
auto __comp_res = __comp(__k, __rt->__get_value());
2095+
if (__comp_res.__less()) {
20872096
__result = static_cast<__end_node_pointer>(__rt);
20882097
__rt = static_cast<__node_pointer>(__rt->__left_);
2089-
} else if (value_comp()(__rt->__get_value(), __k))
2098+
} else if (__comp_res.__greater())
20902099
__rt = static_cast<__node_pointer>(__rt->__right_);
20912100
else
20922101
return std::distance(
@@ -2159,11 +2168,13 @@ __tree<_Tp, _Compare, _Allocator>::__equal_range_unique(const _Key& __k) {
21592168
using _Pp = pair<iterator, iterator>;
21602169
__end_node_pointer __result = __end_node();
21612170
__node_pointer __rt = __root();
2171+
auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
21622172
while (__rt != nullptr) {
2163-
if (value_comp()(__k, __rt->__get_value())) {
2173+
auto __comp_res = __comp(__k, __rt->__get_value());
2174+
if (__comp_res.__less()) {
21642175
__result = static_cast<__end_node_pointer>(__rt);
21652176
__rt = static_cast<__node_pointer>(__rt->__left_);
2166-
} else if (value_comp()(__rt->__get_value(), __k))
2177+
} else if (__comp_res.__greater())
21672178
__rt = static_cast<__node_pointer>(__rt->__right_);
21682179
else
21692180
return _Pp(iterator(__rt),
@@ -2181,11 +2192,13 @@ __tree<_Tp, _Compare, _Allocator>::__equal_range_unique(const _Key& __k) const {
21812192
using _Pp = pair<const_iterator, const_iterator>;
21822193
__end_node_pointer __result = __end_node();
21832194
__node_pointer __rt = __root();
2195+
auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
21842196
while (__rt != nullptr) {
2185-
if (value_comp()(__k, __rt->__get_value())) {
2197+
auto __comp_res = __comp(__k, __rt->__get_value());
2198+
if (__comp_res.__less()) {
21862199
__result = static_cast<__end_node_pointer>(__rt);
21872200
__rt = static_cast<__node_pointer>(__rt->__left_);
2188-
} else if (value_comp()(__rt->__get_value(), __k))
2201+
} else if (__comp_res.__greater())
21892202
__rt = static_cast<__node_pointer>(__rt->__right_);
21902203
else
21912204
return _Pp(
@@ -2202,12 +2215,14 @@ pair<typename __tree<_Tp, _Compare, _Allocator>::iterator, typename __tree<_Tp,
22022215
__tree<_Tp, _Compare, _Allocator>::__equal_range_multi(const _Key& __k) {
22032216
using _Pp = pair<iterator, iterator>;
22042217
__end_node_pointer __result = __end_node();
2205-
__node_pointer __rt = __root();
2218+
__node_pointer __rt = __root();
2219+
auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
22062220
while (__rt != nullptr) {
2207-
if (value_comp()(__k, __rt->__get_value())) {
2221+
auto __comp_res = __comp(__k, __rt->__get_value());
2222+
if (__comp_res.__less()) {
22082223
__result = static_cast<__end_node_pointer>(__rt);
22092224
__rt = static_cast<__node_pointer>(__rt->__left_);
2210-
} else if (value_comp()(__rt->__get_value(), __k))
2225+
} else if (__comp_res.__greater())
22112226
__rt = static_cast<__node_pointer>(__rt->__right_);
22122227
else
22132228
return _Pp(__lower_bound(__k, static_cast<__node_pointer>(__rt->__left_), static_cast<__end_node_pointer>(__rt)),
@@ -2223,12 +2238,14 @@ pair<typename __tree<_Tp, _Compare, _Allocator>::const_iterator,
22232238
__tree<_Tp, _Compare, _Allocator>::__equal_range_multi(const _Key& __k) const {
22242239
using _Pp = pair<const_iterator, const_iterator>;
22252240
__end_node_pointer __result = __end_node();
2226-
__node_pointer __rt = __root();
2241+
__node_pointer __rt = __root();
2242+
auto __comp = __lazy_synth_three_way_comparator<value_compare, _Key, value_type>(value_comp());
22272243
while (__rt != nullptr) {
2228-
if (value_comp()(__k, __rt->__get_value())) {
2244+
auto __comp_res = __comp(__k, __rt->__get_value());
2245+
if (__comp_res.__less()) {
22292246
__result = static_cast<__end_node_pointer>(__rt);
22302247
__rt = static_cast<__node_pointer>(__rt->__left_);
2231-
} else if (value_comp()(__rt->__get_value(), __k))
2248+
} else if (__comp_res.__greater())
22322249
__rt = static_cast<__node_pointer>(__rt->__right_);
22332250
else
22342251
return _Pp(__lower_bound(__k, static_cast<__node_pointer>(__rt->__left_), static_cast<__end_node_pointer>(__rt)),
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef _LIBCPP___UTILITY_DEFAULT_THREE_WAY_COMPARATOR_H
10+
#define _LIBCPP___UTILITY_DEFAULT_THREE_WAY_COMPARATOR_H
11+
12+
#include <__config>
13+
#include <__type_traits/enable_if.h>
14+
#include <__type_traits/is_arithmetic.h>
15+
16+
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
17+
# pragma GCC system_header
18+
#endif
19+
20+
_LIBCPP_BEGIN_NAMESPACE_STD
21+
22+
// This struct can be specialized to provide a three way comparator between _LHS and _RHS.
23+
// The return value should be
24+
// - less than zero if (lhs_val < rhs_val)
25+
// - greater than zero if (rhs_val < lhs_val)
26+
// - zero otherwise
27+
template <class _LHS, class _RHS, class = void>
28+
struct __default_three_way_comparator;
29+
30+
template <class _Tp>
31+
struct __default_three_way_comparator<_Tp, _Tp, __enable_if_t<is_arithmetic<_Tp>::value> > {
32+
_LIBCPP_HIDE_FROM_ABI static int operator()(_Tp __lhs, _Tp __rhs) {
33+
if (__lhs < __rhs)
34+
return -1;
35+
if (__lhs > __rhs)
36+
return 1;
37+
return 0;
38+
}
39+
};
40+
41+
template <class _LHS, class _RHS, bool = true>
42+
inline const bool __has_default_three_way_comparator_v = false;
43+
44+
template <class _LHS, class _RHS>
45+
inline const bool
46+
__has_default_three_way_comparator_v< _LHS, _RHS, sizeof(__default_three_way_comparator<_LHS, _RHS>) >= 0> = true;
47+
48+
_LIBCPP_END_NAMESPACE_STD
49+
50+
#endif // _LIBCPP___UTILITY_DEFAULT_THREE_WAY_COMPARATOR_H
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef _LIBCPP___UTILITY_LAZY_SYNTH_THREE_WAY_COMPARATOR_H
10+
#define _LIBCPP___UTILITY_LAZY_SYNTH_THREE_WAY_COMPARATOR_H
11+
12+
#include <__config>
13+
#include <__type_traits/desugars_to.h>
14+
#include <__type_traits/enable_if.h>
15+
#include <__utility/default_three_way_comparator.h>
16+
17+
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
18+
# pragma GCC system_header
19+
#endif
20+
21+
// This file implements a __lazy_synth_three_way_comparator, which tries to build an efficient three way comparison from
22+
// a binary comparator. That is done in multiple steps:
23+
// 1) Check whether the comparator desugars to a less-than operator
24+
// If that is the case, check whether there exists a specialization of `__default_three_way_comparator`, which
25+
// can be specialized to implement a three way comparator for the specific types.
26+
// 2) Fall back to doing a lazy less than/greater than comparison
27+
28+
_LIBCPP_BEGIN_NAMESPACE_STD
29+
30+
template <class _Comparator, class _LHS, class _RHS>
31+
struct __lazy_compare_result {
32+
const _Comparator& __comp_;
33+
const _LHS& __lhs_;
34+
const _RHS& __rhs_;
35+
36+
_LIBCPP_HIDE_FROM_ABI
37+
__lazy_compare_result(_LIBCPP_CTOR_LIFETIMEBOUND const _Comparator& __comp,
38+
_LIBCPP_CTOR_LIFETIMEBOUND const _LHS& __lhs,
39+
_LIBCPP_CTOR_LIFETIMEBOUND const _RHS& __rhs)
40+
: __comp_(__comp), __lhs_(__lhs), __rhs_(__rhs) {}
41+
42+
_LIBCPP_HIDE_FROM_ABI bool __less() const { return __comp_(__lhs_, __rhs_); }
43+
_LIBCPP_HIDE_FROM_ABI bool __greater() const { return __comp_(__rhs_, __lhs_); }
44+
};
45+
46+
// This class provides three way comparison between _LHS and _RHS as efficiently as possible. This can be specialized if
47+
// a comparator only compares part of the object, potentially allowing an efficient three way comparison between the
48+
// subobjects. The specialization should use the __lazy_synth_three_way_comparator for the subobjects to achieve this.
49+
template <class _Comparator, class _LHS, class _RHS, class = void>
50+
struct __lazy_synth_three_way_comparator {
51+
const _Comparator& __comp_;
52+
53+
_LIBCPP_HIDE_FROM_ABI __lazy_synth_three_way_comparator(_LIBCPP_CTOR_LIFETIMEBOUND const _Comparator& __comp)
54+
: __comp_(__comp) {}
55+
56+
_LIBCPP_HIDE_FROM_ABI __lazy_compare_result<_Comparator, _LHS, _RHS>
57+
operator()(_LIBCPP_LIFETIMEBOUND const _LHS& __lhs, _LIBCPP_LIFETIMEBOUND const _RHS& __rhs) const {
58+
return __lazy_compare_result<_Comparator, _LHS, _RHS>(__comp_, __lhs, __rhs);
59+
}
60+
};
61+
62+
struct __eager_compare_result {
63+
int __res_;
64+
65+
_LIBCPP_HIDE_FROM_ABI explicit __eager_compare_result(int __res) : __res_(__res) {}
66+
67+
_LIBCPP_HIDE_FROM_ABI bool __less() const { return __res_ < 0; }
68+
_LIBCPP_HIDE_FROM_ABI bool __greater() const { return __res_ > 0; }
69+
};
70+
71+
template <class _Comparator, class _LHS, class _RHS>
72+
struct __lazy_synth_three_way_comparator<_Comparator,
73+
_LHS,
74+
_RHS,
75+
__enable_if_t<__desugars_to_v<__less_tag, _Comparator, _LHS, _RHS> &&
76+
__has_default_three_way_comparator_v<_LHS, _RHS> > > {
77+
// This lifetimebound annotation is technically incorrect, but other specializations actually capture the lifetime of
78+
// the comparator.
79+
_LIBCPP_HIDE_FROM_ABI __lazy_synth_three_way_comparator(_LIBCPP_CTOR_LIFETIMEBOUND const _Comparator&) {}
80+
81+
// Same comment as above.
82+
_LIBCPP_HIDE_FROM_ABI static __eager_compare_result
83+
operator()(_LIBCPP_LIFETIMEBOUND const _LHS& __lhs, _LIBCPP_LIFETIMEBOUND const _RHS& __rhs) {
84+
return __eager_compare_result(__default_three_way_comparator<_LHS, _RHS>()(__lhs, __rhs));
85+
}
86+
};
87+
88+
_LIBCPP_END_NAMESPACE_STD
89+
90+
#endif // _LIBCPP___UTILITY_LAZY_SYNTH_THREE_WAY_COMPARATOR_H

libcxx/include/map

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,7 @@ erase_if(multimap<Key, T, Compare, Allocator>& c, Predicate pred); // C++20
603603
# include <__type_traits/remove_const.h>
604604
# include <__type_traits/type_identity.h>
605605
# include <__utility/forward.h>
606+
# include <__utility/lazy_synth_three_way_comparator.h>
606607
# include <__utility/pair.h>
607608
# include <__utility/piecewise_construct.h>
608609
# include <__utility/swap.h>
@@ -702,6 +703,50 @@ public:
702703
# endif
703704
};
704705

706+
# if _LIBCPP_STD_VER >= 14
707+
template <class _MapValueT, class _Key, class _Compare>
708+
struct __lazy_synth_three_way_comparator<__map_value_compare<_Key, _MapValueT, _Compare>, _MapValueT, _MapValueT> {
709+
__lazy_synth_three_way_comparator<_Compare, _Key, _Key> __comp_;
710+
711+
__lazy_synth_three_way_comparator(
712+
_LIBCPP_CTOR_LIFETIMEBOUND const __map_value_compare<_Key, _MapValueT, _Compare>& __comp)
713+
: __comp_(__comp.key_comp()) {}
714+
715+
_LIBCPP_HIDE_FROM_ABI auto
716+
operator()(_LIBCPP_LIFETIMEBOUND const _MapValueT& __lhs, _LIBCPP_LIFETIMEBOUND const _MapValueT& __rhs) const {
717+
return __comp_(__lhs.first, __rhs.first);
718+
}
719+
};
720+
721+
template <class _MapValueT, class _Key, class _TransparentKey, class _Compare>
722+
struct __lazy_synth_three_way_comparator<__map_value_compare<_Key, _MapValueT, _Compare>, _TransparentKey, _MapValueT> {
723+
__lazy_synth_three_way_comparator<_Compare, _TransparentKey, _Key> __comp_;
724+
725+
__lazy_synth_three_way_comparator(
726+
_LIBCPP_CTOR_LIFETIMEBOUND const __map_value_compare<_Key, _MapValueT, _Compare>& __comp)
727+
: __comp_(__comp.key_comp()) {}
728+
729+
_LIBCPP_HIDE_FROM_ABI auto
730+
operator()(_LIBCPP_LIFETIMEBOUND const _TransparentKey& __lhs, _LIBCPP_LIFETIMEBOUND const _MapValueT& __rhs) const {
731+
return __comp_(__lhs, __rhs.first);
732+
}
733+
};
734+
735+
template <class _MapValueT, class _Key, class _TransparentKey, class _Compare>
736+
struct __lazy_synth_three_way_comparator<__map_value_compare<_Key, _MapValueT, _Compare>, _MapValueT, _TransparentKey> {
737+
__lazy_synth_three_way_comparator<_Compare, _Key, _TransparentKey> __comp_;
738+
739+
__lazy_synth_three_way_comparator(
740+
_LIBCPP_CTOR_LIFETIMEBOUND const __map_value_compare<_Key, _MapValueT, _Compare>& __comp)
741+
: __comp_(__comp.key_comp()) {}
742+
743+
_LIBCPP_HIDE_FROM_ABI auto
744+
operator()(_LIBCPP_LIFETIMEBOUND const _MapValueT& __lhs, _LIBCPP_LIFETIMEBOUND const _TransparentKey& __rhs) const {
745+
return __comp_(__lhs.first, __rhs);
746+
}
747+
};
748+
# endif // _LIBCPP_STD_VER >= 14
749+
705750
template <class _Key, class _CP, class _Compare, bool __b>
706751
inline _LIBCPP_HIDE_FROM_ABI void
707752
swap(__map_value_compare<_Key, _CP, _Compare, __b>& __x, __map_value_compare<_Key, _CP, _Compare, __b>& __y)

0 commit comments

Comments
 (0)