diff --git a/libcxx/include/__algorithm/equal.h b/libcxx/include/__algorithm/equal.h index bfc8f72f6eb19..90e586eff2d1d 100644 --- a/libcxx/include/__algorithm/equal.h +++ b/libcxx/include/__algorithm/equal.h @@ -17,6 +17,7 @@ #include <__functional/invoke.h> #include <__iterator/distance.h> #include <__iterator/iterator_traits.h> +#include <__iterator/segmented_iterator.h> #include <__string/constexpr_c_functions.h> #include <__type_traits/desugars_to.h> #include <__type_traits/enable_if.h> @@ -95,6 +96,81 @@ __equal_impl(_Tp* __first1, _Tp* __last1, _Up* __first2, _Up*, _Pred&, _Proj1&, return std::__constexpr_memcmp_equal(__first1, __first2, __element_count(__last1 - __first1)); } +template +_LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_segment( + _SegmentedIterator1 __first1, + _SegmentedIterator1 __last1, + _SegmentedIterator2 __first2, + _SegmentedIterator2 __last2, + _BinaryPredicate __pred, + _Proj1& __proj1, + _Proj2& __proj2) { + using _Traits1 = __segmented_iterator_traits<_SegmentedIterator1>; + using _Traits2 = __segmented_iterator_traits<_SegmentedIterator2>; + + auto __sfirst1 = _Traits1::__segment(__first1); + auto __slast1 = _Traits1::__segment(__last1); + + auto __sfirst2 = _Traits2::__segment(__first2); + auto __slast2 = _Traits2::__segment(__last2); + + // Both have only 1 segment + if (__sfirst1 == __slast1 && __sfirst2 == __slast2) + return std::__equal_impl( + std::__unwrap_iter(_Traits1::__local(__first1)), + std::__unwrap_iter(_Traits1::__local(__last1)), + std::__unwrap_iter(_Traits2::__local(__first2)), + std::__unwrap_iter(_Traits2::__local(__last2)), + __pred, + __proj1, + __proj2); + + { // We have more than one segment. Iterate over the first segment, since we might not start at the beginning + if (!std::__equal_impl( + std::__unwrap_iter(_Traits1::__local(__first1)), + std::__unwrap_iter(_Traits1::__end(__sfirst1)), + std::__unwrap_iter(_Traits2::__local(__first2)), + std::__unwrap_iter(_Traits2::__end(__sfirst2)), + __pred, + __proj1, + __proj2)) { + return false; + } + } + ++__sfirst1; + ++__sfirst2; + + // Iterate over the segments which are guaranteed to be completely in the range + while (__sfirst1 != __slast1 && __sfirst2 != __slast2) { + if (!std::__equal_impl( + std::__unwrap_iter(_Traits1::__begin(__sfirst1)), + std::__unwrap_iter(_Traits1::__end(__sfirst1)), + std::__unwrap_iter(_Traits2::__begin(__sfirst2)), + std::__unwrap_iter(_Traits2::__end(__sfirst2)), + __pred, + __proj1, + __proj2)) { + return false; + } + ++__sfirst1; + ++__sfirst2; + } + + // Iterate over the last segment + if (!std::__equal_impl( + std::__unwrap_iter(_Traits1::__begin(__sfirst1)), + std::__unwrap_iter(_Traits1::__local(__last1)), + std::__unwrap_iter(_Traits2::__begin(__sfirst2)), + std::__unwrap_iter(_Traits2::__local(__last2)), + __pred, + __proj1, + __proj2)) { + return false; + } + + return __sfirst1 == __slast1 && __sfirst2 == __slast2; +} + template _LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool equal(_InputIterator1 __first1, @@ -108,6 +184,11 @@ equal(_InputIterator1 __first1, return false; } __identity __proj; + + if constexpr (__is_segmented_iterator<_InputIterator1>::value && __is_segmented_iterator<_InputIterator2>::value) { + return std::__equal_segment(__first1, __last1, __first2, __last2, __pred, __proj, __proj); + } + return std::__equal_impl( std::__unwrap_iter(__first1), std::__unwrap_iter(__last1),