@@ -566,6 +566,103 @@ class radix_sort {
566
566
uint8_t *_local_memory;
567
567
};
568
568
569
+ // / Load linear segment items into block format across threads
570
+ // / Helper for Block Load
571
+ enum load_algorithm {
572
+
573
+ BLOCK_LOAD_DIRECT,
574
+ BLOCK_LOAD_STRIPED,
575
+ // To-do: BLOCK_LOAD_WARP_TRANSPOSE
576
+
577
+ };
578
+
579
+ // loads a linear segment of workgroup items into a blocked arrangement.
580
+ template <size_t ITEMS_PER_WORK_ITEM, typename InputT, typename InputIteratorT,
581
+ typename Item>
582
+ __dpct_inline__ void load_blocked (const Item &item, InputIteratorT block_itr,
583
+ InputT (&items)[ITEMS_PER_WORK_ITEM]) {
584
+
585
+ // This implementation does not take in account range loading across
586
+ // workgroup items To-do: Decide whether range loading is required for group
587
+ // loading
588
+ size_t linear_tid = item.get_local_linear_id ();
589
+ uint32_t workgroup_offset = linear_tid * ITEMS_PER_WORK_ITEM;
590
+ #pragma unroll
591
+ for (size_t idx = 0 ; idx < ITEMS_PER_WORK_ITEM; idx++) {
592
+ items[idx] = block_itr[workgroup_offset + idx];
593
+ }
594
+ }
595
+
596
+ // loads a linear segment of workgroup items into a striped arrangement.
597
+ template <size_t ITEMS_PER_WORK_ITEM, typename InputT, typename InputIteratorT,
598
+ typename Item>
599
+ __dpct_inline__ void load_striped (const Item &item, InputIteratorT block_itr,
600
+ InputT (&items)[ITEMS_PER_WORK_ITEM]) {
601
+
602
+ // This implementation does not take in account range loading across
603
+ // workgroup items To-do: Decide whether range loading is required for group
604
+ // loading
605
+ size_t linear_tid = item.get_local_linear_id ();
606
+ size_t group_work_items = item.get_local_range ().size ();
607
+ #pragma unroll
608
+ for (size_t idx = 0 ; idx < ITEMS_PER_WORK_ITEM; idx++) {
609
+ items[idx] = block_itr[linear_tid + (idx * group_work_items)];
610
+ }
611
+ }
612
+
613
+ // loads a linear segment of workgroup items into a subgroup striped
614
+ // arrangement. Created as free function until exchange mechanism is
615
+ // implemented.
616
+ // To-do: inline this function with BLOCK_LOAD_WARP_TRANSPOSE mechanism
617
+ template <size_t ITEMS_PER_WORK_ITEM, typename InputT, typename InputIteratorT,
618
+ typename Item>
619
+ __dpct_inline__ void
620
+ uninitialized_load_subgroup_striped (const Item &item, InputIteratorT block_itr,
621
+ InputT (&items)[ITEMS_PER_WORK_ITEM]) {
622
+
623
+ // This implementation does not take in account range loading across
624
+ // workgroup items To-do: Decide whether range loading is required for group
625
+ // loading
626
+ // This implementation uses unintialized memory for loading linear segments
627
+ // into warp striped arrangement.
628
+ uint32_t subgroup_offset = item.get_sub_group ().get_local_linear_id ();
629
+ uint32_t subgroup_size = item.get_sub_group ().get_local_linear_range ();
630
+ uint32_t subgroup_idx = item.get_sub_group ().get_group_linear_id ();
631
+ uint32_t initial_offset =
632
+ (subgroup_idx * ITEMS_PER_WORK_ITEM * subgroup_size) + subgroup_offset;
633
+ #pragma unroll
634
+ for (size_t idx = 0 ; idx < ITEMS_PER_WORK_ITEM; idx++) {
635
+ new (&items[idx]) InputT (block_itr[initial_offset + (idx * subgroup_size)]);
636
+ }
637
+ }
638
+ // template parameters :
639
+ // ITEMS_PER_WORK_ITEM: size_t variable controlling the number of items per
640
+ // thread/work_item
641
+ // ALGORITHM: load_algorithm variable controlling the type of load operation.
642
+ // InputT: type for input sequence.
643
+ // InputIteratorT: input iterator type
644
+ // Item : typename parameter resembling sycl::nd_item<3> .
645
+ template <size_t ITEMS_PER_WORK_ITEM, load_algorithm ALGORITHM, typename InputT,
646
+ typename InputIteratorT, typename Item>
647
+ class workgroup_load {
648
+ public:
649
+ static size_t get_local_memory_size (size_t group_work_items) { return 0 ; }
650
+ workgroup_load (uint8_t *local_memory) : _local_memory(local_memory) {}
651
+
652
+ __dpct_inline__ void load (const Item &item, InputIteratorT block_itr,
653
+ InputT (&items)[ITEMS_PER_WORK_ITEM]) {
654
+
655
+ if constexpr (ALGORITHM == BLOCK_LOAD_DIRECT) {
656
+ load_blocked<ITEMS_PER_WORK_ITEM>(item, block_itr, items);
657
+ } else if constexpr (ALGORITHM == BLOCK_LOAD_STRIPED) {
658
+ load_striped<ITEMS_PER_WORK_ITEM>(item, block_itr, items);
659
+ }
660
+ }
661
+
662
+ private:
663
+ uint8_t *_local_memory;
664
+ };
665
+
569
666
// / Perform a reduction of the data elements assigned to all threads in the
570
667
// / group.
571
668
// /
0 commit comments