Skip to content

Commit 41bd229

Browse files
author
Yihan Wang
authored
[SYCLomatic] Refine the implementation of dpct::group::load_direct_{blocked, striped} (#2248)
Signed-off-by: Wang, Yihan <[email protected]>
1 parent cfe9b51 commit 41bd229

File tree

3 files changed

+98
-33
lines changed

3 files changed

+98
-33
lines changed

clang/lib/DPCT/Rewriters/CUB/RewriterUtilityFunctions.cpp

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,53 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "AnalysisInfo.h"
910
#include "CallExprRewriterCUB.h"
1011
#include "CallExprRewriterCommon.h"
1112

1213
using namespace clang::dpct;
1314

15+
namespace {
16+
class PrettyTemplatedFunctionNamePrinter {
17+
std::string Name;
18+
std::vector<TemplateArgumentInfo> Args;
19+
20+
public:
21+
PrettyTemplatedFunctionNamePrinter(StringRef Name,
22+
std::vector<TemplateArgumentInfo> &&Args)
23+
: Name(Name.str()), Args(std::move(Args)) {}
24+
template <class StreamT> void print(StreamT &Stream) const {
25+
dpct::print(Stream, Name);
26+
if (!Args.empty()) {
27+
Stream << '<';
28+
ArgsPrinter<false, std::vector<TemplateArgumentInfo>>(Args).print(Stream);
29+
Stream << '>';
30+
}
31+
}
32+
};
33+
34+
std::function<PrettyTemplatedFunctionNamePrinter(const CallExpr *)>
35+
makePrettyTemplatedCalleeCreator(std::string CalleeName,
36+
std::vector<size_t> Indexes) {
37+
return PrinterCreator<
38+
PrettyTemplatedFunctionNamePrinter, std::string,
39+
std::function<std::vector<TemplateArgumentInfo>(const CallExpr *)>>(
40+
CalleeName, [=](const CallExpr *C) -> std::vector<TemplateArgumentInfo> {
41+
std::vector<TemplateArgumentInfo> Ret;
42+
auto List = getTemplateArgsList(C);
43+
for (auto Idx : Indexes) {
44+
if (Idx < List.size()) {
45+
Ret.emplace_back(List[Idx]);
46+
}
47+
}
48+
return Ret;
49+
});
50+
}
51+
} // namespace
52+
53+
#define PRETTY_TEMPLATED_CALLEE(FuncName, ...) \
54+
makePrettyTemplatedCalleeCreator(FuncName, {__VA_ARGS__})
55+
1456
RewriterMap dpct::createUtilityFunctionsRewriterMap() {
1557
return RewriterMap{
1658
// cub::IADD3
@@ -120,13 +162,17 @@ RewriterMap dpct::createUtilityFunctionsRewriterMap() {
120162
HeaderType::HT_DPCT_GROUP_Utils,
121163
CALL_FACTORY_ENTRY(
122164
"cub::LoadDirectBlocked",
123-
CALL(MapNames::getDpctNamespace() + "group::load_direct_blocked",
124-
NDITEM, ARG(1), ARG(2))))
165+
CALL(PRETTY_TEMPLATED_CALLEE(MapNames::getDpctNamespace() +
166+
"group::load_direct_blocked",
167+
0, 1, 2),
168+
ARG(0), ARG(1), ARG(2))))
125169
// cub::LoadDirectStriped
126170
HEADER_INSERT_FACTORY(
127171
HeaderType::HT_DPCT_GROUP_Utils,
128172
CALL_FACTORY_ENTRY(
129173
"cub::LoadDirectStriped",
130-
CALL(MapNames::getDpctNamespace() + "group::load_direct_striped",
131-
NDITEM, ARG(1), ARG(2))))};
174+
CALL(PRETTY_TEMPLATED_CALLEE(MapNames::getDpctNamespace() +
175+
"group::load_direct_striped",
176+
0, 1, 2, 3),
177+
ARG(0), ARG(1), ARG(2))))};
132178
}

clang/runtime/dpct-rt/include/dpct/group_utils.hpp

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef __DPCT_GROUP_UTILS_HPP__
1010
#define __DPCT_GROUP_UTILS_HPP__
1111

12+
#include <iterator>
1213
#include <stdexcept>
1314
#include <sycl/sycl.hpp>
1415

@@ -476,41 +477,59 @@ __dpct_inline__ void load_striped(const Item &item, InputIteratorT block_itr,
476477
}
477478
}
478479

