Skip to content

[libc++] Optimize __tree::find and __tree::__erase_unique #152370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions libcxx/docs/ReleaseNotes/22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ Improvements and New Features
- The performance of ``map::operator=(const map&)`` has been improved by up to 11x
- The performance of ``unordered_set::unordered_set(const unordered_set&)`` has been improved by up to 3.3x.
- The performance of ``unordered_set::operator=(const unordered_set&)`` has been improved by up to 5x.
- The performance of ``map::erase`` and ``set::erase`` has been improved by up to 2x
- The performance of ``find(key)`` in ``map``, ``set``, ``multimap`` and ``multiset`` has been improved by up to 2.3x

Deprecations and Removals
-------------------------
Expand Down
36 changes: 15 additions & 21 deletions libcxx/include/__tree
Original file line number Diff line number Diff line change
Expand Up @@ -1038,9 +1038,22 @@ public:
__insert_node_at(__end_node_pointer __parent, __node_base_pointer& __child, __node_base_pointer __new_node) _NOEXCEPT;

template <class _Key>
_LIBCPP_HIDE_FROM_ABI iterator find(const _Key& __v);
_LIBCPP_HIDE_FROM_ABI iterator find(const _Key& __key) {
__end_node_pointer __parent;
__node_base_pointer __match = __find_equal(__parent, __key);
if (__match == nullptr)
return end();
return iterator(static_cast<__node_pointer>(__match));
}

template <class _Key>
_LIBCPP_HIDE_FROM_ABI const_iterator find(const _Key& __v) const;
_LIBCPP_HIDE_FROM_ABI const_iterator find(const _Key& __key) const {
__end_node_pointer __parent;
__node_base_pointer __match = __find_equal(__parent, __key);
if (__match == nullptr)
return end();
return const_iterator(static_cast<__node_pointer>(__match));
}

template <class _Key>
_LIBCPP_HIDE_FROM_ABI size_type __count_unique(const _Key& __k) const;
Expand Down Expand Up @@ -2060,25 +2073,6 @@ __tree<_Tp, _Compare, _Allocator>::__erase_multi(const _Key& __k) {
return __r;
}

template <class _Tp, class _Compare, class _Allocator>
template <class _Key>
typename __tree<_Tp, _Compare, _Allocator>::iterator __tree<_Tp, _Compare, _Allocator>::find(const _Key& __v) {
iterator __p = __lower_bound(__v, __root(), __end_node());
if (__p != end() && !value_comp()(__v, *__p))
return __p;
return end();
}

template <class _Tp, class _Compare, class _Allocator>
template <class _Key>
typename __tree<_Tp, _Compare, _Allocator>::const_iterator
__tree<_Tp, _Compare, _Allocator>::find(const _Key& __v) const {
const_iterator __p = __lower_bound(__v, __root(), __end_node());
if (__p != end() && !value_comp()(__v, *__p))
return __p;
return end();
}

template <class _Tp, class _Compare, class _Allocator>
template <class _Key>
typename __tree<_Tp, _Compare, _Allocator>::size_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@
#include "private_constructor.h"
#include "is_transparent.h"

template <class Iter>
bool iter_in_range(Iter first, Iter last, Iter to_find) {
for (; first != last; ++first) {
if (first == to_find)
return true;
}
return false;
}

