Skip to content

[libc++] Refactor __tree::__find_equal to not have an out parameter #147345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 98 additions & 120 deletions libcxx/include/__tree
Original file line number Diff line number Diff line change
Expand Up @@ -1086,15 +1086,89 @@ public:

// FIXME: Make this function const qualified. Unfortunately doing so
// breaks existing code which uses non-const callable comparators.

// Find place to insert if __v doesn't exist
// Set __parent to parent of null leaf
// Return reference to null leaf
// If __v exists, set parent to node of __v and return reference to node of __v
template <class _Key>
_LIBCPP_HIDE_FROM_ABI __node_base_pointer& __find_equal(__end_node_pointer& __parent, const _Key& __v);
_LIBCPP_HIDE_FROM_ABI pair<__end_node_pointer, __node_base_pointer&> __find_equal(const _Key& __v) {
using _PairT = pair<__end_node_pointer, __node_base_pointer&>;

__node_pointer __node = __root();

if (__node == nullptr) {
auto __end = __end_node();
return _PairT(__end, __end->__left_);
}

__node_base_pointer* __node_ptr = __root_ptr();
while (true) {
if (value_comp()(__v, __node->__value_)) {
if (__node->__left_ == nullptr)
return _PairT(static_cast<__end_node_pointer>(__node), __node->__left_);

__node_ptr = std::addressof(__node->__left_);
__node = static_cast<__node_pointer>(__node->__left_);
} else if (value_comp()(__node->__value_, __v)) {
if (__node->__right_ == nullptr)
return _PairT(static_cast<__end_node_pointer>(__node), __node->__right_);

__node_ptr = std::addressof(__node->__right_);
__node = static_cast<__node_pointer>(__node->__right_);
} else {
return _PairT(static_cast<__end_node_pointer>(__node), *__node_ptr);
}
}
}

