2424#include < thrust/iterator/iterator_adaptor.h>
2525#include < thrust/iterator/iterator_traversal_tags.h>
2626
27+ #include < cuda/std/__memory/addressof.h>
28+ #include < cuda/std/__memory/pointer_traits.h>
29+ #include < cuda/std/__type_traits/enable_if.h>
2730#include < cuda/std/__type_traits/is_comparable.h>
2831#include < cuda/std/__type_traits/is_reference.h>
2932#include < cuda/std/__type_traits/is_same.h>
3033#include < cuda/std/__type_traits/is_void.h>
3134#include < cuda/std/__type_traits/remove_cv.h>
3235#include < cuda/std/__type_traits/remove_cvref.h>
36+ #include < cuda/std/__type_traits/remove_pointer.h>
3337#include < cuda/std/__type_traits/type_identity.h>
3438#include < cuda/std/cstddef>
3539
@@ -197,7 +201,7 @@ class pointer : public detail::pointer_base<Element, Tag, Reference, Derived>::t
197201 */
198202 template <typename OtherPointer, detail::enable_if_pointer_is_convertible_t <OtherPointer, pointer>* = nullptr >
199203 _CCCL_HOST_DEVICE pointer (const OtherPointer& other)
200- : super_t(detail::pointer_traits<OtherPointer>::get (other))
204+ : super_t(static_cast <Element*>(::cuda::std::to_address (other) ))
201205 {}
202206
203207#ifndef _CCCL_DOXYGEN_INVOKED // Doxygen cannot handle this constructor and creates a duplicate ID with the ctor above
@@ -206,7 +210,7 @@ class pointer : public detail::pointer_base<Element, Tag, Reference, Derived>::t
206210 template <typename OtherPointer,
207211 detail::enable_if_void_pointer_is_system_convertible_t <OtherPointer, pointer>* = nullptr >
208212 _CCCL_HOST_DEVICE explicit pointer (const OtherPointer& other)
209- : super_t(static_cast <Element*>(detail::pointer_traits<OtherPointer>::get (other)))
213+ : super_t(static_cast <Element*>(::cuda::std::to_address (other)))
210214 {}
211215#endif
212216
@@ -232,7 +236,7 @@ class pointer : public detail::pointer_base<Element, Tag, Reference, Derived>::t
232236 _CCCL_HOST_DEVICE detail::enable_if_pointer_is_convertible_t <OtherPointer, pointer, derived_type&>
233237 operator =(const OtherPointer& other)
234238 {
235- super_t::base_reference () = detail::pointer_traits<OtherPointer>:: get (other);
239+ super_t::base_reference () = :: cuda::std::to_address (other);
236240 return static_cast <derived_type&>(*this );
237241 }
238242
@@ -257,10 +261,11 @@ class pointer : public detail::pointer_base<Element, Tag, Reference, Derived>::t
257261 return bool (get ());
258262 }
259263
260- _CCCL_HOST_DEVICE static derived_type
261- pointer_to (typename detail::pointer_traits_detail::pointer_to_param<Element>::type r)
264+ template <class T ,
265+ ::cuda::std::enable_if_t <(::cuda::std::is_void_v<Element> || ::cuda::std::is_same_v<T, Element>), int > = 0 >
266+ _CCCL_HOST_DEVICE static derived_type pointer_to (T& r)
262267 {
263- return detail::pointer_traits <derived_type>:: pointer_to (r );
268+ return static_cast <derived_type>(:: cuda::std::addressof (r) );
264269 }
265270
266271#if !_CCCL_COMPILER(NVRTC)
@@ -388,3 +393,39 @@ class pointer : public detail::pointer_base<Element, Tag, Reference, Derived>::t
388393 */
389394
390395THRUST_NAMESPACE_END
396+
397+ _CCCL_BEGIN_NAMESPACE_CUDA_STD
398+
399+ // Specialize pointer traits for everything that has the raw_pointer alias
400+ template <typename Pointer>
401+ struct pointer_traits <Pointer, void_t <typename Pointer::raw_pointer>>
402+ {
403+ using pointer = Pointer;
404+ using element_type = remove_pointer_t <typename Pointer::raw_pointer>;
405+ using difference_type = ptrdiff_t ;
406+
407+ template <typename U>
408+ struct rebind
409+ {
410+ using other = typename THRUST_NS_QUALIFIER::detail::rebind_pointer<pointer, U>::type;
411+ };
412+
413+ // Backwards compatability with thrust::detail::pointer_traits
414+ using raw_pointer = typename pointer::raw_pointer;
415+
416+ // Thrust historically provided a non-standard pointer_to for pointer<void>
417+ template <class T , enable_if_t <(is_void_v<element_type> || is_same_v<T, element_type>), int > = 0 >
418+ [[nodiscard]] _CCCL_API inline static pointer pointer_to (T& r) noexcept (noexcept (::cuda::std::addressof(r)))
419+ {
420+ return static_cast <element_type*>(::cuda::std::addressof (r));
421+ }
422+
423+ // ! @brief Retrieve the address of the element pointed at by an thrust pointer
424+ // ! @param iter A thrust::pointer
425+ // ! @return A pointer to the element pointed to by the thrust pointer
426+ [[nodiscard]] _CCCL_API static constexpr raw_pointer to_address (const pointer iter) noexcept
427+ {
428+ return iter.get ();
429+ }
430+ };
431+ _CCCL_END_NAMESPACE_CUDA_STD
0 commit comments