Skip to content

Commit 1a5c52d

Browse files
committed
Also specialize for device_ptr
1 parent cab82ae commit 1a5c52d

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

thrust/thrust/device_ptr.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,36 @@ THRUST_NAMESPACE_END
221221

222222
#include <thrust/detail/device_ptr.inl>
223223
#include <thrust/detail/raw_pointer_cast.h>
224+
225+
// Specialize pointer traits for everything that has the raw_pointer alias
226+
template <class T>
227+
struct ::cuda::std::pointer_traits<THRUST_NS_QUALIFIER::device_ptr<T>>
228+
{
229+
using pointer = THRUST_NS_QUALIFIER::device_ptr<T>;
230+
using element_type = T;
231+
using difference_type = ptrdiff_t;
232+
233+
template <typename U>
234+
struct rebind
235+
{
236+
using other = typename THRUST_NS_QUALIFIER::detail::rebind_pointer<pointer, U>::type;
237+
};
238+
239+
// Backwards compatability with thrust::detail::pointer_traits
240+
using raw_pointer = typename pointer::raw_pointer;
241+
242+
// Thrust historically provided a non-standard pointer_to for pointer<void>
243+
template <class U, ::cuda::std::enable_if_t<(::cuda::std::is_void_v<T> || ::cuda::std::is_same_v<U, T>), int> = 0>
244+
[[nodiscard]] _CCCL_API inline static pointer pointer_to(U& r) noexcept(noexcept(::cuda::std::addressof(r)))
245+
{
246+
return static_cast<element_type*>(::cuda::std::addressof(r));
247+
}
248+
249+
//! @brief Retrieve the address of the element pointed at by an thrust pointer
250+
//! @param iter A thrust::device_ptr
251+
//! @return A pointer to the element pointed to by the thrust pointer
252+
[[nodiscard]] _CCCL_API static constexpr raw_pointer to_address(const pointer iter) noexcept
253+
{
254+
return iter.get();
255+
}
256+
};

0 commit comments

Comments
 (0)