@@ -125,8 +125,9 @@ int get_mem_idx(GroupTy g, int vec_or_array_idx) {
125125// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_char.html
126126// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_long.html
127127// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_short.html
128- // Reads require 4-byte alignment, writes 16-byte alignment. Supported
129- // sizes:
128+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_local_block_io.html
129+ // Reads require 4-byte alignment for global pointers and 16-byte alignment for
130+ // local pointers, writes require 16-byte alignment. Supported sizes:
130131//
131132// +------------+-------------+
132133// | block type | # of blocks |
@@ -156,6 +157,21 @@ struct BlockInfo {
156157 (num_blocks <= 8 || (num_blocks == 16 && block_size <= 2 ));
157158};
158159
160+ enum class operation_type { load, store };
161+
162+ template <operation_type OpType, access::address_space Space>
163+ struct RequiredAlignment {};
164+
165+ template <operation_type OpType>
166+ struct RequiredAlignment <OpType, access::address_space::global_space> {
167+ static constexpr int value = (OpType == operation_type::load) ? 4 : 16 ;
168+ };
169+
170+ template <operation_type OpType>
171+ struct RequiredAlignment <OpType, access::address_space::local_space> {
172+ static constexpr int value = 16 ;
173+ };
174+
159175template <typename BlockInfoTy> struct BlockTypeInfo ;
160176
161177template <typename IteratorT, std::size_t ElementsPerWorkItem, bool Blocked>
@@ -186,11 +202,10 @@ struct BlockTypeInfo<BlockInfo<IteratorT, ElementsPerWorkItem, Blocked>> {
186202// aren't satisfied. If deduced address space is generic then returned pointer
187203// will have generic address space and has to be dynamically casted to global or
188204// local space before using in a builtin.
189- template <int RequiredAlign, std::size_t ElementsPerWorkItem,
190- typename IteratorT, typename Properties>
191- auto get_block_op_ptr (IteratorT iter, [[maybe_unused]] Properties props) {
192- using value_type =
193- remove_decoration_t <typename std::iterator_traits<IteratorT>::value_type>;
205+ template <std::size_t ElementsPerWorkItem, typename IteratorT,
206+ typename Properties>
207+ constexpr auto get_block_op_ptr (IteratorT iter,
208+ [[maybe_unused]] Properties props) {
194209 using iter_no_cv = std::remove_cv_t <IteratorT>;
195210
196211 constexpr bool blocked = detail::isBlocked (props);
@@ -208,39 +223,46 @@ auto get_block_op_ptr(IteratorT iter, [[maybe_unused]] Properties props) {
208223 } else if constexpr (!props.template has_property <full_group_key>()) {
209224 return nullptr ;
210225 } else if constexpr (detail::is_multi_ptr_v<IteratorT>) {
211- return get_block_op_ptr<RequiredAlign, ElementsPerWorkItem>(
212- iter.get_decorated (), props);
226+ return get_block_op_ptr<ElementsPerWorkItem>(iter.get_decorated (), props);
213227 } else if constexpr (!std::is_pointer_v<iter_no_cv>) {
214228 if constexpr (props.template has_property <contiguous_memory_key>())
215- return get_block_op_ptr<RequiredAlign, ElementsPerWorkItem>(&*iter,
216- props);
229+ return get_block_op_ptr<ElementsPerWorkItem>(&*iter, props);
217230 else
218231 return nullptr ;
219232 } else {
220233 // Load/store to/from nullptr would be an UB, this assume allows the
221234 // compiler to optimize the IR further.
222235 __builtin_assume (iter != nullptr );
223236
224- // No early return as that would mess up return type deduction.
225- bool is_aligned = alignof (value_type) >= RequiredAlign ||
226- reinterpret_cast <uintptr_t >(iter) % RequiredAlign == 0 ;
227-
228237 using block_pointer_type =
229238 typename BlockTypeInfo<BlkInfo>::block_pointer_type;
230239
231- static constexpr auto deduced_address_space =
240+ constexpr auto deduced_address_space =
232241 BlockTypeInfo<BlkInfo>::deduced_address_space;
242+
233243 if constexpr (deduced_address_space ==
234244 access::address_space::generic_space ||
235245 deduced_address_space ==
236246 access::address_space::global_space ||
237- deduced_address_space == access::address_space::local_space) {
238- return is_aligned ? reinterpret_cast <block_pointer_type>(iter) : nullptr ;
247+ (deduced_address_space ==
248+ access::address_space::local_space &&
249+ props.template has_property <
250+ detail::native_local_block_io_key>())) {
251+ return reinterpret_cast <block_pointer_type>(iter);
239252 } else {
240253 return nullptr ;
241254 }
242255 }
243256}
257+
258+ template <int RequiredAlign, typename IteratorType>
259+ bool is_aligned (IteratorType iter) {
260+ using value_type = remove_decoration_t <
261+ typename std::iterator_traits<IteratorType>::value_type>;
262+ return alignof (value_type) >= RequiredAlign ||
263+ reinterpret_cast <uintptr_t >(&*iter) % RequiredAlign == 0 ;
264+ }
265+
244266} // namespace detail
245267
246268// Load API span overload.
@@ -266,78 +288,72 @@ group_load(Group g, InputIteratorT in_ptr,
266288 } else if constexpr (!std::is_same_v<Group, sycl::sub_group>) {
267289 return group_load (g, in_ptr, out, use_naive{});
268290 } else {
269- auto ptr =
270- detail::get_block_op_ptr<4 /* load align */ , ElementsPerWorkItem>(
271- in_ptr, props);
272- if (!ptr)
273- return group_load (g, in_ptr, out, use_naive{});
291+ auto ptr = detail::get_block_op_ptr<ElementsPerWorkItem>(in_ptr, props);
292+ static constexpr auto deduced_address_space =
293+ detail::deduce_AS<std::remove_cv_t <decltype (ptr)>>::value;
274294
275295 if constexpr (!std::is_same_v<std::nullptr_t , decltype (ptr)>) {
276- // Do optimized load.
277- using value_type = remove_decoration_t <
278- typename std::iterator_traits<InputIteratorT>::value_type>;
279- using block_info = typename detail::BlockTypeInfo<
280- detail::BlockInfo<InputIteratorT, ElementsPerWorkItem, blocked>>;
281- static constexpr auto deduced_address_space =
282- block_info::deduced_address_space;
283- using block_op_type = typename block_info::block_op_type;
284-
285- if constexpr (deduced_address_space ==
286- access::address_space::local_space &&
287- !props.template has_property <
288- detail::native_local_block_io_key>())
289- return group_load (g, in_ptr, out, use_naive{});
290-
291- block_op_type load;
292296 if constexpr (deduced_address_space ==
293297 access::address_space::generic_space) {
294298 if (auto local_ptr = detail::dynamic_address_cast<
295299 access::address_space::local_space>(ptr)) {
296- if constexpr (props.template has_property <
297- detail::native_local_block_io_key>())
298- load = __spirv_SubgroupBlockReadINTEL<block_op_type>(local_ptr);
299- else
300- return group_load (g, in_ptr, out, use_naive{});
300+ return group_load (g, local_ptr, out, props);
301301 } else if (auto global_ptr = detail::dynamic_address_cast<
302302 access::address_space::global_space>(ptr)) {
303- load = __spirv_SubgroupBlockReadINTEL<block_op_type>( global_ptr);
303+ return group_load (g, global_ptr, out, props );
304304 } else {
305305 return group_load (g, in_ptr, out, use_naive{});
306306 }
307307 } else {
308- load = __spirv_SubgroupBlockReadINTEL<block_op_type>(ptr);
309- }
308+ using value_type = remove_decoration_t <
309+ typename std::iterator_traits<InputIteratorT>::value_type>;
310+ using block_info = typename detail::BlockTypeInfo<
311+ detail::BlockInfo<InputIteratorT, ElementsPerWorkItem, blocked>>;
312+ using block_op_type = typename block_info::block_op_type;
313+ // Alignment checks of the pointer.
314+ constexpr int ReqAlign =
315+ detail::RequiredAlignment<detail::operation_type::load,
316+ deduced_address_space>::value;
317+ if (!detail::is_aligned<ReqAlign>(in_ptr))
318+ return group_load (g, in_ptr, out, use_naive{});
310319
311- // TODO: accessor_iterator's value_type is weird, so we need
312- // `std::remove_const_t` below:
313- //
314- // static_assert(
315- // std::is_same_v<
316- // typename std::iterator_traits<
317- // sycl::detail::accessor_iterator<const int, 1>>::value_type,
318- // const int>);
319- //
320- // yet
321- //
322- // static_assert(
323- // std::is_same_v<
324- // typename std::iterator_traits<const int *>::value_type, int>);
325-
326- if constexpr (std::is_same_v<std::remove_const_t <value_type>, OutputT>) {
327- static_assert (sizeof (load) == out.size_bytes ());
328- sycl::detail::memcpy_no_adl (out.begin (), &load, out.size_bytes ());
329- } else {
330- std::remove_const_t <value_type> values[ElementsPerWorkItem];
331- static_assert (sizeof (load) == sizeof (values));
332- sycl::detail::memcpy_no_adl (values, &load, sizeof (values));
333-
334- // Note: can't `memcpy` directly into `out` because that might bypass
335- // an implicit conversion required by the specification.
336- for (int i = 0 ; i < ElementsPerWorkItem; ++i)
337- out[i] = values[i];
320+ // We know the pointer is aligned and the address space is known. Do the
321+ // optimized load.
322+ auto load = __spirv_SubgroupBlockReadINTEL<block_op_type>(ptr);
323+
324+ // TODO: accessor_iterator's value_type is weird, so we need
325+ // `std::remove_const_t` below:
326+ //
327+ // static_assert(
328+ // std::is_same_v<
329+ // typename std::iterator_traits<
330+ // sycl::detail::accessor_iterator<const int,
331+ // 1>>::value_type,
332+ // const int>);
333+ //
334+ // yet
335+ //
336+ // static_assert(
337+ // std::is_same_v<
338+ // typename std::iterator_traits<const int *>::value_type,
339+ // int>);
340+ if constexpr (std::is_same_v<std::remove_const_t <value_type>,
341+ OutputT>) {
342+ static_assert (sizeof (load) == out.size_bytes ());
343+ sycl::detail::memcpy_no_adl (out.begin (), &load, out.size_bytes ());
344+ } else {
345+ std::remove_const_t <value_type> values[ElementsPerWorkItem];
346+ static_assert (sizeof (load) == sizeof (values));
347+ sycl::detail::memcpy_no_adl (values, &load, sizeof (values));
348+
349+ // Note: can't `memcpy` directly into `out` because that might bypass
350+ // an implicit conversion required by the specification.
351+ for (int i = 0 ; i < ElementsPerWorkItem; ++i)
352+ out[i] = values[i];
353+ }
338354 }
339-
340- return ;
355+ } else {
356+ return group_load (g, in_ptr, out, use_naive{}) ;
341357 }
342358 }
343359}
@@ -365,55 +381,50 @@ group_store(Group g, const span<InputT, ElementsPerWorkItem> in,
365381 } else if constexpr (!std::is_same_v<Group, sycl::sub_group>) {
366382 return group_store (g, in, out_ptr, use_naive{});
367383 } else {
368- auto ptr =
369- detail::get_block_op_ptr<16 /* store align */ , ElementsPerWorkItem>(
370- out_ptr, props);
371- if (!ptr)
372- return group_store (g, in, out_ptr, use_naive{});
384+ auto ptr = detail::get_block_op_ptr<ElementsPerWorkItem>(out_ptr, props);
373385
374386 if constexpr (!std::is_same_v<std::nullptr_t , decltype (ptr)>) {
375- using block_info = typename detail::BlockTypeInfo<
376- detail::BlockInfo<OutputIteratorT, ElementsPerWorkItem, blocked>>;
377387 static constexpr auto deduced_address_space =
378- block_info::deduced_address_space;
379- if constexpr (deduced_address_space ==
380- access::address_space::local_space &&
381- !props.template has_property <
382- detail::native_local_block_io_key>())
383- return group_store (g, in, out_ptr, use_naive{});
384-
385- // Do optimized store.
386- std::remove_const_t <remove_decoration_t <
387- typename std::iterator_traits<OutputIteratorT>::value_type>>
388- values[ElementsPerWorkItem];
389-
390- for (int i = 0 ; i < ElementsPerWorkItem; ++i) {
391- // Including implicit conversion.
392- values[i] = in[i];
393- }
394-
395- using block_op_type = typename block_info::block_op_type;
388+ detail::deduce_AS<std::remove_cv_t <decltype (ptr)>>::value;
396389 if constexpr (deduced_address_space ==
397390 access::address_space::generic_space) {
398391 if (auto local_ptr = detail::dynamic_address_cast<
399392 access::address_space::local_space>(ptr)) {
400- if constexpr (props.template has_property <
401- detail::native_local_block_io_key>())
402- __spirv_SubgroupBlockWriteINTEL (
403- local_ptr, sycl::bit_cast<block_op_type>(values));
404- else
405- return group_store (g, in, out_ptr, use_naive{});
393+ return group_store (g, in, local_ptr, props);
406394 } else if (auto global_ptr = detail::dynamic_address_cast<
407395 access::address_space::global_space>(ptr)) {
408- __spirv_SubgroupBlockWriteINTEL (
409- global_ptr, sycl::bit_cast<block_op_type>(values));
396+ return group_store (g, in, global_ptr, props);
410397 } else {
411398 return group_store (g, in, out_ptr, use_naive{});
412399 }
413400 } else {
401+ using block_info = typename detail::BlockTypeInfo<
402+ detail::BlockInfo<OutputIteratorT, ElementsPerWorkItem, blocked>>;
403+ using block_op_type = typename block_info::block_op_type;
404+
405+ // Alignment checks of the pointer.
406+ constexpr int ReqAlign =
407+ detail::RequiredAlignment<detail::operation_type::store,
408+ deduced_address_space>::value;
409+ if (!detail::is_aligned<ReqAlign>(out_ptr))
410+ return group_store (g, in, out_ptr, use_naive{});
411+
412+ std::remove_const_t <remove_decoration_t <
413+ typename std::iterator_traits<OutputIteratorT>::value_type>>
414+ values[ElementsPerWorkItem];
415+
416+ for (int i = 0 ; i < ElementsPerWorkItem; ++i) {
417+ // Including implicit conversion.
418+ values[i] = in[i];
419+ }
420+
421+ // We know the pointer is aligned and the address space is known. Do the
422+ // optimized load.
414423 __spirv_SubgroupBlockWriteINTEL (ptr,
415424 sycl::bit_cast<block_op_type>(values));
416425 }
426+ } else {
427+ return group_store (g, in, out_ptr, use_naive{});
417428 }
418429 }
419430}
0 commit comments