Skip to content

Commit b8dd0af

Browse files
committed
Specialize pointer_traits for thrust::pointer
1 parent 3b85d38 commit b8dd0af

File tree

14 files changed

+103
-41
lines changed

14 files changed

+103
-41
lines changed

libcudacxx/include/cuda/std/__memory/pointer_traits.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,6 @@ struct _CCCL_TYPE_VISIBILITY_DEFAULT pointer_traits<_Tp*>
159159
}
160160
};
161161

162-
template <class _From, class _To>
163-
struct __rebind_pointer
164-
{
165-
using type = typename pointer_traits<_From>::template rebind<_To>;
166-
};
167-
168162
// to_address
169163

170164
template <class _Pointer, class = void>

thrust/testing/memory.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ get_temporary_buffer(my_old_temporary_allocation_system, std::ptrdiff_t)
7171
template <typename Pointer>
7272
void return_temporary_buffer(my_old_temporary_allocation_system, Pointer p)
7373
{
74-
using RP = typename thrust::detail::pointer_traits<Pointer>::raw_pointer;
75-
ASSERT_EQUAL(p.get(), reinterpret_cast<RP>(4217));
74+
using RP = typename cuda::std::pointer_traits<Pointer>::raw_pointer;
75+
ASSERT_EQUAL(::cuda::std::to_address(p), reinterpret_cast<RP>(4217));
7676
}
7777
} // namespace my_old_namespace
7878

@@ -101,8 +101,8 @@ void return_temporary_buffer(my_new_temporary_allocation_system, Pointer)
101101
template <typename Pointer>
102102
void return_temporary_buffer(my_new_temporary_allocation_system, Pointer p, std::ptrdiff_t n)
103103
{
104-
using RP = typename thrust::detail::pointer_traits<Pointer>::raw_pointer;
105-
ASSERT_EQUAL(p.get(), reinterpret_cast<RP>(1742));
104+
using RP = typename cuda::std::pointer_traits<Pointer>::raw_pointer;
105+
ASSERT_EQUAL(::cuda::std::to_address(p), reinterpret_cast<RP>(1742));
106106
ASSERT_EQUAL(n, 413);
107107
}
108108
} // namespace my_new_namespace

thrust/testing/mr_pool.cu

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include <thrust/mr/pool.h>
55
#include <thrust/mr/sync_pool.h>
66

7+
#include <cuda/std/memory>
8+
79
#include <unittest/unittest.h>
810

911
template <typename T>
@@ -100,6 +102,25 @@ struct tracked_pointer
100102
}
101103
};
102104