template <class _Key>
_LIBCPP_HIDE_FROM_ABI __node_base_pointer& __find_equal(__end_node_pointer& __parent, const _Key& __v) const {
return const_cast<__tree*>(this)->__find_equal(__parent, __v);
_LIBCPP_HIDE_FROM_ABI pair<__end_node_pointer, __node_base_pointer&> __find_equal(const _Key& __v) const {
return const_cast<__tree*>(this)->__find_equal(__v);
}

// Find place to insert if __v doesn't exist
// First check prior to __hint.
// Next check after __hint.
// Next do O(log N) search.
// Set __parent to parent of null leaf
// Return reference to null leaf
// If __v exists, set parent to node of __v and return reference to node of __v
template <class _Key>
_LIBCPP_HIDE_FROM_ABI __node_base_pointer&
__find_equal(const_iterator __hint, __end_node_pointer& __parent, __node_base_pointer& __dummy, const _Key& __v);
_LIBCPP_HIDE_FROM_ABI pair<__end_node_pointer, __node_base_pointer&>
__find_equal(const_iterator __hint, __node_base_pointer& __dummy, const _Key& __v) {
using _PairT = pair<__end_node_pointer, __node_base_pointer&>;

if (__hint == end() || value_comp()(__v, *__hint)) { // check before
// __v < *__hint
const_iterator __prior = __hint;
if (__prior == begin() || value_comp()(*--__prior, __v)) {
// *prev(__hint) < __v < *__hint
if (__hint.__ptr_->__left_ == nullptr)
return _PairT(__hint.__ptr_, __hint.__ptr_->__left_);
return _PairT(__prior.__ptr_, static_cast<__node_pointer>(__prior.__ptr_)->__right_);
}
// __v <= *prev(__hint)
return __find_equal(__v);
}

if (value_comp()(*__hint, __v)) { // check after
// *__hint < __v
const_iterator __next = std::next(__hint);
if (__next == end() || value_comp()(__v, *__next)) {
// *__hint < __v < *std::next(__hint)
if (__hint.__get_np()->__right_ == nullptr)
return _PairT(__hint.__ptr_, static_cast<__node_pointer>(__hint.__ptr_)->__right_);
return _PairT(__next.__ptr_, __next.__ptr_->__left_);
}
// *next(__hint) <= __v
return __find_equal(__v);
}

// else __v == *__hint
__dummy = static_cast<__node_base_pointer>(__hint.__ptr_);
return _PairT(__hint.__ptr_, __dummy);
}

_LIBCPP_HIDE_FROM_ABI void __copy_assign_alloc(const __tree& __t) {
__copy_assign_alloc(__t, integral_constant<bool, __node_traits::propagate_on_container_copy_assignment::value>());
Expand Down Expand Up @@ -1647,94 +1721,6 @@ typename __tree<_Tp, _Compare, _Allocator>::__node_base_pointer& __tree<_Tp, _Co
return __find_leaf_low(__parent, __v);
}

// Find place to insert if __v doesn't exist
// Set __parent to parent of null leaf
// Return reference to null leaf
// If __v exists, set parent to node of __v and return reference to node of __v
template <class _Tp, class _Compare, class _Allocator>
template <class _Key>
typename __tree<_Tp, _Compare, _Allocator>::__node_base_pointer&
__tree<_Tp, _Compare, _Allocator>::__find_equal(__end_node_pointer& __parent, const _Key& __v) {
__node_pointer __nd = __root();
__node_base_pointer* __nd_ptr = __root_ptr();
if (__nd != nullptr) {
while (true) {
if (value_comp()(__v, __nd->__value_)) {
if (__nd->__left_ != nullptr) {
__nd_ptr = std::addressof(__nd->__left_);
__nd = static_cast<__node_pointer>(__nd->__left_);
} else {
__parent = static_cast<__end_node_pointer>(__nd);
return __parent->__left_;
}
} else if (value_comp()(__nd->__value_, __v)) {
if (__nd->__right_ != nullptr) {
__nd_ptr = std::addressof(__nd->__right_);
__nd = static_cast<__node_pointer>(__nd->__right_);
} else {
__parent = static_cast<__end_node_pointer>(__nd);
return __nd->__right_;
}
} else {
__parent = static_cast<__end_node_pointer>(__nd);
return *__nd_ptr;
}
}
}
__parent = __end_node();
return __parent->__left_;
}

// Find place to insert if __v doesn't exist
// First check prior to __hint.
// Next check after __hint.
// Next do O(log N) search.
// Set __parent to parent of null leaf
// Return reference to null leaf
// If __v exists, set parent to node of __v and return reference to node of __v
template <class _Tp, class _Compare, class _Allocator>
template <class _Key>
typename __tree<_Tp, _Compare, _Allocator>::__node_base_pointer& __tree<_Tp, _Compare, _Allocator>::__find_equal(
const_iterator __hint, __end_node_pointer& __parent, __node_base_pointer& __dummy, const _Key& __v) {
if (__hint == end() || value_comp()(__v, *__hint)) // check before
{
// __v < *__hint
const_iterator __prior = __hint;
if (__prior == begin() || value_comp()(*--__prior, __v)) {
// *prev(__hint) < __v < *__hint
if (__hint.__ptr_->__left_ == nullptr) {
__parent = __hint.__ptr_;
return __parent->__left_;
} else {
__parent = __prior.__ptr_;
return static_cast<__node_base_pointer>(__prior.__ptr_)->__right_;
}
}
// __v <= *prev(__hint)
return __find_equal(__parent, __v);
} else if (value_comp()(*__hint, __v)) // check after
{
// *__hint < __v
const_iterator __next = std::next(__hint);
if (__next == end() || value_comp()(__v, *__next)) {
// *__hint < __v < *std::next(__hint)
if (__hint.__get_np()->__right_ == nullptr) {
__parent = __hint.__ptr_;
return static_cast<__node_base_pointer>(__hint.__ptr_)->__right_;
} else {
__parent = __next.__ptr_;
return __parent->__left_;
}
}
// *next(__hint) <= __v
return __find_equal(__parent, __v);
}
// else __v == *__hint
__parent = __hint.__ptr_;
__dummy = static_cast<__node_base_pointer>(__hint.__ptr_);
return __dummy;
}

template <class _Tp, class _Compare, class _Allocator>
void __tree<_Tp, _Compare, _Allocator>::__insert_node_at(
__end_node_pointer __parent, __node_base_pointer& __child, __node_base_pointer __new_node) _NOEXCEPT {
Expand All @@ -1753,10 +1739,9 @@ template <class _Tp, class _Compare, class _Allocator>
template <class _Key, class... _Args>
pair<typename __tree<_Tp, _Compare, _Allocator>::iterator, bool>
__tree<_Tp, _Compare, _Allocator>::__emplace_unique_key_args(_Key const& __k, _Args&&... __args) {
__end_node_pointer __parent;
__node_base_pointer& __child = __find_equal(__parent, __k);
__node_pointer __r = static_cast<__node_pointer>(__child);
bool __inserted = false;
auto [__parent, __child] = __find_equal(__k);
__node_pointer __r = static_cast<__node_pointer>(__child);
bool __inserted = false;
if (__child == nullptr) {
__node_holder __h = __construct_node(std::forward<_Args>(__args)...);
__insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get()));
Expand All @@ -1771,11 +1756,10 @@ template <class _Key, class... _Args>
pair<typename __tree<_Tp, _Compare, _Allocator>::iterator, bool>
__tree<_Tp, _Compare, _Allocator>::__emplace_hint_unique_key_args(
const_iterator __p, _Key const& __k, _Args&&... __args) {
__end_node_pointer __parent;
__node_base_pointer __dummy;
__node_base_pointer& __child = __find_equal(__p, __parent, __dummy, __k);
__node_pointer __r = static_cast<__node_pointer>(__child);
bool __inserted = false;
auto [__parent, __child] = __find_equal(__p, __dummy, __k);
__node_pointer __r = static_cast<__node_pointer>(__child);
bool __inserted = false;
if (__child == nullptr) {
__node_holder __h = __construct_node(std::forward<_Args>(__args)...);
__insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get()));
Expand All @@ -1800,11 +1784,10 @@ template <class _Tp, class _Compare, class _Allocator>
template <class... _Args>
pair<typename __tree<_Tp, _Compare, _Allocator>::iterator, bool>
__tree<_Tp, _Compare, _Allocator>::__emplace_unique_impl(_Args&&... __args) {
__node_holder __h = __construct_node(std::forward<_Args>(__args)...);
__end_node_pointer __parent;
__node_base_pointer& __child = __find_equal(__parent, __h->__value_);
__node_pointer __r = static_cast<__node_pointer>(__child);
bool __inserted = false;
__node_holder __h = __construct_node(std::forward<_Args>(__args)...);
auto [__parent, __child] = __find_equal(__h->__value_);
__node_pointer __r = static_cast<__node_pointer>(__child);
bool __inserted = false;
if (__child == nullptr) {
__insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get()));
__r = __h.release();
Expand All @@ -1818,10 +1801,9 @@ template <class... _Args>
typename __tree<_Tp, _Compare, _Allocator>::iterator
__tree<_Tp, _Compare, _Allocator>::__emplace_hint_unique_impl(const_iterator __p, _Args&&... __args) {
__node_holder __h = __construct_node(std::forward<_Args>(__args)...);
__end_node_pointer __parent;
__node_base_pointer __dummy;
__node_base_pointer& __child = __find_equal(__p, __parent, __dummy, __h->__value_);
__node_pointer __r = static_cast<__node_pointer>(__child);
auto [__parent, __child] = __find_equal(__p, __dummy, __h->__value_);
__node_pointer __r = static_cast<__node_pointer>(__child);
if (__child == nullptr) {
__insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get()));
__r = __h.release();
Expand Down Expand Up @@ -1854,10 +1836,9 @@ __tree<_Tp, _Compare, _Allocator>::__emplace_hint_multi(const_iterator __p, _Arg
template <class _Tp, class _Compare, class _Allocator>
pair<typename __tree<_Tp, _Compare, _Allocator>::iterator, bool>
__tree<_Tp, _Compare, _Allocator>::__node_assign_unique(const value_type& __v, __node_pointer __nd) {
__end_node_pointer __parent;
__node_base_pointer& __child = __find_equal(__parent, __v);
__node_pointer __r = static_cast<__node_pointer>(__child);
bool __inserted = false;
auto [__parent, __child] = __find_equal(__v);
__node_pointer __r = static_cast<__node_pointer>(__child);
bool __inserted = false;
if (__child == nullptr) {
__assign_value(__nd->__value_, __v);
__insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__nd));
Expand Down Expand Up @@ -1906,8 +1887,7 @@ __tree<_Tp, _Compare, _Allocator>::__node_handle_insert_unique(_NodeHandle&& __n
return _InsertReturnType{end(), false, _NodeHandle()};

__node_pointer __ptr = __nh.__ptr_;
__end_node_pointer __parent;
__node_base_pointer& __child = __find_equal(__parent, __ptr->__value_);
auto [__parent, __child] = __find_equal(__ptr->__value_);
if (__child != nullptr)
return _InsertReturnType{iterator(static_cast<__node_pointer>(__child)), false, std::move(__nh)};

Expand All @@ -1924,10 +1904,9 @@ __tree<_Tp, _Compare, _Allocator>::__node_handle_insert_unique(const_iterator __
return end();

__node_pointer __ptr = __nh.__ptr_;
__end_node_pointer __parent;
__node_base_pointer __dummy;
__node_base_pointer& __child = __find_equal(__hint, __parent, __dummy, __ptr->__value_);
__node_pointer __r = static_cast<__node_pointer>(__child);
auto [__parent, __child] = __find_equal(__hint, __dummy, __ptr->__value_);
__node_pointer __r = static_cast<__node_pointer>(__child);
if (__child == nullptr) {
__insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__ptr));
__r = __ptr;
Expand Down Expand Up @@ -1960,8 +1939,7 @@ _LIBCPP_HIDE_FROM_ABI void __tree<_Tp, _Compare, _Allocator>::__node_handle_merg

for (typename _Tree::iterator __i = __source.begin(); __i != __source.end();) {
__node_pointer __src_ptr = __i.__get_np();
__end_node_pointer __parent;
__node_base_pointer& __child = __find_equal(__parent, __src_ptr->__value_);
auto [__parent, __child] = __find_equal(__src_ptr->__value_);
++__i;
if (__child != nullptr)
continue;
Expand Down
11 changes: 4 additions & 7 deletions libcxx/include/map
Original file line number Diff line number Diff line change
Expand Up @@ -1428,9 +1428,8 @@ map<_Key, _Tp, _Compare, _Allocator>::__construct_node_with_key(const key_type&

template <class _Key, class _Tp, class _Compare, class _Allocator>
_Tp& map<_Key, _Tp, _Compare, _Allocator>::operator[](const key_type& __k) {
__parent_pointer __parent;
__node_base_pointer& __child = __tree_.__find_equal(__parent, __k);
__node_pointer __r = static_cast<__node_pointer>(__child);
auto [__parent, __child] = __tree_.__find_equal(__k);
__node_pointer __r = static_cast<__node_pointer>(__child);
if (__child == nullptr) {
__node_holder __h = __construct_node_with_key(__k);
__tree_.__insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get()));
Expand All @@ -1443,17 +1442,15 @@ _Tp& map<_Key, _Tp, _Compare, _Allocator>::operator[](const key_type& __k) {

template <class _Key, class _Tp, class _Compare, class _Allocator>
_Tp& map<_Key, _Tp, _Compare, _Allocator>::at(const key_type& __k) {
__parent_pointer __parent;
__node_base_pointer& __child = __tree_.__find_equal(__parent, __k);
auto [_, __child] = __tree_.__find_equal(__k);
if (__child == nullptr)
std::__throw_out_of_range("map::at: key not found");
return static_cast<__node_pointer>(__child)->__value_.second;
}

template <class _Key, class _Tp, class _Compare, class _Allocator>
const _Tp& map<_Key, _Tp, _Compare, _Allocator>::at(const key_type& __k) const {
__parent_pointer __parent;
__node_base_pointer __child = __tree_.__find_equal(__parent, __k);
auto [_, __child] = __tree_.__find_equal(__k);
if (__child == nullptr)
std::__throw_out_of_range("map::at: key not found");
return static_cast<__node_pointer>(__child)->__value_.second;
Expand Down
Loading