diff --git a/clang/runtime/dpct-rt/include/dpct/group_utils.hpp b/clang/runtime/dpct-rt/include/dpct/group_utils.hpp index 545191b59482..1efc70abe945 100644 --- a/clang/runtime/dpct-rt/include/dpct/group_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/group_utils.hpp @@ -708,11 +708,11 @@ class [[deprecated("Please use group_radix_sort instead")]] radix_sort { /// Load linear segment items into block format across threads /// Helper for Block Load -enum load_algorithm { +enum class load_algorithm { + BLOCK_LOAD_DIRECT, BLOCK_LOAD_STRIPED, }; - // loads a linear segment of workgroup items into a blocked arrangement. template @@ -869,6 +869,34 @@ uninitialized_load_subgroup_striped(const Item &item, InputIteratorT block_itr, new (&items[idx]) InputT(block_itr[initial_offset + (idx * subgroup_size)]); } } + +/// Stores a subgroup-striped arrangement of work items linear segment of items. +// Created as free function until exchange mechanism is +// implemented. +// To-do: inline this function with BLOCK_STORE_WARP_TRANSPOSE mechanism +template +__dpct_inline__ void +store_subgroup_striped(const ItemT &item, OutputIteratorT block_itr, + T (&data)[ElementsPerWorkItem]) { + + // This implementation does not take in account range storing across + // workgroup items To-do: Decide whether range storing is required for group + // loading + // This implementation loads linear segments into subgroup striped arrangement. + auto sub_group = item.get_sub_group(); + uint32_t subgroup_offset = sub_group.get_local_linear_id(); + uint32_t subgroup_size = sub_group.get_local_linear_range(); + uint32_t subgroup_idx = sub_group.get_group_linear_id(); + uint32_t initial_offset = + (subgroup_idx * ElementsPerWorkItem * subgroup_size) + subgroup_offset; + OutputIteratorT workitem_itr = block_itr + initial_offset; +#pragma unroll + for (uint32_t idx = 0; idx < ElementsPerWorkItem; idx++) { + workitem_itr[(idx * subgroup_size)] = data[idx]; + } +} + // template parameters : // ITEMS_PER_WORK_ITEM: size_t variable controlling the number of items per // thread/work_item @@ -887,9 +915,9 @@ class [[deprecated( __dpct_inline__ void load(const Item &item, InputIteratorT block_itr, InputT (&items)[ITEMS_PER_WORK_ITEM]) { - if constexpr (ALGORITHM == BLOCK_LOAD_DIRECT) { + if constexpr (ALGORITHM == load_algorithm::BLOCK_LOAD_DIRECT) { load_blocked(item, block_itr, items); - } else if constexpr (ALGORITHM == BLOCK_LOAD_STRIPED) { + } else if constexpr (ALGORITHM == load_algorithm::BLOCK_LOAD_STRIPED) { load_striped(item, block_itr, items); } }