Skip to content

Commit 4f7b179

Browse files
authored
[SYCLLowerIR] Fix a bug in hierarchical parallelism implementation (#20484)
When inside a `parallel_for_work_group` context, calls to functions that contain calls to `parallel_for_work_item` are being lowered incorrectly into IR in that they are being put under the work-group leader branch in the IR which is semantically incorrect as this function should be called in every work item. This manifests when we have an indirect function call to `parallel_for_work_item` together with at least one other direct call to `parallel_for_work_item` in the same `parallel_for_work_group` context and it leads to a program that hangs. This PR fixes the issue and adds a couple of other tests to check this behavior.
1 parent 2f2fb47 commit 4f7b179

File tree

3 files changed

+102
-1
lines changed

3 files changed

+102
-1
lines changed

llvm/lib/SYCLLowerIR/LowerWGScope.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,22 @@ static bool hasCallToAFuncWithWGMetadata(Function &F) {
214214
return false;
215215
}
216216

217+
// Recursively searches for a call to a function with parallel_for_work_item
218+
// metadata inside F.
219+
static bool hasCallToAFuncWithPFWIMetadata(Function &F) {
220+
for (auto &BB : F)
221+
for (auto &I : BB) {
222+
if (isCallToAFuncMarkedWithMD(&I, PFWI_MD))
223+
return true;
224+
const CallInst *Call = dyn_cast<CallInst>(&I);
225+
Function *F = dyn_cast_or_null<Function>(Call ? Call->getCalledFunction()
226+
: nullptr);
227+
if (F && hasCallToAFuncWithPFWIMetadata(*F))
228+
return true;
229+
}
230+
return false;
231+
}
232+
217233
// Checks if this is a call to parallel_for_work_item.
218234
static bool isPFWICall(const Instruction *I) {
219235
return isCallToAFuncMarkedWithMD(I, PFWI_MD);
@@ -835,7 +851,14 @@ PreservedAnalyses SYCLLowerWGScopePass::run(Function &F,
835851
}
836852
continue;
837853
}
838-
if (!mayHaveSideEffects(I))
854+
// In addition to an instruction not having side effects, we end the range
855+
// if the instruction is a call that contains, possibly several layers
856+
// down the stack, a call to a parallel_for_work_item. Such calls should
857+
// not be subject to lowering since they must be executed by every work
858+
// item.
859+
const CallInst *Call = dyn_cast<CallInst>(I);
860+
if (!mayHaveSideEffects(I) ||
861+
(Call && hasCallToAFuncWithPFWIMetadata(*Call->getCalledFunction())))
839862
continue;
840863
LLVM_DEBUG(llvm::dbgs() << "+++ Side effects: " << *I << "\n");
841864
if (!First)

llvm/test/SYCLLowerIR/pfwg_and_pfwi.ll

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
; are properly handled by LowerWGScope pass. Check that WG-shared local "shadow" variables are created
66
; and before each PFWI invocation leader WI stores its private copy of the variable into the shadow,
77
; then all WIs load the shadow value into their private copies ("materialize" the private copy).
8+
; Also check that an indirect call to a function marked with parallel_for_work_item is treated
9+
; the same as a direct call.
810

911
%struct.bar = type { i8 }
1012
%struct.zot = type { %struct.widget, %struct.widget, %struct.widget, %struct.foo }
@@ -54,6 +56,7 @@ define internal spir_func void @wibble(ptr addrspace(4) %arg, ptr byval(%struct.
5456
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrieriii(i32 2, i32 2, i32 272) #[[ATTR0]]
5557
; CHECK-NEXT: [[TMP9:%.*]] = addrspacecast ptr [[ARG1]] to ptr addrspace(4)
5658
; CHECK-NEXT: call spir_func void @bar(ptr addrspace(4) [[TMP9]], ptr byval([[STRUCT_FOO_0]]) align 1 [[TMP1]])
59+
; CHECK-NEXT: call spir_func void @foo(ptr addrspace(4) [[TMP9]], ptr byval([[STRUCT_FOO_0]]) align 1 [[TMP1]])
5760
; CHECK-NEXT: ret void
5861
;
5962
bb:
@@ -62,6 +65,57 @@ bb:
6265
store ptr addrspace(4) %arg, ptr %0, align 8
6366
%2 = addrspacecast ptr %arg1 to ptr addrspace(4)
6467
call spir_func void @bar(ptr addrspace(4) %2, ptr byval(%struct.foo.0) align 1 %1)
68+
call spir_func void @foo(ptr addrspace(4) %2, ptr byval(%struct.foo.0) align 1 %1)
69+
ret void
70+
}
71+
72+
define internal spir_func void @foo(ptr addrspace(4) %arg, ptr byval(%struct.foo.0) align 1 %arg1) align 2 !work_group_scope !0 {
73+
; CHECK: bb:
74+
; CHECK-NEXT: [[TMP0:%.*]] = alloca ptr addrspace(4), align 8
75+
; CHECK-NEXT: [[TMP1:%.*]] = alloca [[STRUCT_FOO_0:%.*]], align 1
76+
; CHECK-NEXT: [[TMP2:%.*]] = load i64, ptr addrspace(1) @__spirv_BuiltInLocalInvocationIndex, align 4
77+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrieriii(i32 2, i32 2, i32 272) #[[ATTR0]]
78+
; CHECK-NEXT: [[CMPZ3:%.*]] = icmp eq i64 [[TMP2]], 0
79+
; CHECK-NEXT: br i1 [[CMPZ3]], label [[LEADER:%.*]], label [[MERGE:%.*]]
80+
; CHECK: leader:
81+
; CHECK-NEXT: call void @llvm.memcpy.p3.p0.i64(ptr addrspace(3) align 8 @ArgShadow.4, ptr align 1 [[ARG1:%.*]], i64 1, i1 false)
82+
; CHECK-NEXT: br label [[MERGE]]
83+
; CHECK: merge:
84+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrieriii(i32 2, i32 2, i32 272) #[[ATTR0]]
85+
; CHECK-NEXT: call void @llvm.memcpy.p0.p3.i64(ptr align 1 [[ARG1]], ptr addrspace(3) align 8 @ArgShadow.4, i64 1, i1 false)
86+
; CHECK-NEXT: [[TMP3:%.*]] = load i64, ptr addrspace(1) @__spirv_BuiltInLocalInvocationIndex, align 4
87+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrieriii(i32 2, i32 2, i32 272) #[[ATTR0]]
88+
; CHECK-NEXT: [[CMPZ:%.*]] = icmp eq i64 [[TMP3]], 0
89+
; CHECK-NEXT: br i1 [[CMPZ]], label [[WG_LEADER:%.*]], label [[WG_CF:%.*]]
90+
; CHECK: wg_leader:
91+
; CHECK-NEXT: store ptr addrspace(4) [[ARG:%.*]], ptr [[TMP0]], align 8
92+
; CHECK-NEXT: br label [[WG_CF]]
93+
; CHECK: wg_cf:
94+
; CHECK-NEXT: [[TMP4:%.*]] = load i64, ptr addrspace(1) @__spirv_BuiltInLocalInvocationIndex, align 4
95+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrieriii(i32 2, i32 2, i32 272) #[[ATTR0]]
96+
; CHECK-NEXT: [[CMPZ2:%.*]] = icmp eq i64 [[TMP4]], 0
97+
; CHECK-NEXT: br i1 [[CMPZ2]], label [[TESTMAT:%.*]], label [[LEADERMAT:%.*]]
98+
; CHECK: TestMat:
99+
; CHECK-NEXT: call void @llvm.memcpy.p3.p0.i64(ptr addrspace(3) align 8 @WGCopy.3, ptr align 1 [[TMP1]], i64 1, i1 false)
100+
; CHECK-NEXT: [[MAT_LD:%.*]] = load ptr addrspace(4), ptr [[TMP0]], align 8
101+
; CHECK-NEXT: store ptr addrspace(4) [[MAT_LD]], ptr addrspace(3) @WGCopy.2, align 8
102+
; CHECK-NEXT: br label [[LEADERMAT]]
103+
; CHECK: LeaderMat:
104+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrieriii(i32 2, i32 2, i32 272) #[[ATTR0]]
105+
; CHECK-NEXT: [[MAT_LD1:%.*]] = load ptr addrspace(4), ptr addrspace(3) @WGCopy.2, align 8
106+
; CHECK-NEXT: store ptr addrspace(4) [[MAT_LD1]], ptr [[TMP0]], align 8
107+
; CHECK-NEXT: call void @llvm.memcpy.p0.p3.i64(ptr align 1 [[TMP1]], ptr addrspace(3) align 8 @WGCopy.3, i64 1, i1 false)
108+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrieriii(i32 2, i32 2, i32 272) #[[ATTR0]]
109+
; CHECK-NEXT: [[TMP5:%.*]] = addrspacecast ptr [[ARG1]] to ptr addrspace(4)
110+
; CHECK-NEXT: call spir_func void @bar(ptr addrspace(4) [[TMP5]], ptr byval([[STRUCT_FOO_0]]) align 1 [[TMP1]])
111+
; CHECK-NEXT: ret void
112+
;
113+
bb:
114+
%1 = alloca ptr addrspace(4), align 8
115+
%2 = alloca %struct.foo.0, align 1
116+
store ptr addrspace(4) %arg, ptr %1, align 8
117+
%3 = addrspacecast ptr %arg1 to ptr addrspace(4)
118+
call spir_func void @bar(ptr addrspace(4) %3, ptr byval(%struct.foo.0) align 1 %2)
65119
ret void
66120
}
67121

sycl/test-e2e/HierPar/hier_par_indirect.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,36 @@ void __attribute__((noinline)) foo(sycl::group<1> work_group) {
1919
work_group.parallel_for_work_item([&](sycl::h_item<1> index) {});
2020
}
2121

22+
void __attribute__((noinline)) bar(sycl::group<1> work_group) {
23+
foo(work_group);
24+
}
25+
2226
int main(int argc, char **argv) {
2327
sycl::queue q;
28+
29+
// Try a single indirect call, two indirect calls and an indirect call
30+
// accompanied by multiple parallel_for_work_item calls in the same work_group
31+
// scope.
2432
q.submit([&](sycl::handler &cgh) {
2533
cgh.parallel_for_work_group(sycl::range<1>{1}, sycl::range<1>{128},
2634
([=](sycl::group<1> wGroup) { foo(wGroup); }));
2735
}).wait();
36+
q.submit([&](sycl::handler &cgh) {
37+
cgh.parallel_for_work_group(
38+
sycl::range<1>{1}, sycl::range<1>{128}, ([=](sycl::group<1> wGroup) {
39+
foo(wGroup); // 1-layer indirect call
40+
bar(wGroup); // 2-layer indirect call since bar calls foo
41+
}));
42+
}).wait();
43+
q.submit([&](sycl::handler &cgh) {
44+
cgh.parallel_for_work_group(
45+
sycl::range<1>{1}, sycl::range<1>{128}, ([=](sycl::group<1> wGroup) {
46+
wGroup.parallel_for_work_item([&](sycl::h_item<1> index) {});
47+
foo(wGroup);
48+
wGroup.parallel_for_work_item([&](sycl::h_item<1> index) {});
49+
}));
50+
}).wait();
51+
2852
std::cout << "test passed" << std::endl;
2953
return 0;
3054
}

0 commit comments

Comments
 (0)