@@ -28,7 +28,69 @@ namespace detail
2828template <typename ... Ts>
2929class tuple_of_iterator_references ;
3030
31- template <class U , class T >
31+ // is_compatible_tuple_normalize:
32+ // device_reference<T> --> T
33+ // tuple_of_iterator_references<Ts...> --> tuple<Ts...>
34+ // T& --> T
35+
36+ template <typename T>
37+ struct is_compatible_tuple_normalize
38+ {
39+ using type = T;
40+ };
41+
42+ template <typename ... Ts>
43+ struct is_compatible_tuple_normalize <tuple_of_iterator_references<Ts...>>
44+ {
45+ using type = ::cuda::std::tuple<Ts...>;
46+ };
47+
48+ template <typename T>
49+ struct is_compatible_tuple_normalize <thrust::device_reference<T>>
50+ {
51+ using type = T;
52+ };
53+
54+ template <typename T>
55+ struct is_compatible_tuple_normalize <T&>
56+ {
57+ using type = T;
58+ };
59+
60+ template <typename T>
61+ using is_compatible_tuple_normalize_t = typename is_compatible_tuple_normalize<T>::type;
62+
63+ // is_compatible_tuple_v:
64+ // - checks if the tuple structure matches
65+ // - rather than just testing the top-level size, this handles nesting with length-1 tuples,
66+
67+ // is_compatible_tuple_v:
68+ // - case of two non-tuple types are compatible
69+ // - case of mixing tuples is not compatible
70+ template <typename U, typename T>
71+ inline constexpr bool is_compatible_tuple_v = ::cuda::std::__tuple_like<U> == ::cuda::std::__tuple_like<T>;
72+
73+ // is_compatible_tuple_helper_v: verifies that the outer-most tuple_size matches prior to recursing further
74+ // - case1: non-viable, sizes don't even match, do not recurse
75+ template <typename U, typename T, bool TupleSizeMatches>
76+ inline constexpr bool is_compatible_tuple_helper_v = false ;
77+
78+ // is_compatible_tuple_helper_v: viable, sizes match, recurse further but unwrap references
79+ template <template <class ...> class Tuple1 , template <class ...> class Tuple2 , typename ... Ts, typename ... Us>
80+ inline constexpr bool is_compatible_tuple_helper_v<Tuple1<Us...>, Tuple2<Ts...>, true > =
81+ (is_compatible_tuple_v<is_compatible_tuple_normalize_t <Us>, is_compatible_tuple_normalize_t <Ts>> && ...);
82+
83+ // is_compatible_tuple_v: recurse via is_compatible_tuple_helper_v to see if the two tuples are compatible
84+ template <template <class ...> class Tuple1 , template <class ...> class Tuple2 , typename ... Ts, typename ... Us>
85+ inline constexpr bool is_compatible_tuple_v<Tuple1<Us...>, Tuple2<Ts...>> =
86+ is_compatible_tuple_helper_v<Tuple1<Us...>, Tuple2<Ts...>, sizeof ...(Us) == sizeof ...(Ts)>;
87+
88+ // is_compatible_tuple_v: recurse via is_compatible_tuple_helper_v to see if the two tuples are compatible
89+ template <typename ... Us, typename ... Ts>
90+ inline constexpr bool is_compatible_tuple_v<::cuda::std::tuple<Us...>, ::cuda::std::tuple<Ts...>> =
91+ is_compatible_tuple_helper_v<::cuda::std::tuple<Us...>, ::cuda::std::tuple<Ts...>, sizeof ...(Us) == sizeof ...(Ts)>;
92+
93+ template <class U , class T , class Enable = void >
3294struct maybe_unwrap_nested
3395{
3496 _CCCL_HOST_DEVICE U operator ()(const T& t) const
@@ -38,7 +100,10 @@ struct maybe_unwrap_nested
38100};
39101
40102template <class ... Us, class ... Ts>
41- struct maybe_unwrap_nested <::cuda::std::tuple<Us...>, tuple_of_iterator_references<Ts...>>
103+ struct maybe_unwrap_nested <
104+ ::cuda::std::tuple<Us...>,
105+ tuple_of_iterator_references<Ts...>,
106+ ::cuda::std::enable_if_t <is_compatible_tuple_v<::cuda::std::tuple<Us...>, ::cuda::std::tuple<Ts...>>, int >>
42107{
43108 _CCCL_HOST_DEVICE ::cuda::std::tuple<Us...> operator ()(const tuple_of_iterator_references<Ts...>& t) const
44109 {
@@ -98,7 +163,9 @@ class tuple_of_iterator_references : public ::cuda::std::tuple<Ts...>
98163 return *this ;
99164 }
100165
101- template <class ... Us, ::cuda::std::enable_if_t <sizeof ...(Us) == sizeof ...(Ts), int > = 0 >
166+ template <
167+ class ... Us,
168+ ::cuda::std::enable_if_t <is_compatible_tuple_v<::cuda::std::tuple<Us...>, ::cuda::std::tuple<Ts...>>, int > = 0 >
102169 _CCCL_HOST_DEVICE constexpr operator ::cuda::std::tuple<Us...>() const
103170 {
104171 return __to_tuple<Us...>(typename ::cuda::std::__make_tuple_indices<sizeof ...(Ts)>::type{});
@@ -112,7 +179,10 @@ class tuple_of_iterator_references : public ::cuda::std::tuple<Ts...>
112179 x.swap (y);
113180 }
114181
115- template <class ... Us, size_t ... Id>
182+ template <
183+ class ... Us,
184+ size_t ... Id,
185+ ::cuda::std::enable_if_t <is_compatible_tuple_v<::cuda::std::tuple<Us...>, ::cuda::std::tuple<Ts...>>, int > = 0 >
116186 _CCCL_HOST_DEVICE constexpr ::cuda::std::tuple<Us...> __to_tuple (::cuda::std::__tuple_indices<Id...>) const
117187 {
118188 return {maybe_unwrap_nested<Us, Ts>{}(::cuda::std::get<Id>(*this ))...};
0 commit comments