Skip to content

Commit 34b7a91

Browse files
committed
Replace rebind_pointer_traits
1 parent 3fbd348 commit 34b7a91

File tree

1 file changed

+85
-1
lines changed

1 file changed

+85
-1
lines changed

thrust/thrust/detail/pointer.h

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
#include <cuda/std/__memory/addressof.h>
4141
#include <cuda/std/__memory/pointer_traits.h>
42+
#include <cuda/std/__type_traits/add_lvalue_reference.h>
4243
#include <cuda/std/__type_traits/enable_if.h>
4344
#include <cuda/std/__type_traits/is_comparable.h>
4445
#include <cuda/std/__type_traits/is_reference.h>
@@ -402,6 +403,83 @@ class pointer : public detail::pointer_base<Element, Tag, Reference, Derived>::t
402403
operator<=(pointer const& lhs, pointer<OtherElement, OtherTag, OtherReference, OtherDerived> const& rhs) = delete;
403404
};
404405

406+
namespace detail
407+
{
408+
template <typename Ptr, typename T>
409+
struct thrust_pointer_rebind;
410+
411+
template <typename T, typename U>
412+
struct thrust_pointer_rebind<T*, U>
413+
{
414+
using type = U*;
415+
};
416+
417+
// Rebind generic fancy pointers.
418+
template <template <typename, typename...> class Ptr, typename OldT, typename... Tail, typename T>
419+
struct thrust_pointer_rebind<Ptr<OldT, Tail...>, T>
420+
{
421+
using type = Ptr<T, Tail...>;
422+
};
423+
424+
// Rebind `thrust::pointer`-like things with `thrust::reference`-like references.
425+
template <template <typename, typename, typename, typename...> class Ptr,
426+
typename OldT,
427+
typename Tag,
428+
template <typename...> class Ref,
429+
typename... RefTail,
430+
typename... PtrTail,
431+
typename T>
432+
struct thrust_pointer_rebind<Ptr<OldT, Tag, Ref<OldT, RefTail...>, PtrTail...>, T>
433+
{
434+
// static_assert(is_same<OldT, Tag>::value, "0");
435+
using type = Ptr<T, Tag, Ref<T, RefTail...>, PtrTail...>;
436+
};
437+
438+
// Rebind `thrust::pointer`-like things with `thrust::reference`-like references
439+
// and templated derived types.
440+
template <template <typename, typename, typename, typename...> class Ptr,
441+
typename OldT,
442+
typename Tag,
443+
template <typename...> class Ref,
444+
typename... RefTail,
445+
template <typename...> class DerivedPtr,
446+
typename... DerivedPtrTail,
447+
typename T>
448+
struct thrust_pointer_rebind<Ptr<OldT, Tag, Ref<OldT, RefTail...>, DerivedPtr<OldT, DerivedPtrTail...>>, T>
449+
{
450+
// static_assert(::cuda::std::is_same<OldT, Tag>::value, "1");
451+
using type = Ptr<T, Tag, Ref<T, RefTail...>, DerivedPtr<T, DerivedPtrTail...>>;
452+
};
453+
454+
// Rebind `thrust::pointer`-like things with native reference types.
455+
template <template <typename, typename, typename, typename...> class Ptr,
456+
typename OldT,
457+
typename Tag,
458+
typename... PtrTail,
459+
typename T>
460+
struct thrust_pointer_rebind<Ptr<OldT, Tag, ::cuda::std::add_lvalue_reference_t<OldT>, PtrTail...>, T>
461+
{
462+
// static_assert(::cuda::std::is_same<OldT, Tag>::value, "2");
463+
using type = Ptr<T, Tag, ::cuda::std::add_lvalue_reference_t<T>, PtrTail...>;
464+
};
465+
466+
// Rebind `thrust::pointer`-like things with native reference types and templated
467+
// derived types.
468+
template <template <typename, typename, typename, typename...> class Ptr,
469+
typename OldT,
470+
typename Tag,
471+
template <typename...> class DerivedPtr,
472+
typename... DerivedPtrTail,
473+
typename T>
474+
struct thrust_pointer_rebind<
475+
Ptr<OldT, Tag, ::cuda::std::add_lvalue_reference_t<OldT>, DerivedPtr<OldT, DerivedPtrTail...>>,
476+
T>
477+
{
478+
// static_assert(is_same<OldT, Tag>::value, "3");
479+
using type = Ptr<T, Tag, ::cuda::std::add_lvalue_reference_t<T>, DerivedPtr<T, DerivedPtrTail...>>;
480+
};
481+
} // namespace detail
482+
405483
/*! \} // memory_management
406484
*/
407485

@@ -420,7 +498,7 @@ struct pointer_traits<Pointer, void_t<typename Pointer::raw_pointer>>
420498
template <typename U>
421499
struct rebind
422500
{
423-
using other = typename THRUST_NS_QUALIFIER::detail::rebind_pointer<pointer, U>::type;
501+
using other = typename THRUST_NS_QUALIFIER::detail::thrust_pointer_rebind<pointer, U>::type;
424502
};
425503

426504
// Backwards compatability with thrust::detail::pointer_traits
@@ -440,5 +518,11 @@ struct pointer_traits<Pointer, void_t<typename Pointer::raw_pointer>>
440518
{
441519
return iter.get();
442520
}
521+
522+
//! For backwards compatability with old pointer traits
523+
[[nodiscard]] _CCCL_API static constexpr raw_pointer get(const pointer iter) noexcept
524+
{
525+
return iter.get();
526+
}
443527
};
444528
_CCCL_END_NAMESPACE_CUDA_STD

0 commit comments

Comments
 (0)