@@ -566,6 +566,103 @@ class radix_sort {
566566 uint8_t *_local_memory;
567567};
568568
569+ // / Load linear segment items into block format across threads
570+ // / Helper for Block Load
571+ enum load_algorithm {
572+
573+ BLOCK_LOAD_DIRECT,
574+ BLOCK_LOAD_STRIPED,
575+ // To-do: BLOCK_LOAD_WARP_TRANSPOSE
576+
577+ };
578+
579+ // loads a linear segment of workgroup items into a blocked arrangement.
580+ template <size_t ITEMS_PER_WORK_ITEM, typename InputT, typename InputIteratorT,
581+ typename Item>
582+ __dpct_inline__ void load_blocked (const Item &item, InputIteratorT block_itr,
583+ InputT (&items)[ITEMS_PER_WORK_ITEM]) {
584+
585+ // This implementation does not take in account range loading across
586+ // workgroup items To-do: Decide whether range loading is required for group
587+ // loading
588+ size_t linear_tid = item.get_local_linear_id ();
589+ uint32_t workgroup_offset = linear_tid * ITEMS_PER_WORK_ITEM;
590+ #pragma unroll
591+ for (size_t idx = 0 ; idx < ITEMS_PER_WORK_ITEM; idx++) {
592+ items[idx] = block_itr[workgroup_offset + idx];
593+ }
594+ }
595+
596+ // loads a linear segment of workgroup items into a striped arrangement.
597+ template <size_t ITEMS_PER_WORK_ITEM, typename InputT, typename InputIteratorT,
598+ typename Item>
599+ __dpct_inline__ void load_striped (const Item &item, InputIteratorT block_itr,
600+ InputT (&items)[ITEMS_PER_WORK_ITEM]) {
601+
602+ // This implementation does not take in account range loading across
603+ // workgroup items To-do: Decide whether range loading is required for group
604+ // loading
605+ size_t linear_tid = item.get_local_linear_id ();
606+ size_t group_work_items = item.get_local_range ().size ();
607+ #pragma unroll
608+ for (size_t idx = 0 ; idx < ITEMS_PER_WORK_ITEM; idx++) {
609+ items[idx] = block_itr[linear_tid + (idx * group_work_items)];
610+ }
611+ }
612+
613+ // loads a linear segment of workgroup items into a subgroup striped
614+ // arrangement. Created as free function until exchange mechanism is
615+ // implemented.
616+ // To-do: inline this function with BLOCK_LOAD_WARP_TRANSPOSE mechanism
617+ template <size_t ITEMS_PER_WORK_ITEM, typename InputT, typename InputIteratorT,
618+ typename Item>
619+ __dpct_inline__ void
620+ uninitialized_load_subgroup_striped (const Item &item, InputIteratorT block_itr,
621+ InputT (&items)[ITEMS_PER_WORK_ITEM]) {
622+
623+ // This implementation does not take in account range loading across
624+ // workgroup items To-do: Decide whether range loading is required for group
625+ // loading
626+ // This implementation uses unintialized memory for loading linear segments
627+ // into warp striped arrangement.
628+ uint32_t subgroup_offset = item.get_sub_group ().get_local_linear_id ();
629+ uint32_t subgroup_size = item.get_sub_group ().get_local_linear_range ();
630+ uint32_t subgroup_idx = item.get_sub_group ().get_group_linear_id ();
631+ uint32_t initial_offset =
632+ (subgroup_idx * ITEMS_PER_WORK_ITEM * subgroup_size) + subgroup_offset;
633+ #pragma unroll
634+ for (size_t idx = 0 ; idx < ITEMS_PER_WORK_ITEM; idx++) {
635+ new (&items[idx]) InputT (block_itr[initial_offset + (idx * subgroup_size)]);
636+ }
637+ }
638+ // template parameters :
639+ // ITEMS_PER_WORK_ITEM: size_t variable controlling the number of items per
640+ // thread/work_item
641+ // ALGORITHM: load_algorithm variable controlling the type of load operation.
642+ // InputT: type for input sequence.
643+ // InputIteratorT: input iterator type
644+ // Item : typename parameter resembling sycl::nd_item<3> .
645+ template <size_t ITEMS_PER_WORK_ITEM, load_algorithm ALGORITHM, typename InputT,
646+ typename InputIteratorT, typename Item>
647+ class workgroup_load {
648+ public:
649+ static size_t get_local_memory_size (size_t group_work_items) { return 0 ; }
650+ workgroup_load (uint8_t *local_memory) : _local_memory(local_memory) {}
651+
652+ __dpct_inline__ void load (const Item &item, InputIteratorT block_itr,
653+ InputT (&items)[ITEMS_PER_WORK_ITEM]) {
654+
655+ if constexpr (ALGORITHM == BLOCK_LOAD_DIRECT) {
656+ load_blocked<ITEMS_PER_WORK_ITEM>(item, block_itr, items);
657+ } else if constexpr (ALGORITHM == BLOCK_LOAD_STRIPED) {
658+ load_striped<ITEMS_PER_WORK_ITEM>(item, block_itr, items);
659+ }
660+ }
661+
662+ private:
663+ uint8_t *_local_memory;
664+ };
665+
569666// / Perform a reduction of the data elements assigned to all threads in the
570667// / group.
571668// /
0 commit comments