Skip to content

Commit f7db283

Browse files
[SYCL] Allow work group scratch memory to be used with free function kernels (intel#19978)
This is a cherry-pick of intel#19837 When a kernel uses implicit local memory such as by way of the `get_work_group_scratch_memory` function, the library is supposed to mark the kernel with the appropriate attribute `WORK_GROUP_STATIC_ATTR` to get things to work at runtime. This is done through the properties passed to the kernel invocation call. For free function kernels however, the infrastructure is not there to do this marking process and usage of the above mentioned function typically results in a UR error. This PR makes some changes at the middle-end level to traverse the call graph wherever the compiler built-in functions `__sycl_allocateLocalMemory` and `__sycl_dynamicLocalMemoryPlaceholder` are used and mark each of the kernels found during this traversal , including free function kernels, with the `WORK_GROUP_STATIC_ATTR` attribute if not already present. Patch-by: Lorenc Bushi <[email protected]>
1 parent 5a32796 commit f7db283

File tree

3 files changed

+163
-0
lines changed

3 files changed

+163
-0
lines changed

llvm/lib/SYCLLowerIR/LowerWGLocalMemory.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,40 @@ lowerDynamicLocalMemCallDirect(CallInst *CI, Triple TT,
184184

185185
static void lowerLocalMemCall(Function *LocalMemAllocFunc,
186186
std::function<void(CallInst *CI)> TransformCall) {
187+
static SmallPtrSet<Function *, 16> FuncsCache;
187188
SmallVector<CallInst *, 4> DelCalls;
188189
for (User *U : LocalMemAllocFunc->users()) {
189190
auto *CI = cast<CallInst>(U);
190191
TransformCall(CI);
191192
DelCalls.push_back(CI);
193+
// Now, take each kernel that calls the builtins that allocate local memory,
194+
// either directly or through a series of function calls that eventually end
195+
// up in a direct call to the builtin, and attach the
196+
// work-group-memory-static attribute to the kernel if not already attached.
197+
// This is needed because free function kernels do not have the attribute
198+
// added by the library as is the case with other types of kernels.
199+
if (!FuncsCache.insert(CI->getFunction()).second)
200+
continue; // We have already traversed call graph from this function.
201+
202+
SmallVector<Function *, 8> WorkList;
203+
WorkList.push_back(CI->getFunction());
204+
while (!WorkList.empty()) {
205+
Function *F = WorkList.back();
206+
WorkList.pop_back();
207+
208+
// Mark kernel as using scratch memory if it isn't marked already.
209+
if (F->getCallingConv() == CallingConv::SPIR_KERNEL &&
210+
!F->hasFnAttribute(WORK_GROUP_STATIC_ATTR))
211+
F->addFnAttr(WORK_GROUP_STATIC_ATTR);
212+
213+
for (auto *FU : F->users()) {
214+
if (auto *UCI = dyn_cast<CallInst>(FU)) {
215+
if (FuncsCache.insert(UCI->getFunction()).second)
216+
WorkList.push_back(UCI->getFunction());
217+
} // Even though there could be other uses of a Function, we don't
218+
// care about them because we are only concerned about call graph.
219+
}
220+
}
192221
}
193222

194223
for (auto *CI : DelCalls) {

llvm/test/SYCLLowerIR/work_group_static.ll

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,29 @@ entry:
2222
ret void
2323
}
2424

25+
; Function Attrs: convergent norecurse
26+
; CHECK: @__sycl_kernel_B{{.*}} #[[ATTRS:[0-9]+]]
27+
define weak_odr dso_local spir_kernel void @__sycl_kernel_B(ptr addrspace(1) %0) local_unnamed_addr #1 !kernel_arg_addr_space !5 {
28+
entry:
29+
%1 = tail call spir_func ptr addrspace(3) @__sycl_dynamicLocalMemoryPlaceholder(i64 128) #1
30+
ret void
31+
}
32+
33+
; Function Attrs: convergent norecurse
34+
; CHECK: @__sycl_kernel_C{{.*}} #[[ATTRS]]
35+
define weak_odr dso_local spir_kernel void @__sycl_kernel_C(ptr addrspace(1) %0) local_unnamed_addr #1 !kernel_arg_addr_space !5 {
36+
entry:
37+
%1 = tail call spir_func ptr addrspace(3) @__sycl_allocateLocalMemory(i64 128, i64 4) #1
38+
ret void
39+
}
40+
41+
; Function Attrs: convergent
42+
declare dso_local spir_func ptr addrspace(3) @__sycl_allocateLocalMemory(i64, i64) local_unnamed_addr #1
43+
2544
; Function Attrs: convergent
2645
declare dso_local spir_func ptr addrspace(3) @__sycl_dynamicLocalMemoryPlaceholder(i64) local_unnamed_addr #1
2746

47+
; CHECK: #[[ATTRS]] = {{.*}} "sycl-work-group-static"
2848
attributes #0 = { convergent norecurse "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "uniform-work-group-size"="true" "unsafe-fp-math"="false" "use-soft-float"="false" "sycl-work-group-static"="1" }
2949
attributes #1 = { convergent norecurse }
3050

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
// REQUIRES: aspect-usm_shared_allocations
2+
// UNSUPPORTED: target-amd
3+
// UNSUPPORTED-TRACKER: https://github.com/intel/llvm/issues/16072
4+
5+
// RUN: %{build} -o %t.out
6+
// RUN: %{run} %t.out
7+
8+
// This test verifies that we can compile, run and get correct results when
9+
// using a free function kernel that allocates shared local memory in a kernel
10+
// either by way of the work group scratch memory extension or the work group
11+
// static memory extension.
12+
13+
#include "helpers.hpp"
14+
15+
#include <cassert>
16+
#include <sycl/ext/oneapi/experimental/enqueue_functions.hpp>
17+
#include <sycl/ext/oneapi/free_function_queries.hpp>
18+
#include <sycl/ext/oneapi/work_group_static.hpp>
19+
#include <sycl/group_barrier.hpp>
20+
#include <sycl/usm.hpp>
21+
22+
namespace syclext = sycl::ext::oneapi;
23+
namespace syclexp = sycl::ext::oneapi::experimental;
24+
25+
constexpr int SIZE = 16;
26+
27+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
28+
void scratchKernel(float *Src, float *Dst) {
29+
size_t Lid = syclext::this_work_item::get_nd_item<1>().get_local_linear_id();
30+
float *LocalMem =
31+
reinterpret_cast<float *>(syclexp::get_work_group_scratch_memory());
32+
LocalMem[Lid] = 2 * Src[Lid];
33+
Dst[Lid] = LocalMem[Lid];
34+
}
35+
36+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
37+
void staticKernel(float *Src, float *Dst) {
38+
sycl::nd_item<1> Item = syclext::this_work_item::get_nd_item<1>();
39+
size_t Lid = Item.get_local_linear_id();
40+
syclexp::work_group_static<float[SIZE]> LocalMem;
41+
LocalMem[Lid] = Src[Lid] * Src[Lid];
42+
sycl::group_barrier(Item.get_group());
43+
if (Item.get_group().leader()) { // Check that memory is indeed shared between
44+
// the work group.
45+
for (int I = 0; I < SIZE; ++I)
46+
assert(LocalMem[I] == Src[I] * Src[I]);
47+
}
48+
Dst[Lid] = LocalMem[Lid];
49+
}
50+
51+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
52+
void scratchStaticKernel(float *Src, float *Dst) {
53+
size_t Lid = syclext::this_work_item::get_nd_item<1>().get_local_linear_id();
54+
float *ScratchMem =
55+
reinterpret_cast<float *>(syclexp::get_work_group_scratch_memory());
56+
syclexp::work_group_static<float[SIZE]> StaticMem;
57+
ScratchMem[Lid] = Src[Lid];
58+
StaticMem[Lid] = Src[Lid];
59+
Dst[Lid] = ScratchMem[Lid] + StaticMem[Lid];
60+
}
61+
62+
int main() {
63+
sycl::queue Q;
64+
float *Src = sycl::malloc_shared<float>(SIZE, Q);
65+
float *Dst = sycl::malloc_shared<float>(SIZE, Q);
66+
67+
for (int I = 0; I < SIZE; I++) {
68+
Src[I] = I;
69+
}
70+
71+
auto ScratchBndl =
72+
syclexp::get_kernel_bundle<scratchKernel, sycl::bundle_state::executable>(
73+
Q.get_context());
74+
auto StaticBndl =
75+
syclexp::get_kernel_bundle<staticKernel, sycl::bundle_state::executable>(
76+
Q.get_context());
77+
auto ScratchStaticBndl = syclexp::get_kernel_bundle<
78+
scratchStaticKernel, sycl::bundle_state::executable>(Q.get_context());
79+
80+
sycl::kernel ScratchKrn =
81+
ScratchBndl.template ext_oneapi_get_kernel<scratchKernel>();
82+
sycl::kernel StaticKrn =
83+
StaticBndl.template ext_oneapi_get_kernel<staticKernel>();
84+
sycl::kernel ScratchStaticKrn =
85+
ScratchStaticBndl.template ext_oneapi_get_kernel<scratchStaticKernel>();
86+
syclexp::launch_config ScratchKernelcfg{
87+
::sycl::nd_range<1>(::sycl::range<1>(SIZE), ::sycl::range<1>(SIZE)),
88+
syclexp::properties{
89+
syclexp::work_group_scratch_size(SIZE * sizeof(float))}};
90+
syclexp::launch_config StaticKernelcfg{
91+
::sycl::nd_range<1>(::sycl::range<1>(SIZE), ::sycl::range<1>(SIZE))};
92+
93+
syclexp::nd_launch(Q, ScratchKernelcfg, ScratchKrn, Src, Dst);
94+
Q.wait();
95+
for (int I = 0; I < SIZE; I++) {
96+
assert(Dst[I] == 2 * Src[I]);
97+
}
98+
99+
syclexp::nd_launch(Q, StaticKernelcfg, StaticKrn, Src, Dst);
100+
Q.wait();
101+
for (int I = 0; I < SIZE; I++) {
102+
assert(Dst[I] == Src[I] * Src[I]);
103+
}
104+
105+
syclexp::nd_launch(Q, ScratchKernelcfg, ScratchStaticKrn, Src, Dst);
106+
Q.wait();
107+
for (int I = 0; I < SIZE; I++) {
108+
assert(Dst[I] == 2 * Src[I]);
109+
}
110+
111+
sycl::free(Src, Q);
112+
sycl::free(Dst, Q);
113+
return 0;
114+
}

0 commit comments

Comments
 (0)