int main(int, char**) {
typedef std::pair<const int, double> V;
{
Expand All @@ -30,15 +39,15 @@ int main(int, char**) {
V ar[] = {V(5, 1), V(5, 2), V(5, 3), V(7, 1), V(7, 2), V(7, 3), V(9, 1), V(9, 2), V(9, 3)};
M m(ar, ar + sizeof(ar) / sizeof(ar[0]));
R r = m.find(5);
assert(r == m.begin());
assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
r = m.find(6);
assert(r == m.end());
r = m.find(7);
assert(r == std::next(m.begin(), 3));
assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
r = m.find(8);
assert(r == m.end());
r = m.find(9);
assert(r == std::next(m.begin(), 6));
assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
r = m.find(10);
assert(r == m.end());
}
Expand All @@ -47,15 +56,15 @@ int main(int, char**) {
V ar[] = {V(5, 1), V(5, 2), V(5, 3), V(7, 1), V(7, 2), V(7, 3), V(9, 1), V(9, 2), V(9, 3)};
const M m(ar, ar + sizeof(ar) / sizeof(ar[0]));
R r = m.find(5);
assert(r == m.begin());
assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
r = m.find(6);
assert(r == m.end());
r = m.find(7);
assert(r == std::next(m.begin(), 3));
assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
r = m.find(8);
assert(r == m.end());
r = m.find(9);
assert(r == std::next(m.begin(), 6));
assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
r = m.find(10);
assert(r == m.end());
}
Expand All @@ -68,15 +77,15 @@ int main(int, char**) {
V ar[] = {V(5, 1), V(5, 2), V(5, 3), V(7, 1), V(7, 2), V(7, 3), V(9, 1), V(9, 2), V(9, 3)};
M m(ar, ar + sizeof(ar) / sizeof(ar[0]));
R r = m.find(5);
assert(r == m.begin());
assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
r = m.find(6);
assert(r == m.end());
r = m.find(7);
assert(r == std::next(m.begin(), 3));
assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
r = m.find(8);
assert(r == m.end());
r = m.find(9);
assert(r == std::next(m.begin(), 6));
assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
r = m.find(10);
assert(r == m.end());
}
Expand All @@ -85,15 +94,15 @@ int main(int, char**) {
V ar[] = {V(5, 1), V(5, 2), V(5, 3), V(7, 1), V(7, 2), V(7, 3), V(9, 1), V(9, 2), V(9, 3)};
const M m(ar, ar + sizeof(ar) / sizeof(ar[0]));
R r = m.find(5);
assert(r == m.begin());
assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
r = m.find(6);
assert(r == m.end());
r = m.find(7);
assert(r == std::next(m.begin(), 3));
assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
r = m.find(8);
assert(r == m.end());
r = m.find(9);
assert(r == std::next(m.begin(), 6));
assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
r = m.find(10);
assert(r == m.end());
}
Expand All @@ -107,28 +116,28 @@ int main(int, char**) {
V ar[] = {V(5, 1), V(5, 2), V(5, 3), V(7, 1), V(7, 2), V(7, 3), V(9, 1), V(9, 2), V(9, 3)};
M m(ar, ar + sizeof(ar) / sizeof(ar[0]));
R r = m.find(5);
assert(r == m.begin());
assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
r = m.find(6);
assert(r == m.end());
r = m.find(7);
assert(r == std::next(m.begin(), 3));
assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
r = m.find(8);
assert(r == m.end());
r = m.find(9);
assert(r == std::next(m.begin(), 6));
assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
r = m.find(10);
assert(r == m.end());

r = m.find(C2Int(5));
assert(r == m.begin());
assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
r = m.find(C2Int(6));
assert(r == m.end());
r = m.find(C2Int(7));
assert(r == std::next(m.begin(), 3));
assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
r = m.find(C2Int(8));
assert(r == m.end());
r = m.find(C2Int(9));
assert(r == std::next(m.begin(), 6));
assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
r = m.find(C2Int(10));
assert(r == m.end());
}
Expand All @@ -150,15 +159,15 @@ int main(int, char**) {
m.insert(std::make_pair<PC, double>(PC::make(9), 3));

R r = m.find(5);
assert(r == m.begin());
assert(iter_in_range(std::next(m.begin(), 0), std::next(m.begin(), 3), r));
r = m.find(6);
assert(r == m.end());
r = m.find(7);
assert(r == std::next(m.begin(), 3));
assert(iter_in_range(std::next(m.begin(), 3), std::next(m.begin(), 6), r));
r = m.find(8);
assert(r == m.end());
r = m.find(9);
assert(r == std::next(m.begin(), 6));
assert(iter_in_range(std::next(m.begin(), 6), std::next(m.begin(), 9), r));
r = m.find(10);
assert(r == m.end());
}
Expand Down
Loading