diff --git a/libcxx/include/__tree b/libcxx/include/__tree index 3dd5ae585e1db..2e0e4488f28a5 100644 --- a/libcxx/include/__tree +++ b/libcxx/include/__tree @@ -1086,15 +1086,89 @@ public: // FIXME: Make this function const qualified. Unfortunately doing so // breaks existing code which uses non-const callable comparators. + + // Find place to insert if __v doesn't exist + // Set __parent to parent of null leaf + // Return reference to null leaf + // If __v exists, set parent to node of __v and return reference to node of __v template - _LIBCPP_HIDE_FROM_ABI __node_base_pointer& __find_equal(__end_node_pointer& __parent, const _Key& __v); + _LIBCPP_HIDE_FROM_ABI pair<__end_node_pointer, __node_base_pointer&> __find_equal(const _Key& __v) { + using _PairT = pair<__end_node_pointer, __node_base_pointer&>; + + __node_pointer __node = __root(); + + if (__node == nullptr) { + auto __end = __end_node(); + return _PairT(__end, __end->__left_); + } + + __node_base_pointer* __node_ptr = __root_ptr(); + while (true) { + if (value_comp()(__v, __node->__value_)) { + if (__node->__left_ == nullptr) + return _PairT(static_cast<__end_node_pointer>(__node), __node->__left_); + + __node_ptr = std::addressof(__node->__left_); + __node = static_cast<__node_pointer>(__node->__left_); + } else if (value_comp()(__node->__value_, __v)) { + if (__node->__right_ == nullptr) + return _PairT(static_cast<__end_node_pointer>(__node), __node->__right_); + + __node_ptr = std::addressof(__node->__right_); + __node = static_cast<__node_pointer>(__node->__right_); + } else { + return _PairT(static_cast<__end_node_pointer>(__node), *__node_ptr); + } + } + } + template - _LIBCPP_HIDE_FROM_ABI __node_base_pointer& __find_equal(__end_node_pointer& __parent, const _Key& __v) const { - return const_cast<__tree*>(this)->__find_equal(__parent, __v); + _LIBCPP_HIDE_FROM_ABI pair<__end_node_pointer, __node_base_pointer&> __find_equal(const _Key& __v) const { + return const_cast<__tree*>(this)->__find_equal(__v); } + + // Find place to insert if __v doesn't exist + // First check prior to __hint. + // Next check after __hint. + // Next do O(log N) search. + // Set __parent to parent of null leaf + // Return reference to null leaf + // If __v exists, set parent to node of __v and return reference to node of __v template - _LIBCPP_HIDE_FROM_ABI __node_base_pointer& - __find_equal(const_iterator __hint, __end_node_pointer& __parent, __node_base_pointer& __dummy, const _Key& __v); + _LIBCPP_HIDE_FROM_ABI pair<__end_node_pointer, __node_base_pointer&> + __find_equal(const_iterator __hint, __node_base_pointer& __dummy, const _Key& __v) { + using _PairT = pair<__end_node_pointer, __node_base_pointer&>; + + if (__hint == end() || value_comp()(__v, *__hint)) { // check before + // __v < *__hint + const_iterator __prior = __hint; + if (__prior == begin() || value_comp()(*--__prior, __v)) { + // *prev(__hint) < __v < *__hint + if (__hint.__ptr_->__left_ == nullptr) + return _PairT(__hint.__ptr_, __hint.__ptr_->__left_); + return _PairT(__prior.__ptr_, static_cast<__node_pointer>(__prior.__ptr_)->__right_); + } + // __v <= *prev(__hint) + return __find_equal(__v); + } + + if (value_comp()(*__hint, __v)) { // check after + // *__hint < __v + const_iterator __next = std::next(__hint); + if (__next == end() || value_comp()(__v, *__next)) { + // *__hint < __v < *std::next(__hint) + if (__hint.__get_np()->__right_ == nullptr) + return _PairT(__hint.__ptr_, static_cast<__node_pointer>(__hint.__ptr_)->__right_); + return _PairT(__next.__ptr_, __next.__ptr_->__left_); + } + // *next(__hint) <= __v + return __find_equal(__v); + } + + // else __v == *__hint + __dummy = static_cast<__node_base_pointer>(__hint.__ptr_); + return _PairT(__hint.__ptr_, __dummy); + } _LIBCPP_HIDE_FROM_ABI void __copy_assign_alloc(const __tree& __t) { __copy_assign_alloc(__t, integral_constant()); @@ -1647,94 +1721,6 @@ typename __tree<_Tp, _Compare, _Allocator>::__node_base_pointer& __tree<_Tp, _Co return __find_leaf_low(__parent, __v); } -// Find place to insert if __v doesn't exist -// Set __parent to parent of null leaf -// Return reference to null leaf -// If __v exists, set parent to node of __v and return reference to node of __v -template -template -typename __tree<_Tp, _Compare, _Allocator>::__node_base_pointer& -__tree<_Tp, _Compare, _Allocator>::__find_equal(__end_node_pointer& __parent, const _Key& __v) { - __node_pointer __nd = __root(); - __node_base_pointer* __nd_ptr = __root_ptr(); - if (__nd != nullptr) { - while (true) { - if (value_comp()(__v, __nd->__value_)) { - if (__nd->__left_ != nullptr) { - __nd_ptr = std::addressof(__nd->__left_); - __nd = static_cast<__node_pointer>(__nd->__left_); - } else { - __parent = static_cast<__end_node_pointer>(__nd); - return __parent->__left_; - } - } else if (value_comp()(__nd->__value_, __v)) { - if (__nd->__right_ != nullptr) { - __nd_ptr = std::addressof(__nd->__right_); - __nd = static_cast<__node_pointer>(__nd->__right_); - } else { - __parent = static_cast<__end_node_pointer>(__nd); - return __nd->__right_; - } - } else { - __parent = static_cast<__end_node_pointer>(__nd); - return *__nd_ptr; - } - } - } - __parent = __end_node(); - return __parent->__left_; -} - -// Find place to insert if __v doesn't exist -// First check prior to __hint. -// Next check after __hint. -// Next do O(log N) search. -// Set __parent to parent of null leaf -// Return reference to null leaf -// If __v exists, set parent to node of __v and return reference to node of __v -template -template -typename __tree<_Tp, _Compare, _Allocator>::__node_base_pointer& __tree<_Tp, _Compare, _Allocator>::__find_equal( - const_iterator __hint, __end_node_pointer& __parent, __node_base_pointer& __dummy, const _Key& __v) { - if (__hint == end() || value_comp()(__v, *__hint)) // check before - { - // __v < *__hint - const_iterator __prior = __hint; - if (__prior == begin() || value_comp()(*--__prior, __v)) { - // *prev(__hint) < __v < *__hint - if (__hint.__ptr_->__left_ == nullptr) { - __parent = __hint.__ptr_; - return __parent->__left_; - } else { - __parent = __prior.__ptr_; - return static_cast<__node_base_pointer>(__prior.__ptr_)->__right_; - } - } - // __v <= *prev(__hint) - return __find_equal(__parent, __v); - } else if (value_comp()(*__hint, __v)) // check after - { - // *__hint < __v - const_iterator __next = std::next(__hint); - if (__next == end() || value_comp()(__v, *__next)) { - // *__hint < __v < *std::next(__hint) - if (__hint.__get_np()->__right_ == nullptr) { - __parent = __hint.__ptr_; - return static_cast<__node_base_pointer>(__hint.__ptr_)->__right_; - } else { - __parent = __next.__ptr_; - return __parent->__left_; - } - } - // *next(__hint) <= __v - return __find_equal(__parent, __v); - } - // else __v == *__hint - __parent = __hint.__ptr_; - __dummy = static_cast<__node_base_pointer>(__hint.__ptr_); - return __dummy; -} - template void __tree<_Tp, _Compare, _Allocator>::__insert_node_at( __end_node_pointer __parent, __node_base_pointer& __child, __node_base_pointer __new_node) _NOEXCEPT { @@ -1753,10 +1739,9 @@ template template pair::iterator, bool> __tree<_Tp, _Compare, _Allocator>::__emplace_unique_key_args(_Key const& __k, _Args&&... __args) { - __end_node_pointer __parent; - __node_base_pointer& __child = __find_equal(__parent, __k); - __node_pointer __r = static_cast<__node_pointer>(__child); - bool __inserted = false; + auto [__parent, __child] = __find_equal(__k); + __node_pointer __r = static_cast<__node_pointer>(__child); + bool __inserted = false; if (__child == nullptr) { __node_holder __h = __construct_node(std::forward<_Args>(__args)...); __insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get())); @@ -1771,11 +1756,10 @@ template pair::iterator, bool> __tree<_Tp, _Compare, _Allocator>::__emplace_hint_unique_key_args( const_iterator __p, _Key const& __k, _Args&&... __args) { - __end_node_pointer __parent; __node_base_pointer __dummy; - __node_base_pointer& __child = __find_equal(__p, __parent, __dummy, __k); - __node_pointer __r = static_cast<__node_pointer>(__child); - bool __inserted = false; + auto [__parent, __child] = __find_equal(__p, __dummy, __k); + __node_pointer __r = static_cast<__node_pointer>(__child); + bool __inserted = false; if (__child == nullptr) { __node_holder __h = __construct_node(std::forward<_Args>(__args)...); __insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get())); @@ -1800,11 +1784,10 @@ template template pair::iterator, bool> __tree<_Tp, _Compare, _Allocator>::__emplace_unique_impl(_Args&&... __args) { - __node_holder __h = __construct_node(std::forward<_Args>(__args)...); - __end_node_pointer __parent; - __node_base_pointer& __child = __find_equal(__parent, __h->__value_); - __node_pointer __r = static_cast<__node_pointer>(__child); - bool __inserted = false; + __node_holder __h = __construct_node(std::forward<_Args>(__args)...); + auto [__parent, __child] = __find_equal(__h->__value_); + __node_pointer __r = static_cast<__node_pointer>(__child); + bool __inserted = false; if (__child == nullptr) { __insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get())); __r = __h.release(); @@ -1818,10 +1801,9 @@ template typename __tree<_Tp, _Compare, _Allocator>::iterator __tree<_Tp, _Compare, _Allocator>::__emplace_hint_unique_impl(const_iterator __p, _Args&&... __args) { __node_holder __h = __construct_node(std::forward<_Args>(__args)...); - __end_node_pointer __parent; __node_base_pointer __dummy; - __node_base_pointer& __child = __find_equal(__p, __parent, __dummy, __h->__value_); - __node_pointer __r = static_cast<__node_pointer>(__child); + auto [__parent, __child] = __find_equal(__p, __dummy, __h->__value_); + __node_pointer __r = static_cast<__node_pointer>(__child); if (__child == nullptr) { __insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get())); __r = __h.release(); @@ -1854,10 +1836,9 @@ __tree<_Tp, _Compare, _Allocator>::__emplace_hint_multi(const_iterator __p, _Arg template pair::iterator, bool> __tree<_Tp, _Compare, _Allocator>::__node_assign_unique(const value_type& __v, __node_pointer __nd) { - __end_node_pointer __parent; - __node_base_pointer& __child = __find_equal(__parent, __v); - __node_pointer __r = static_cast<__node_pointer>(__child); - bool __inserted = false; + auto [__parent, __child] = __find_equal(__v); + __node_pointer __r = static_cast<__node_pointer>(__child); + bool __inserted = false; if (__child == nullptr) { __assign_value(__nd->__value_, __v); __insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__nd)); @@ -1906,8 +1887,7 @@ __tree<_Tp, _Compare, _Allocator>::__node_handle_insert_unique(_NodeHandle&& __n return _InsertReturnType{end(), false, _NodeHandle()}; __node_pointer __ptr = __nh.__ptr_; - __end_node_pointer __parent; - __node_base_pointer& __child = __find_equal(__parent, __ptr->__value_); + auto [__parent, __child] = __find_equal(__ptr->__value_); if (__child != nullptr) return _InsertReturnType{iterator(static_cast<__node_pointer>(__child)), false, std::move(__nh)}; @@ -1924,10 +1904,9 @@ __tree<_Tp, _Compare, _Allocator>::__node_handle_insert_unique(const_iterator __ return end(); __node_pointer __ptr = __nh.__ptr_; - __end_node_pointer __parent; __node_base_pointer __dummy; - __node_base_pointer& __child = __find_equal(__hint, __parent, __dummy, __ptr->__value_); - __node_pointer __r = static_cast<__node_pointer>(__child); + auto [__parent, __child] = __find_equal(__hint, __dummy, __ptr->__value_); + __node_pointer __r = static_cast<__node_pointer>(__child); if (__child == nullptr) { __insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__ptr)); __r = __ptr; @@ -1960,8 +1939,7 @@ _LIBCPP_HIDE_FROM_ABI void __tree<_Tp, _Compare, _Allocator>::__node_handle_merg for (typename _Tree::iterator __i = __source.begin(); __i != __source.end();) { __node_pointer __src_ptr = __i.__get_np(); - __end_node_pointer __parent; - __node_base_pointer& __child = __find_equal(__parent, __src_ptr->__value_); + auto [__parent, __child] = __find_equal(__src_ptr->__value_); ++__i; if (__child != nullptr) continue; diff --git a/libcxx/include/map b/libcxx/include/map index 6378218945ca0..dd8fa0e0c7867 100644 --- a/libcxx/include/map +++ b/libcxx/include/map @@ -1428,9 +1428,8 @@ map<_Key, _Tp, _Compare, _Allocator>::__construct_node_with_key(const key_type& template _Tp& map<_Key, _Tp, _Compare, _Allocator>::operator[](const key_type& __k) { - __parent_pointer __parent; - __node_base_pointer& __child = __tree_.__find_equal(__parent, __k); - __node_pointer __r = static_cast<__node_pointer>(__child); + auto [__parent, __child] = __tree_.__find_equal(__k); + __node_pointer __r = static_cast<__node_pointer>(__child); if (__child == nullptr) { __node_holder __h = __construct_node_with_key(__k); __tree_.__insert_node_at(__parent, __child, static_cast<__node_base_pointer>(__h.get())); @@ -1443,8 +1442,7 @@ _Tp& map<_Key, _Tp, _Compare, _Allocator>::operator[](const key_type& __k) { template _Tp& map<_Key, _Tp, _Compare, _Allocator>::at(const key_type& __k) { - __parent_pointer __parent; - __node_base_pointer& __child = __tree_.__find_equal(__parent, __k); + auto [_, __child] = __tree_.__find_equal(__k); if (__child == nullptr) std::__throw_out_of_range("map::at: key not found"); return static_cast<__node_pointer>(__child)->__value_.second; @@ -1452,8 +1450,7 @@ _Tp& map<_Key, _Tp, _Compare, _Allocator>::at(const key_type& __k) { template const _Tp& map<_Key, _Tp, _Compare, _Allocator>::at(const key_type& __k) const { - __parent_pointer __parent; - __node_base_pointer __child = __tree_.__find_equal(__parent, __k); + auto [_, __child] = __tree_.__find_equal(__k); if (__child == nullptr) std::__throw_out_of_range("map::at: key not found"); return static_cast<__node_pointer>(__child)->__value_.second;