Skip to content

Commit 809069c

Browse files
Fixes dereferencing zip_iterator over tuples of length-1 as well as copies from nested zip_iterators using OpenMP/TBB (#7882)
* Fixes dereferencing zip_iterator over tuples of length-1 as well as copies from nested zip_iterators using OpenMP/TBB - implements is_compatible to verify that the tuple structure matches to fix conversions - TestZipIteratorDereferenceToValue verifies that casting from the result of dereferencing to the corresponding value type works - TestZipIteratorNestedCopy verifies that copying using nested zip_iterators works Fixes #7855 * Simplify implementation and make it more efficient * Bring back normalization for omp and tbb --------- Co-authored-by: Michael Schellenberger Costa <miscco@nvidia.com>
1 parent c0b9df6 commit 809069c

File tree

2 files changed

+155
-4
lines changed

2 files changed

+155
-4
lines changed

thrust/testing/zip_iterator.cu

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,3 +386,84 @@ void TestZipIteratorCopySoAToAoS()
386386
ASSERT_EQUAL_QUIET(13, cuda::std::get<1>(h_soa[0]));
387387
};
388388
DECLARE_UNITTEST(TestZipIteratorCopySoAToAoS);
389+
390+
template <typename T>
391+
void TestZipIteratorDereferenceToValueType(const T& t)
392+
{
393+
thrust::device_vector<T> data(1, t);
394+
395+
// verify that storing the result of dereferencing a zip_iterator and then subsequently converting to its value type
396+
// is handled correctly
397+
398+
auto a = thrust::make_zip_iterator(data.begin());
399+
static_assert(cuda::std::is_same_v<cuda::std::tuple<T>, cuda::std::iter_value_t<decltype(a)>>);
400+
401+
auto b = a[0];
402+
static_assert(
403+
cuda::std::is_same_v<thrust::detail::tuple_of_iterator_references<thrust::device_reference<T>>, decltype(b)>);
404+
405+
// verify that the stored tuple_of_iterator_references<device_reference<T>> can be cast to tuple<T>
406+
auto c = cuda::std::tuple<T>(b);
407+
static_assert(cuda::std::is_same_v<cuda::std::tuple<T>, decltype(c)>);
408+
409+
ASSERT_EQUAL_QUIET(c, cuda::std::make_tuple(t));
410+
}
411+
412+
void TestZipIteratorDereferenceToValue()
413+
{
414+
TestZipIteratorDereferenceToValueType(1);
415+
TestZipIteratorDereferenceToValueType(cuda::std::make_tuple(1));
416+
TestZipIteratorDereferenceToValueType(cuda::std::make_tuple(1, cuda::std::make_tuple(1)));
417+
TestZipIteratorDereferenceToValueType(cuda::std::make_tuple(1, cuda::std::make_tuple(1, 1)));
418+
TestZipIteratorDereferenceToValueType(cuda::std::make_tuple(cuda::std::make_tuple(1), cuda::std::make_tuple(1, 1)));
419+
}
420+
DECLARE_UNITTEST(TestZipIteratorDereferenceToValue);
421+
422+
void TestZipIteratorNestedCopy()
423+
{
424+
using T = int;
425+
426+
{
427+
thrust::device_vector<T> a(10, 1);
428+
thrust::device_vector<cuda::std::tuple<cuda::std::tuple<T>>> b(a.size());
429+
430+
thrust::copy_n(thrust::make_zip_iterator(thrust::make_zip_iterator(a.begin())), a.size(), b.begin());
431+
432+
decltype(b) b_expected(b.size(), cuda::std::make_tuple(cuda::std::make_tuple(1)));
433+
434+
ASSERT_EQUAL_QUIET(b, b_expected);
435+
}
436+
437+
{
438+
thrust::device_vector<T> a(10, 1);
439+
thrust::device_vector<cuda::std::tuple<cuda::std::tuple<T, T>, cuda::std::tuple<T, T>>> b(a.size());
440+
441+
thrust::copy_n(thrust::make_zip_iterator(thrust::make_zip_iterator(a.begin(), a.begin()),
442+
thrust::make_zip_iterator(a.begin(), a.begin())),
443+
a.size(),
444+
b.begin());
445+
446+
decltype(b) b_expected(b.size(), cuda::std::make_tuple(cuda::std::make_tuple(1, 1), cuda::std::make_tuple(1, 1)));
447+
448+
ASSERT_EQUAL_QUIET(b, b_expected);
449+
}
450+
451+
{
452+
thrust::device_vector<cuda::std::tuple<T, T>> a(10, cuda::std::make_tuple(1, 1));
453+
thrust::device_vector<
454+
cuda::std::tuple<cuda::std::tuple<T, T>, cuda::std::tuple<T, T>, cuda::std::tuple<T, T>, cuda::std::tuple<T, T>>>
455+
b(a.size());
456+
457+
thrust::copy_n(thrust::make_zip_iterator(a.begin(), a.begin(), a.begin(), a.begin()), a.size(), b.begin());
458+
459+
decltype(b) b_expected(
460+
b.size(),
461+
cuda::std::make_tuple(cuda::std::make_tuple(1, 1),
462+
cuda::std::make_tuple(1, 1),
463+
cuda::std::make_tuple(1, 1),
464+
cuda::std::make_tuple(1, 1)));
465+
466+
ASSERT_EQUAL_QUIET(b, b_expected);
467+
}
468+
}
469+
DECLARE_UNITTEST(TestZipIteratorNestedCopy);

thrust/thrust/iterator/detail/tuple_of_iterator_references.h

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,69 @@ namespace detail
2828
template <typename... Ts>
2929
class tuple_of_iterator_references;
3030