105+
template <typename T>
106+
struct cuda::std::pointer_traits<tracked_pointer<T>>
107+
{
108+
using pointer = tracked_pointer<T>;
109+
using element_type = T;
110+
using difference_type = cuda::std::ptrdiff_t;
111+
112+
template <typename U>
113+
struct rebind
114+
{
115+
using other = typename THRUST_NS_QUALIFIER::detail::rebind_pointer<pointer, U>::type;
116+
};
117+
118+
[[nodiscard]] _CCCL_API static constexpr element_type* to_address(const pointer iter) noexcept
119+
{
120+
return iter.get();
121+
}
122+
};
123+
103124
class tracked_resource final : public thrust::mr::memory_resource<tracked_pointer<void>>
104125
{
105126
public:

thrust/thrust/allocate_unique.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ struct allocator_delete final
6767
using traits = ::cuda::std::allocator_traits<::cuda::std::remove_cvref_t<Allocator>>;
6868
typename traits::allocator_type alloc_T(alloc_);
6969

70-
if (nullptr != detail::pointer_traits<pointer>::get(p))
70+
if (nullptr != ::cuda::std::to_address(p))
7171
{
7272
if constexpr (!Uninitialized)
7373
{
@@ -142,7 +142,7 @@ struct array_allocator_delete final
142142
{
143143
using traits = ::cuda::std::allocator_traits<::cuda::std::remove_cvref_t<Allocator>>;
144144
typename traits::allocator_type alloc_T(get_allocator());
145-
if (nullptr != detail::pointer_traits<pointer>::get(p))
145+
if (nullptr != ::cuda::std::to_address(p))
146146
{
147147
if constexpr (!Uninitialized)
148148
{

thrust/thrust/detail/allocator/tagged_allocator.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
1313
# pragma system_header
1414
#endif // no system header
15+
1516
#include <thrust/detail/allocator/tagged_allocator.h>
1617
#include <thrust/detail/type_traits/pointer_traits.h>
1718
#include <thrust/iterator/iterator_traits.h>
1819

20+
#include <cuda/std/__memory/pointer_traits.h>
1921
#include <cuda/std/limits>
2022

2123
THRUST_NAMESPACE_BEGIN
@@ -29,10 +31,10 @@ class tagged_allocator<void, Tag, Pointer>
2931
{
3032
public:
3133
using value_type = void;
32-
using pointer = typename thrust::detail::pointer_traits<Pointer>::template rebind<void>::other;
33-
using const_pointer = typename thrust::detail::pointer_traits<Pointer>::template rebind<const void>::other;
34+
using pointer = typename ::cuda::std::pointer_traits<Pointer>::template rebind<void>::other;
35+
using const_pointer = typename ::cuda::std::pointer_traits<Pointer>::template rebind<const void>::other;
3436
using size_type = std::size_t;
35-
using difference_type = typename thrust::detail::pointer_traits<Pointer>::difference_type;
37+
using difference_type = typename ::cuda::std::pointer_traits<Pointer>::difference_type;
3638
using system_type = Tag;
3739

3840
template <typename U>
@@ -47,12 +49,12 @@ class tagged_allocator
4749
{
4850
public:
4951
using value_type = T;
50-
using pointer = typename thrust::detail::pointer_traits<Pointer>::template rebind<T>::other;
51-
using const_pointer = typename thrust::detail::pointer_traits<Pointer>::template rebind<const T>::other;
52+
using pointer = typename ::cuda::std::pointer_traits<Pointer>::template rebind<T>::other;
53+
using const_pointer = typename ::cuda::std::pointer_traits<Pointer>::template rebind<const T>::other;
5254
using reference = thrust::detail::it_reference_t<pointer>;
5355
using const_reference = thrust::detail::it_reference_t<const_pointer>;
5456
using size_type = std::size_t;
55-
using difference_type = typename thrust::detail::pointer_traits<pointer>::difference_type;
57+
using difference_type = typename ::cuda::std::pointer_traits<pointer>::difference_type;
5658
using system_type = Tag;
5759

5860
template <typename U>

thrust/thrust/detail/execute_with_allocator.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include <cuda/__cmath/ceil_div.h>
2121
#include <cuda/std/__memory/allocator_traits.h>
22+
#include <cuda/std/__memory/pointer_traits.h>
2223
#include <cuda/std/__utility/pair.h>
2324

2425
THRUST_NAMESPACE_BEGIN
@@ -54,7 +55,7 @@ _CCCL_HOST void return_temporary_buffer(
5455
using pointer = typename alloc_traits::pointer;
5556
using size_type = typename alloc_traits::size_type;
5657
using value_type = typename alloc_traits::value_type;
57-
using T = typename thrust::detail::pointer_traits<Pointer>::element_type;
58+
using T = typename ::cuda::std::pointer_traits<Pointer>::element_type;
5859

5960
size_type num_elements = ::cuda::ceil_div(sizeof(T) * n, sizeof(value_type));
6061

thrust/thrust/detail/get_iterator_value.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include <thrust/iterator/iterator_traits.h>
1919
#include <thrust/system/detail/generic/memory.h> // for get_value()
2020

21+
#include <cuda/std/__memory/pointer_traits.h>
22+
2123
THRUST_NAMESPACE_BEGIN
2224

2325
namespace detail
@@ -38,7 +40,7 @@ _CCCL_HOST_DEVICE it_value_t<Iterator> get_iterator_value(thrust::execution_poli
3840
// we use get_value(exec,pointer*) function
3941
// to perform a dereferencing consistent with the execution policy
4042
template <typename DerivedPolicy, typename Pointer>
41-
_CCCL_HOST_DEVICE typename thrust::detail::pointer_traits<Pointer*>::element_type
43+
_CCCL_HOST_DEVICE typename ::cuda::std::pointer_traits<Pointer*>::element_type
4244
get_iterator_value(thrust::execution_policy<DerivedPolicy>& exec, Pointer* ptr)
4345
{
4446
return get_value(derived_cast(exec), ptr);

thrust/thrust/detail/pointer.h

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,16 @@
2424
#include <thrust/iterator/iterator_adaptor.h>
2525
#include <thrust/iterator/iterator_traversal_tags.h>
2626

27+
#include <cuda/std/__memory/addressof.h>
28+
#include <cuda/std/__memory/pointer_traits.h>
29+
#include <cuda/std/__type_traits/enable_if.h>
2730
#include <cuda/std/__type_traits/is_comparable.h>
2831
#include <cuda/std/__type_traits/is_reference.h>
2932
#include <cuda/std/__type_traits/is_same.h>
3033
#include <cuda/std/__type_traits/is_void.h>
3134
#include <cuda/std/__type_traits/remove_cv.h>
3235
#include <cuda/std/__type_traits/remove_cvref.h>
36+
#include <cuda/std/__type_traits/remove_pointer.h>
3337
#include <cuda/std/__type_traits/type_identity.h>
3438
#include <cuda/std/cstddef>
3539

@@ -197,7 +201,7 @@ class pointer : public detail::pointer_base<Element, Tag, Reference, Derived>::t
197201
*/
198202
template <typename OtherPointer, detail::enable_if_pointer_is_convertible_t<OtherPointer, pointer>* = nullptr>
199203
_CCCL_HOST_DEVICE pointer(const OtherPointer& other)
200-
: super_t(detail::pointer_traits<OtherPointer>::get(other))
204+
: super_t(static_cast<Element*>(::cuda::std::to_address(other)))
201205
{}
202206

203207
#ifndef _CCCL_DOXYGEN_INVOKED // Doxygen cannot handle this constructor and creates a duplicate ID with the ctor above
@@ -206,7 +210,7 @@ class pointer : public detail::pointer_base<Element, Tag, Reference, Derived>::t
206210
template <typename OtherPointer,
207211
detail::enable_if_void_pointer_is_system_convertible_t<OtherPointer, pointer>* = nullptr>
208212
_CCCL_HOST_DEVICE explicit pointer(const OtherPointer& other)
209-
: super_t(static_cast<Element*>(detail::pointer_traits<OtherPointer>::get(other)))
213+
: super_t(static_cast<Element*>(::cuda::std::to_address(other)))
210214
{}
211215
#endif
212216

@@ -232,7 +236,7 @@ class pointer : public detail::pointer_base<Element, Tag, Reference, Derived>::t
232236
_CCCL_HOST_DEVICE detail::enable_if_pointer_is_convertible_t<OtherPointer, pointer, derived_type&>
233237
operator=(const OtherPointer& other)
234238
{
235-
super_t::base_reference() = detail::pointer_traits<OtherPointer>::get(other);
239+
super_t::base_reference() = ::cuda::std::to_address(other);
236240
return static_cast<derived_type&>(*this);
237241
}
238242

@@ -257,10 +261,11 @@ class pointer : public detail::pointer_base<Element, Tag, Reference, Derived>::t
257261
return bool(get());
258262
}
259263

260-
_CCCL_HOST_DEVICE static derived_type
261-
pointer_to(typename detail::pointer_traits_detail::pointer_to_param<Element>::type r)
264+
template <class T,
265+
::cuda::std::enable_if_t<(::cuda::std::is_void_v<Element> || ::cuda::std::is_same_v<T, Element>), int> = 0>
266+
_CCCL_HOST_DEVICE static derived_type pointer_to(T& r)
262267
{
263-
return detail::pointer_traits<derived_type>::pointer_to(r);
268+
return static_cast<derived_type>(::cuda::std::addressof(r));
264269
}
265270

266271
#if !_CCCL_COMPILER(NVRTC)
@@ -388,3 +393,39 @@ class pointer : public detail::pointer_base<Element, Tag, Reference, Derived>::t
388393
*/
389394

390395
THRUST_NAMESPACE_END
396+
397+
_CCCL_BEGIN_NAMESPACE_CUDA_STD
398+
399+
// Specialize pointer traits for everything that has the raw_pointer alias
400+
template <typename Pointer>
401+
struct pointer_traits<Pointer, void_t<typename Pointer::raw_pointer>>
402+
{
403+
using pointer = Pointer;
404+
using element_type = remove_pointer_t<typename Pointer::raw_pointer>;
405+
using difference_type = ptrdiff_t;
406+
407+
template <typename U>
408+
struct rebind
409+
{
410+
using other = typename THRUST_NS_QUALIFIER::detail::rebind_pointer<pointer, U>::type;
411+
};
412+
413+
// Backwards compatability with thrust::detail::pointer_traits
414+
using raw_pointer = typename pointer::raw_pointer;
415+
416+
// Thrust historically provided a non-standard pointer_to for pointer<void>
417+
template <class T, enable_if_t<(is_void_v<element_type> || is_same_v<T, element_type>), int> = 0>
418+
[[nodiscard]] _CCCL_API inline static pointer pointer_to(T& r) noexcept(noexcept(::cuda::std::addressof(r)))
419+
{
420+
return static_cast<element_type*>(::cuda::std::addressof(r));
421+
}
422+
423+
//! @brief Retrieve the address of the element pointed at by an thrust pointer
424+
//! @param iter A thrust::pointer
425+
//! @return A pointer to the element pointed to by the thrust pointer
426+
[[nodiscard]] _CCCL_API static constexpr raw_pointer to_address(const pointer iter) noexcept
427+
{
428+
return iter.get();
429+
}
430+
};
431+
_CCCL_END_NAMESPACE_CUDA_STD

thrust/thrust/detail/raw_pointer_cast.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,17 @@
1212
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
1313
# pragma system_header
1414
#endif // no system header
15+
1516
#include <thrust/detail/type_traits/pointer_traits.h>
1617

18+
#include <cuda/std/__memory/pointer_traits.h>
19+
1720
THRUST_NAMESPACE_BEGIN
1821

1922
template <typename Pointer>
20-
_CCCL_HOST_DEVICE typename thrust::detail::pointer_traits<Pointer>::raw_pointer raw_pointer_cast(Pointer ptr)
23+
_CCCL_HOST_DEVICE auto raw_pointer_cast(Pointer ptr)
2124
{
22-
return thrust::detail::pointer_traits<Pointer>::get(ptr);
25+
return ::cuda::std::to_address(ptr);
2326
}
2427

2528
template <typename ToPointer, typename FromPointer>

thrust/thrust/memory.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include <thrust/detail/temporary_buffer.h>
2424
#include <thrust/detail/type_traits/pointer_traits.h>
2525

26+
#include <cuda/std/__memory/pointer_traits.h>
27+
2628
THRUST_NAMESPACE_BEGIN
2729

2830
/*! \addtogroup memory_management Memory Management
@@ -274,7 +276,7 @@ _CCCL_HOST_DEVICE void return_temporary_buffer(
274276
* \endverbatim
275277
*/
276278
template <typename Pointer>
277-
_CCCL_HOST_DEVICE typename thrust::detail::pointer_traits<Pointer>::raw_pointer raw_pointer_cast(Pointer ptr);
279+
_CCCL_HOST_DEVICE auto raw_pointer_cast(Pointer ptr);
278280

279281
/*! \p raw_reference_cast creates a "raw" reference from a wrapped reference type,
280282
* simply returning the underlying reference, should it exist.

0 commit comments

Comments
 (0)