|
14 | 14 | #include <sycl/ext/oneapi/experimental/detail/invoke_simd_types.hpp> |
15 | 15 | #include <sycl/ext/oneapi/experimental/uniform.hpp> |
16 | 16 |
|
17 | | -#include <sycl/detail/boost/mp11.hpp> |
18 | 17 | #include <sycl/sub_group.hpp> |
19 | 18 |
|
20 | 19 | #include <functional> |
@@ -71,8 +70,6 @@ namespace ext::oneapi::experimental { |
71 | 70 | // --- Helpers |
72 | 71 | namespace detail { |
73 | 72 |
|
74 | | -namespace __MP11_NS = sycl::detail::boost::mp11; |
75 | | - |
76 | 73 | // This structure performs the SPMD-to-SIMD parameter type conversion as defined |
77 | 74 | // by the spec. |
78 | 75 | template <class T, int N, class = void> struct spmd2simd; |
@@ -154,8 +151,7 @@ struct is_simd_or_mask_type<simd_mask<T, N>> : std::true_type {}; |
154 | 151 | // Checks if all the types in the parameter pack are uniform<T>. |
155 | 152 | template <class... SpmdArgs> struct all_uniform_types { |
156 | 153 | constexpr operator bool() { |
157 | | - using TypeList = __MP11_NS::mp_list<SpmdArgs...>; |
158 | | - return __MP11_NS::mp_all_of<TypeList, is_uniform_type>::value; |
| 154 | + return ((is_uniform_type<SpmdArgs>::value && ...)); |
159 | 155 | } |
160 | 156 | }; |
161 | 157 |
|
@@ -193,26 +189,32 @@ constexpr void verify_return_type_matches_sg_size() { |
193 | 189 | // as prescribed by the spec assuming this subgroup size. One and only one |
194 | 190 | // subgroup size should conform. |
195 | 191 | template <class SimdCallable, class... SpmdArgs> struct sg_size { |
196 | | - template <class N> |
197 | | - using IsInvocableSgSize = __MP11_NS::mp_bool<std::is_invocable_v< |
198 | | - SimdCallable, typename spmd2simd<SpmdArgs, N::value>::type...>>; |
199 | | - |
200 | 192 | __DPCPP_SYCL_EXTERNAL constexpr operator int() { |
201 | | - using SupportedSgSizes = __MP11_NS::mp_list_c<int, 1, 2, 4, 8, 16, 32>; |
202 | | - using InvocableSgSizes = |
203 | | - __MP11_NS::mp_copy_if<SupportedSgSizes, IsInvocableSgSize>; |
204 | | - constexpr auto found_invoke_simd_target = |
205 | | - __MP11_NS::mp_empty<InvocableSgSizes>::value != 1; |
206 | | - if constexpr (found_invoke_simd_target) { |
207 | | - static_assert((__MP11_NS::mp_size<InvocableSgSizes>::value == 1) && |
208 | | - "multiple invoke_simd targets found"); |
209 | | - return __MP11_NS::mp_front<InvocableSgSizes>::value; |
210 | | - } |
211 | | - static_assert( |
212 | | - found_invoke_simd_target, |
213 | | - "No callable invoke_simd target found. Confirm the " |
214 | | - "invoke_simd invocation argument types are convertible to the " |
215 | | - "invoke_simd target argument types"); |
| 193 | + constexpr auto x = []() constexpr { |
| 194 | + constexpr int supported_sg_sizes[] = {1, 2, 4, 8, 16, 32}; |
| 195 | + int num_found = 0; |
| 196 | + int found_sg_size = 0; |
| 197 | + sycl::detail::loop<std::size(supported_sg_sizes)>([&](auto idx) { |
| 198 | + constexpr auto sg_size = supported_sg_sizes[idx]; |
| 199 | + if (std::is_invocable_v< |
| 200 | + SimdCallable, typename spmd2simd<SpmdArgs, sg_size>::type...>) { |
| 201 | + ++num_found; |
| 202 | + found_sg_size = sg_size; |
| 203 | + } |
| 204 | + }); |
| 205 | + return std::pair{num_found, found_sg_size}; |
| 206 | + }(); |
| 207 | + |
| 208 | + constexpr auto num_found = x.first; |
| 209 | + constexpr auto found_sg_size = x.second; |
| 210 | + |
| 211 | + static_assert(num_found != 0, |
| 212 | + "No callable invoke_simd target found. Confirm the " |
| 213 | + "invoke_simd invocation argument types are convertible to " |
| 214 | + "the invoke_simd target argument types"); |
| 215 | + static_assert(num_found == 1, "Multiple invoke_simd targets found!"); |
| 216 | + |
| 217 | + return found_sg_size; |
216 | 218 | } |
217 | 219 | }; |
218 | 220 |
|
|
0 commit comments