31-
template <class U, class T>
31+
// is_compatible_tuple_normalize:
32+
// device_reference<T> --> T
33+
// tuple_of_iterator_references<Ts...> --> tuple<Ts...>
34+
// T& --> T
35+
36+
template <typename T>
37+
struct is_compatible_tuple_normalize
38+
{
39+
using type = T;
40+
};
41+
42+
template <typename... Ts>
43+
struct is_compatible_tuple_normalize<tuple_of_iterator_references<Ts...>>
44+
{
45+
using type = ::cuda::std::tuple<Ts...>;
46+
};
47+
48+
template <typename T>
49+
struct is_compatible_tuple_normalize<thrust::device_reference<T>>
50+
{
51+
using type = T;
52+
};
53+
54+
template <typename T>
55+
struct is_compatible_tuple_normalize<T&>
56+
{
57+
using type = T;
58+
};
59+
60+
template <typename T>
61+
using is_compatible_tuple_normalize_t = typename is_compatible_tuple_normalize<T>::type;
62+
63+
// is_compatible_tuple_v:
64+
// - checks if the tuple structure matches
65+
// - rather than just testing the top-level size, this handles nesting with length-1 tuples,
66+
67+
// is_compatible_tuple_v:
68+
// - case of two non-tuple types are compatible
69+
// - case of mixing tuples is not compatible
70+
template <typename U, typename T>
71+
inline constexpr bool is_compatible_tuple_v = ::cuda::std::__tuple_like<U> == ::cuda::std::__tuple_like<T>;
72+
73+
// is_compatible_tuple_helper_v: verifies that the outer-most tuple_size matches prior to recursing further
74+
// - case1: non-viable, sizes don't even match, do not recurse
75+
template <typename U, typename T, bool TupleSizeMatches>
76+
inline constexpr bool is_compatible_tuple_helper_v = false;
77+
78+
// is_compatible_tuple_helper_v: viable, sizes match, recurse further but unwrap references
79+
template <template <class...> class Tuple1, template <class...> class Tuple2, typename... Ts, typename... Us>
80+
inline constexpr bool is_compatible_tuple_helper_v<Tuple1<Us...>, Tuple2<Ts...>, true> =
81+
(is_compatible_tuple_v<is_compatible_tuple_normalize_t<Us>, is_compatible_tuple_normalize_t<Ts>> && ...);
82+
83+
// is_compatible_tuple_v: recurse via is_compatible_tuple_helper_v to see if the two tuples are compatible
84+
template <template <class...> class Tuple1, template <class...> class Tuple2, typename... Ts, typename... Us>
85+
inline constexpr bool is_compatible_tuple_v<Tuple1<Us...>, Tuple2<Ts...>> =
86+
is_compatible_tuple_helper_v<Tuple1<Us...>, Tuple2<Ts...>, sizeof...(Us) == sizeof...(Ts)>;
87+
88+
// is_compatible_tuple_v: recurse via is_compatible_tuple_helper_v to see if the two tuples are compatible
89+
template <typename... Us, typename... Ts>
90+
inline constexpr bool is_compatible_tuple_v<::cuda::std::tuple<Us...>, ::cuda::std::tuple<Ts...>> =
91+
is_compatible_tuple_helper_v<::cuda::std::tuple<Us...>, ::cuda::std::tuple<Ts...>, sizeof...(Us) == sizeof...(Ts)>;
92+
93+
template <class U, class T, class Enable = void>
3294
struct maybe_unwrap_nested
3395
{
3496
_CCCL_HOST_DEVICE U operator()(const T& t) const
@@ -38,7 +100,10 @@ struct maybe_unwrap_nested
38100
};
39101

40102
template <class... Us, class... Ts>
41-
struct maybe_unwrap_nested<::cuda::std::tuple<Us...>, tuple_of_iterator_references<Ts...>>
103+
struct maybe_unwrap_nested<
104+
::cuda::std::tuple<Us...>,
105+
tuple_of_iterator_references<Ts...>,
106+
::cuda::std::enable_if_t<is_compatible_tuple_v<::cuda::std::tuple<Us...>, ::cuda::std::tuple<Ts...>>, int>>
42107
{
43108
_CCCL_HOST_DEVICE ::cuda::std::tuple<Us...> operator()(const tuple_of_iterator_references<Ts...>& t) const
44109
{
@@ -98,7 +163,9 @@ class tuple_of_iterator_references : public ::cuda::std::tuple<Ts...>
98163
return *this;
99164
}
100165

101-
template <class... Us, ::cuda::std::enable_if_t<sizeof...(Us) == sizeof...(Ts), int> = 0>
166+
template <
167+
class... Us,
168+
::cuda::std::enable_if_t<is_compatible_tuple_v<::cuda::std::tuple<Us...>, ::cuda::std::tuple<Ts...>>, int> = 0>
102169
_CCCL_HOST_DEVICE constexpr operator ::cuda::std::tuple<Us...>() const
103170
{
104171
return __to_tuple<Us...>(typename ::cuda::std::__make_tuple_indices<sizeof...(Ts)>::type{});
@@ -112,7 +179,10 @@ class tuple_of_iterator_references : public ::cuda::std::tuple<Ts...>
112179
x.swap(y);
113180
}
114181

115-
template <class... Us, size_t... Id>
182+
template <
183+
class... Us,
184+
size_t... Id,
185+
::cuda::std::enable_if_t<is_compatible_tuple_v<::cuda::std::tuple<Us...>, ::cuda::std::tuple<Ts...>>, int> = 0>
116186
_CCCL_HOST_DEVICE constexpr ::cuda::std::tuple<Us...> __to_tuple(::cuda::std::__tuple_indices<Id...>) const
117187
{
118188
return {maybe_unwrap_nested<Us, Ts>{}(::cuda::std::get<Id>(*this))...};

0 commit comments

Comments
 (0)