479-
// loads a linear segment of workgroup items into a blocked arrangement.
480-
template <typename InputT, size_t ITEMS_PER_WORK_ITEM, typename InputIteratorT,
481-
typename Item>
482-
__dpct_inline__ void load_direct_blocked(const Item &item, InputIteratorT block_itr,
483-
InputT (&items)[ITEMS_PER_WORK_ITEM]) {
484-
485-
// This implementation does not take in account range loading across
486-
// workgroup items To-do: Decide whether range loading is required for group
487-
// loading
488-
size_t linear_tid = item.get_local_linear_id();
489-
uint32_t workgroup_offset = linear_tid * ITEMS_PER_WORK_ITEM;
480+
/// Load a linear segment of elements into a blocked arrangement across the
481+
/// work-group.
482+
///
483+
/// \tparam InputT The data type to load.
484+
///
485+
/// \tparam ElementsPerWorkItem The number of consecutive elements partitioned
486+
/// onto each work-item.
487+
///
488+
/// \tparam InputIteratorT The random-access iterator type for input \iterator.
489+
///
490+
/// \param linear_tid A suitable linear identifier for the calling work-item.
491+
///
492+
/// \param block_itr The work-group's base input iterator for loading from.
493+
///
494+
/// \param items Data to load
495+
template <typename InputT, size_t ElementsPerWorkItem, typename InputIteratorT>
496+
__dpct_inline__ void load_direct_blocked(size_t linear_tid,
497+
InputIteratorT block_itr,
498+
InputT (&items)[ElementsPerWorkItem]) {
490499
#pragma unroll
491-
for (size_t idx = 0; idx < ITEMS_PER_WORK_ITEM; idx++) {
492-
items[idx] = block_itr[workgroup_offset + idx];
500+
for (size_t i = 0; i < ElementsPerWorkItem; i++) {
501+
items[i] = block_itr[(linear_tid * ElementsPerWorkItem) + i];
493502
}
494503
}
495504

496-
// loads a linear segment of workgroup items into a striped arrangement.
497-
template <typename InputT, size_t ITEMS_PER_WORK_ITEM, typename InputIteratorT,
498-
typename Item>
499-
__dpct_inline__ void load_direct_striped(const Item &item, InputIteratorT block_itr,
500-
InputT (&items)[ITEMS_PER_WORK_ITEM]) {
501-
502-
// This implementation does not take in account range loading across
503-
// workgroup items To-do: Decide whether range loading is required for group
504-
// loading
505-
size_t linear_tid = item.get_local_linear_id();
506-
size_t group_work_items = item.get_local_range().size();
505+
/// Load a linear segment of elements into a striped arrangement across the
506+
/// work-group.
507+
///
508+
/// \tparam WorkGroupSize The work-group size.
509+
///
510+
/// \tparam InputT The data type to load.
511+
///
512+
/// \tparam ElementsPerWorkItem The number of consecutive elements partitioned
513+
/// onto each work-item.
514+
///
515+
/// \tparam InputIteratorT The random-access iterator type for input \iterator.
516+
///
517+
/// \param linear_tid A suitable linear identifier for the calling work-item.
518+
///
519+
/// \param block_itr The work-group's base input iterator for loading from.
520+
///
521+
/// \param items Data to load
522+
template <size_t WorkGroupSize, typename InputT, int ElementsPerWorkItem,
523+
typename InputIteratorT>
524+
__dpct_inline__ void load_direct_striped(size_t linear_tid,
525+
InputIteratorT block_itr,
526+
InputT (&items)[ElementsPerWorkItem]) {
507527
#pragma unroll
508-
for (size_t idx = 0; idx < ITEMS_PER_WORK_ITEM; idx++) {
509-
items[idx] = block_itr[linear_tid + (idx * group_work_items)];
528+
for (size_t i = 0; i < ElementsPerWorkItem; i++) {
529+
items[i] = block_itr[linear_tid + i * WorkGroupSize];
510530
}
511531
}
512532

513-
514533
// loads a linear segment of workgroup items into a subgroup striped
515534
// arrangement. Created as free function until exchange mechanism is
516535
// implemented.

clang/test/dpct/cub/intrinsic/load.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111

1212
__global__ void TestLoadStriped(int *d_data) {
1313
int thread_data[4];
14-
// CHECK: dpct::group::load_direct_striped(item_ct1, d_data, thread_data);
14+
// CHECK: dpct::group::load_direct_striped<128>(item_ct1.get_local_id(2), d_data, thread_data);
1515
cub::LoadDirectStriped<128>(threadIdx.x, d_data, thread_data);
1616
}
1717

1818
__global__ void BlockedToStripedKernel(int *d_data) {
1919
int thread_data[4];
20-
// CHECK: dpct::group::load_direct_blocked(item_ct1, d_data, thread_data);
20+
// CHECK: dpct::group::load_direct_blocked(item_ct1.get_local_id(2), d_data, thread_data);
2121
cub::LoadDirectBlocked(threadIdx.x, d_data, thread_data);
2222
}

0 commit comments

Comments
 (0)