@@ -58,6 +58,13 @@ struct naive_key : detail::compile_time_property_key<detail::PropKind::Naive> {
5858 using value_t = property_value<naive_key>;
5959};
6060inline constexpr naive_key::value_t naive;
61+
62+ struct native_local_block_io_key
63+ : detail::compile_time_property_key<detail::PropKind::NativeLocalBlockIO> {
64+ using value_t = property_value<native_local_block_io_key>;
65+ };
66+ inline constexpr native_local_block_io_key::value_t native_local_block_io;
67+
6168using namespace sycl ::detail;
6269} // namespace detail
6370
@@ -154,7 +161,6 @@ template <typename BlockInfoTy> struct BlockTypeInfo;
154161template <typename IteratorT, std::size_t ElementsPerWorkItem, bool Blocked>
155162struct BlockTypeInfo <BlockInfo<IteratorT, ElementsPerWorkItem, Blocked>> {
156163 using BlockInfoTy = BlockInfo<IteratorT, ElementsPerWorkItem, Blocked>;
157- static_assert (BlockInfoTy::has_builtin);
158164
159165 using block_type = detail::fixed_width_unsigned<BlockInfoTy::block_size>;
160166
@@ -163,15 +169,23 @@ struct BlockTypeInfo<BlockInfo<IteratorT, ElementsPerWorkItem, Blocked>> {
163169 typename std::iterator_traits<IteratorT>::reference>>,
164170 std::add_const_t <block_type>, block_type>;
165171
166- using block_pointer_type = typename detail::DecoratedType<
167- block_pointer_elem_type, access::address_space::global_space>::type *;
172+ static constexpr auto deduced_address_space =
173+ detail::deduce_AS<std::remove_cv_t <IteratorT>>::value;
174+
175+ using block_pointer_type =
176+ typename detail::DecoratedType<block_pointer_elem_type,
177+ deduced_address_space>::type *;
178+
168179 using block_op_type = std::conditional_t <
169180 BlockInfoTy::num_blocks == 1 , block_type,
170181 detail::ConvertToOpenCLType_t<vec<block_type, BlockInfoTy::num_blocks>>>;
171182};
172183
173- // Returns either a pointer suitable to use in a block read/write builtin or
174- // nullptr if some legality conditions aren't satisfied.
184+ // Returns either a pointer decorated with the deduced address space, suitable
185+ // to use in a block read/write builtin, or nullptr if some legality conditions
186+ // aren't satisfied. If deduced address space is generic then returned pointer
187+ // will have generic address space and has to be dynamically casted to global or
188+ // local space before using in a builtin.
175189template <int RequiredAlign, std::size_t ElementsPerWorkItem,
176190 typename IteratorT, typename Properties>
177191auto get_block_op_ptr (IteratorT iter, [[maybe_unused]] Properties props) {
@@ -211,16 +225,17 @@ auto get_block_op_ptr(IteratorT iter, [[maybe_unused]] Properties props) {
211225 bool is_aligned = alignof (value_type) >= RequiredAlign ||
212226 reinterpret_cast <uintptr_t >(iter) % RequiredAlign == 0 ;
213227
214- constexpr auto AS = detail::deduce_AS<iter_no_cv>::value;
215228 using block_pointer_type =
216229 typename BlockTypeInfo<BlkInfo>::block_pointer_type;
217- if constexpr (AS == access::address_space::global_space) {
230+
231+ static constexpr auto deduced_address_space =
232+ BlockTypeInfo<BlkInfo>::deduced_address_space;
233+ if constexpr (deduced_address_space ==
234+ access::address_space::generic_space ||
235+ deduced_address_space ==
236+ access::address_space::global_space ||
237+ deduced_address_space == access::address_space::local_space) {
218238 return is_aligned ? reinterpret_cast <block_pointer_type>(iter) : nullptr ;
219- } else if constexpr (AS == access::address_space::generic_space) {
220- return is_aligned ? reinterpret_cast <block_pointer_type>(
221- detail::dynamic_address_cast<
222- access::address_space::global_space>(iter))
223- : nullptr ;
224239 } else {
225240 return nullptr ;
226241 }
@@ -261,11 +276,37 @@ group_load(Group g, InputIteratorT in_ptr,
261276 // Do optimized load.
262277 using value_type = remove_decoration_t <
263278 typename std::iterator_traits<InputIteratorT>::value_type>;
264-
265- auto load = __spirv_SubgroupBlockReadINTEL<
266- typename detail::BlockTypeInfo<detail::BlockInfo<
267- InputIteratorT, ElementsPerWorkItem, blocked>>::block_op_type>(
268- ptr);
279+ using block_info = typename detail::BlockTypeInfo<
280+ detail::BlockInfo<InputIteratorT, ElementsPerWorkItem, blocked>>;
281+ static constexpr auto deduced_address_space =
282+ block_info::deduced_address_space;
283+ using block_op_type = typename block_info::block_op_type;
284+
285+ if constexpr (deduced_address_space ==
286+ access::address_space::local_space &&
287+ !props.template has_property <
288+ detail::native_local_block_io_key>())
289+ return group_load (g, in_ptr, out, use_naive{});
290+
291+ block_op_type load;
292+ if constexpr (deduced_address_space ==
293+ access::address_space::generic_space) {
294+ if (auto local_ptr = detail::dynamic_address_cast<
295+ access::address_space::local_space>(ptr)) {
296+ if constexpr (props.template has_property <
297+ detail::native_local_block_io_key>())
298+ load = __spirv_SubgroupBlockReadINTEL<block_op_type>(local_ptr);
299+ else
300+ return group_load (g, in_ptr, out, use_naive{});
301+ } else if (auto global_ptr = detail::dynamic_address_cast<
302+ access::address_space::global_space>(ptr)) {
303+ load = __spirv_SubgroupBlockReadINTEL<block_op_type>(global_ptr);
304+ } else {
305+ return group_load (g, in_ptr, out, use_naive{});
306+ }
307+ } else {
308+ load = __spirv_SubgroupBlockReadINTEL<block_op_type>(ptr);
309+ }
269310
270311 // TODO: accessor_iterator's value_type is weird, so we need
271312 // `std::remove_const_t` below:
@@ -331,6 +372,16 @@ group_store(Group g, const span<InputT, ElementsPerWorkItem> in,
331372 return group_store (g, in, out_ptr, use_naive{});
332373
333374 if constexpr (!std::is_same_v<std::nullptr_t , decltype (ptr)>) {
375+ using block_info = typename detail::BlockTypeInfo<
376+ detail::BlockInfo<OutputIteratorT, ElementsPerWorkItem, blocked>>;
377+ static constexpr auto deduced_address_space =
378+ block_info::deduced_address_space;
379+ if constexpr (deduced_address_space ==
380+ access::address_space::local_space &&
381+ !props.template has_property <
382+ detail::native_local_block_io_key>())
383+ return group_store (g, in, out_ptr, use_naive{});
384+
334385 // Do optimized store.
335386 std::remove_const_t <remove_decoration_t <
336387 typename std::iterator_traits<OutputIteratorT>::value_type>>
@@ -341,11 +392,28 @@ group_store(Group g, const span<InputT, ElementsPerWorkItem> in,
341392 values[i] = in[i];
342393 }
343394
344- __spirv_SubgroupBlockWriteINTEL (
345- ptr,
346- sycl::bit_cast<typename detail::BlockTypeInfo<detail::BlockInfo<
347- OutputIteratorT, ElementsPerWorkItem, blocked>>::block_op_type>(
348- values));
395+ using block_op_type = typename block_info::block_op_type;
396+ if constexpr (deduced_address_space ==
397+ access::address_space::generic_space) {
398+ if (auto local_ptr = detail::dynamic_address_cast<
399+ access::address_space::local_space>(ptr)) {
400+ if constexpr (props.template has_property <
401+ detail::native_local_block_io_key>())
402+ __spirv_SubgroupBlockWriteINTEL (
403+ local_ptr, sycl::bit_cast<block_op_type>(values));
404+ else
405+ return group_store (g, in, out_ptr, use_naive{});
406+ } else if (auto global_ptr = detail::dynamic_address_cast<
407+ access::address_space::global_space>(ptr)) {
408+ __spirv_SubgroupBlockWriteINTEL (
409+ global_ptr, sycl::bit_cast<block_op_type>(values));
410+ } else {
411+ return group_store (g, in, out_ptr, use_naive{});
412+ }
413+ } else {
414+ __spirv_SubgroupBlockWriteINTEL (ptr,
415+ sycl::bit_cast<block_op_type>(values));
416+ }
349417 }
350418 }
351419}
0 commit comments