3939
4040#include < cuda/std/__memory/addressof.h>
4141#include < cuda/std/__memory/pointer_traits.h>
42+ #include < cuda/std/__type_traits/add_lvalue_reference.h>
4243#include < cuda/std/__type_traits/enable_if.h>
4344#include < cuda/std/__type_traits/is_comparable.h>
4445#include < cuda/std/__type_traits/is_reference.h>
@@ -402,6 +403,83 @@ class pointer : public detail::pointer_base<Element, Tag, Reference, Derived>::t
402403 operator <=(pointer const & lhs, pointer<OtherElement, OtherTag, OtherReference, OtherDerived> const & rhs) = delete ;
403404};
404405
406+ namespace detail
407+ {
408+ template <typename Ptr, typename T>
409+ struct thrust_pointer_rebind ;
410+
411+ template <typename T, typename U>
412+ struct thrust_pointer_rebind <T*, U>
413+ {
414+ using type = U*;
415+ };
416+
417+ // Rebind generic fancy pointers.
418+ template <template <typename , typename ...> class Ptr , typename OldT, typename ... Tail, typename T>
419+ struct thrust_pointer_rebind <Ptr<OldT, Tail...>, T>
420+ {
421+ using type = Ptr<T, Tail...>;
422+ };
423+
424+ // Rebind `thrust::pointer`-like things with `thrust::reference`-like references.
425+ template <template <typename , typename , typename , typename ...> class Ptr ,
426+ typename OldT,
427+ typename Tag,
428+ template <typename ...> class Ref ,
429+ typename ... RefTail,
430+ typename ... PtrTail,
431+ typename T>
432+ struct thrust_pointer_rebind <Ptr<OldT, Tag, Ref<OldT, RefTail...>, PtrTail...>, T>
433+ {
434+ // static_assert(is_same<OldT, Tag>::value, "0");
435+ using type = Ptr<T, Tag, Ref<T, RefTail...>, PtrTail...>;
436+ };
437+
438+ // Rebind `thrust::pointer`-like things with `thrust::reference`-like references
439+ // and templated derived types.
440+ template <template <typename , typename , typename , typename ...> class Ptr ,
441+ typename OldT,
442+ typename Tag,
443+ template <typename ...> class Ref ,
444+ typename ... RefTail,
445+ template <typename ...> class DerivedPtr ,
446+ typename ... DerivedPtrTail,
447+ typename T>
448+ struct thrust_pointer_rebind <Ptr<OldT, Tag, Ref<OldT, RefTail...>, DerivedPtr<OldT, DerivedPtrTail...>>, T>
449+ {
450+ // static_assert(::cuda::std::is_same<OldT, Tag>::value, "1");
451+ using type = Ptr<T, Tag, Ref<T, RefTail...>, DerivedPtr<T, DerivedPtrTail...>>;
452+ };
453+
454+ // Rebind `thrust::pointer`-like things with native reference types.
455+ template <template <typename , typename , typename , typename ...> class Ptr ,
456+ typename OldT,
457+ typename Tag,
458+ typename ... PtrTail,
459+ typename T>
460+ struct thrust_pointer_rebind <Ptr<OldT, Tag, ::cuda::std::add_lvalue_reference_t <OldT>, PtrTail...>, T>
461+ {
462+ // static_assert(::cuda::std::is_same<OldT, Tag>::value, "2");
463+ using type = Ptr<T, Tag, ::cuda::std::add_lvalue_reference_t <T>, PtrTail...>;
464+ };
465+
466+ // Rebind `thrust::pointer`-like things with native reference types and templated
467+ // derived types.
468+ template <template <typename , typename , typename , typename ...> class Ptr ,
469+ typename OldT,
470+ typename Tag,
471+ template <typename ...> class DerivedPtr ,
472+ typename ... DerivedPtrTail,
473+ typename T>
474+ struct thrust_pointer_rebind <
475+ Ptr<OldT, Tag, ::cuda::std::add_lvalue_reference_t <OldT>, DerivedPtr<OldT, DerivedPtrTail...>>,
476+ T>
477+ {
478+ // static_assert(is_same<OldT, Tag>::value, "3");
479+ using type = Ptr<T, Tag, ::cuda::std::add_lvalue_reference_t <T>, DerivedPtr<T, DerivedPtrTail...>>;
480+ };
481+ } // namespace detail
482+
405483/* ! \} // memory_management
406484 */
407485
@@ -420,7 +498,7 @@ struct pointer_traits<Pointer, void_t<typename Pointer::raw_pointer>>
420498 template <typename U>
421499 struct rebind
422500 {
423- using other = typename THRUST_NS_QUALIFIER::detail::rebind_pointer <pointer, U>::type;
501+ using other = typename THRUST_NS_QUALIFIER::detail::thrust_pointer_rebind <pointer, U>::type;
424502 };
425503
426504 // Backwards compatability with thrust::detail::pointer_traits
@@ -440,5 +518,11 @@ struct pointer_traits<Pointer, void_t<typename Pointer::raw_pointer>>
440518 {
441519 return iter.get ();
442520 }
521+
522+ // ! For backwards compatability with old pointer traits
523+ [[nodiscard]] _CCCL_API static constexpr raw_pointer get (const pointer iter) noexcept
524+ {
525+ return iter.get ();
526+ }
443527};
444528_CCCL_END_NAMESPACE_CUDA_STD
0 commit comments