Skip to content

Commit d8ac7ce

Browse files
committed
[SPIR-V] Add llvm.loop.unroll metadata lowering
.enable lowers to Unroll LoopControl .disable lowers to DontUnroll LoopControl .count lowers to PartialCount LoopControl .full lowers to Unroll + PartialCount LoopControls TODO: enable structurizer for non-vulkan targets. Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent 1b455df commit d8ac7ce

File tree

4 files changed

+268
-3
lines changed

4 files changed

+268
-3
lines changed

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2985,10 +2985,11 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
29852985
case Intrinsic::spv_loop_merge: {
29862986
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLoopMerge));
29872987
for (unsigned i = 1; i < I.getNumExplicitOperands(); ++i) {
2988-
assert(I.getOperand(i).isMBB());
2989-
MIB.addMBB(I.getOperand(i).getMBB());
2988+
if (I.getOperand(i).isMBB())
2989+
MIB.addMBB(I.getOperand(i).getMBB());
2990+
else
2991+
MIB.addImm(foldImm(I.getOperand(i), MRI));
29902992
}
2991-
MIB.addImm(SPIRV::SelectionControl::None);
29922993
return MIB.constrainAllUses(TII, TRI, RBI);
29932994
}
29942995
case Intrinsic::spv_selection_merge: {

llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,40 @@ class SPIRVStructurizer : public FunctionPass {
611611
auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
612612
auto ContinueAddress = BlockAddress::get(Continue->getParent(), Continue);
613613
SmallVector<Value *, 2> Args = {MergeAddress, ContinueAddress};
614+
unsigned LC = SPIRV::LoopControl::None;
615+
// Currently used only to store PartialCount value. Later when other
616+
// LoopControls are added - this map should be sorted before making
617+
// them loop_merge operands to satisfy 3.23. Loop Control requirements.
618+
std::vector<std::pair<unsigned, unsigned>> MaskToValueMap;
619+
if (getBooleanLoopAttribute(L, "llvm.loop.unroll.disable")) {
620+
LC |= SPIRV::LoopControl::DontUnroll;
621+
} else {
622+
if (getBooleanLoopAttribute(L, "llvm.loop.unroll.enable")) {
623+
LC |= SPIRV::LoopControl::Unroll;
624+
}
625+
std::optional<int> Count =
626+
getOptionalIntLoopAttribute(L, "llvm.loop.unroll.count");
627+
if (Count && Count != 1) {
628+
LC |= SPIRV::LoopControl::PartialCount;
629+
MaskToValueMap.emplace_back(
630+
std::make_pair(SPIRV::LoopControl::PartialCount, *Count));
631+
}
632+
if (getBooleanLoopAttribute(L, "llvm.loop.unroll.full")) {
633+
// llvm.loop.unroll.full doesn't have a direct counterpart in SPIR-V,
634+
// the closest thing we can do is to add Unroll mask and if the trip
635+
// count is not known at compile time - either disable unrolling by
636+
// setting PartialCount to 1 or reuse already available PartialCount.
637+
LC |= SPIRV::LoopControl::Unroll;
638+
if ((LC & SPIRV::LoopControl::PartialCount) == 0) {
639+
LC |= SPIRV::LoopControl::PartialCount;
640+
MaskToValueMap.emplace_back(
641+
std::make_pair(SPIRV::LoopControl::PartialCount, 1));
642+
}
643+
}
644+
}
645+
Args.emplace_back(llvm::ConstantInt::get(Builder.getInt32Ty(), LC));
646+
for (auto &[Mask, Val] : MaskToValueMap)
647+
Args.emplace_back(llvm::ConstantInt::get(Builder.getInt32Ty(), Val));
614648

615649
Builder.CreateIntrinsic(Intrinsic::spv_loop_merge, {}, {Args});
616650
Modified = true;

llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ llvm::SplitKnownCriticalEdge(Instruction *TI, unsigned SuccNum,
175175
// Create our unconditional branch.
176176
BranchInst *NewBI = BranchInst::Create(DestBB, NewBB);
177177
NewBI->setDebugLoc(TI->getDebugLoc());
178+
if (auto *LoopMD = TI->getMetadata(LLVMContext::MD_loop))
179+
NewBI->setMetadata(LLVMContext::MD_loop, LoopMD);
178180

179181
// Insert the block into the function... right after the block TI lives in.
180182
Function &F = *TIBB->getParent();
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
; RUN: llc -mtriple=spirv-unknown-vulkan-compute -O0 -verify-machineinstrs %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: OpName %[[#For:]] "for_loop"
5+
; CHECK-DAG: OpName %[[#While:]] "while_loop"
6+
; CHECK-DAG: OpName %[[#DoWhile:]] "do_while_loop"
7+
; CHECK-DAG: OpName %[[#Disable:]] "unroll_disable"
8+
; CHECK-DAG: OpName %[[#Count:]] "unroll_count"
9+
; CHECK-DAG: OpName %[[#Full:]] "unroll_full"
10+
; CHECK-DAG: OpName %[[#FullCount:]] "unroll_full_count"
11+
12+
; CHECK: %[[#For]] = OpFunction
13+
; CHECK: OpLoopMerge %[[#]] %[[#]] Unroll
14+
15+
; CHECK: %[[#While]] = OpFunction
16+
; CHECK: OpLoopMerge %[[#]] %[[#]] Unroll
17+
18+
; CHECK: %[[#DoWhile]] = OpFunction
19+
; CHECK: OpLoopMerge %[[#]] %[[#]] Unroll
20+
21+
; CHECK: %[[#Disable]] = OpFunction
22+
; CHECK: OpLoopMerge %[[#]] %[[#]] DontUnroll
23+
24+
; CHECK: %[[#Count]] = OpFunction
25+
; CHECK: OpLoopMerge %[[#]] %[[#]] PartialCount 4
26+
27+
; CHECK: %[[#Full]] = OpFunction
28+
; CHECK: OpLoopMerge %[[#]] %[[#]] Unroll|PartialCount 1
29+
30+
; CHECK: %[[#FullCount]] = OpFunction
31+
; CHECK: OpLoopMerge %[[#]] %[[#]] Unroll|PartialCount 4
32+
33+
define dso_local void @for_loop(ptr noundef %0, i32 noundef %1) {
34+
%3 = alloca ptr, align 8
35+
%4 = alloca i32, align 4
36+
%5 = alloca i32, align 4
37+
store ptr %0, ptr %3, align 8
38+
store i32 %1, ptr %4, align 4
39+
store i32 0, ptr %5, align 4
40+
br label %6
41+
42+
6: ; preds = %15, %2
43+
%7 = load i32, ptr %5, align 4
44+
%8 = load i32, ptr %4, align 4
45+
%9 = icmp slt i32 %7, %8
46+
br i1 %9, label %10, label %18
47+
48+
10: ; preds = %6
49+
%11 = load i32, ptr %5, align 4
50+
%12 = load ptr, ptr %3, align 8
51+
%13 = load i32, ptr %12, align 4
52+
%14 = add nsw i32 %13, %11
53+
store i32 %14, ptr %12, align 4
54+
br label %15
55+
56+
15: ; preds = %10
57+
%16 = load i32, ptr %5, align 4
58+
%17 = add nsw i32 %16, 1
59+
store i32 %17, ptr %5, align 4
60+
br label %6, !llvm.loop !1
61+
62+
18: ; preds = %6
63+
ret void
64+
}
65+
66+
define dso_local void @while_loop(ptr noundef %0, i32 noundef %1) {
67+
%3 = alloca ptr, align 8
68+
%4 = alloca i32, align 4
69+
%5 = alloca i32, align 4
70+
store ptr %0, ptr %3, align 8
71+
store i32 %1, ptr %4, align 4
72+
store i32 0, ptr %5, align 4
73+
br label %6
74+
75+
6: ; preds = %10, %2
76+
%7 = load i32, ptr %5, align 4
77+
%8 = load i32, ptr %4, align 4
78+
%9 = icmp slt i32 %7, %8
79+
br i1 %9, label %10, label %17
80+
81+
10: ; preds = %6
82+
%11 = load i32, ptr %5, align 4
83+
%12 = load ptr, ptr %3, align 8
84+
%13 = load i32, ptr %12, align 4
85+
%14 = add nsw i32 %13, %11
86+
store i32 %14, ptr %12, align 4
87+
%15 = load i32, ptr %5, align 4
88+
%16 = add nsw i32 %15, 1
89+
store i32 %16, ptr %5, align 4
90+
br label %6, !llvm.loop !3
91+
92+
17: ; preds = %6
93+
ret void
94+
}
95+
96+
define dso_local void @do_while_loop(ptr noundef %0, i32 noundef %1) {
97+
%3 = alloca ptr, align 8
98+
%4 = alloca i32, align 4
99+
%5 = alloca i32, align 4
100+
store ptr %0, ptr %3, align 8
101+
store i32 %1, ptr %4, align 4
102+
store i32 0, ptr %5, align 4
103+
br label %6
104+
105+
6: ; preds = %13, %2
106+
%7 = load i32, ptr %5, align 4
107+
%8 = load ptr, ptr %3, align 8
108+
%9 = load i32, ptr %8, align 4
109+
%10 = add nsw i32 %9, %7
110+
store i32 %10, ptr %8, align 4
111+
%11 = load i32, ptr %5, align 4
112+
%12 = add nsw i32 %11, 1
113+
store i32 %12, ptr %5, align 4
114+
br label %13
115+
116+
13: ; preds = %6
117+
%14 = load i32, ptr %5, align 4
118+
%15 = load i32, ptr %4, align 4
119+
%16 = icmp slt i32 %14, %15
120+
br i1 %16, label %6, label %17, !llvm.loop !4
121+
122+
17: ; preds = %13
123+
ret void
124+
}
125+
126+
define dso_local void @unroll_disable(i32 noundef %0) {
127+
%2 = alloca i32, align 4
128+
%3 = alloca i32, align 4
129+
store i32 %0, ptr %2, align 4
130+
store i32 0, ptr %3, align 4
131+
br label %4
132+
133+
4: ; preds = %7, %1
134+
%5 = load i32, ptr %3, align 4
135+
%6 = add nsw i32 %5, 1
136+
store i32 %6, ptr %3, align 4
137+
br label %7
138+
139+
7: ; preds = %4
140+
%8 = load i32, ptr %3, align 4
141+
%9 = load i32, ptr %2, align 4
142+
%10 = icmp slt i32 %8, %9
143+
br i1 %10, label %4, label %11, !llvm.loop !5
144+
145+
11: ; preds = %7
146+
ret void
147+
}
148+
149+
define dso_local void @unroll_count(i32 noundef %0) {
150+
%2 = alloca i32, align 4
151+
%3 = alloca i32, align 4
152+
store i32 %0, ptr %2, align 4
153+
store i32 0, ptr %3, align 4
154+
br label %4
155+
156+
4: ; preds = %7, %1
157+
%5 = load i32, ptr %3, align 4
158+
%6 = add nsw i32 %5, 1
159+
store i32 %6, ptr %3, align 4
160+
br label %7
161+
162+
7: ; preds = %4
163+
%8 = load i32, ptr %3, align 4
164+
%9 = load i32, ptr %2, align 4
165+
%10 = icmp slt i32 %8, %9
166+
br i1 %10, label %4, label %11, !llvm.loop !7
167+
168+
11: ; preds = %7
169+
ret void
170+
}
171+
172+
define dso_local void @unroll_full(i32 noundef %0) {
173+
%2 = alloca i32, align 4
174+
%3 = alloca i32, align 4
175+
store i32 %0, ptr %2, align 4
176+
store i32 0, ptr %3, align 4
177+
br label %4
178+
179+
4: ; preds = %7, %1
180+
%5 = load i32, ptr %3, align 4
181+
%6 = add nsw i32 %5, 1
182+
store i32 %6, ptr %3, align 4
183+
br label %7
184+
185+
7: ; preds = %4
186+
%8 = load i32, ptr %3, align 4
187+
%9 = load i32, ptr %2, align 4
188+
%10 = icmp slt i32 %8, %9
189+
br i1 %10, label %4, label %11, !llvm.loop !9
190+
191+
11: ; preds = %7
192+
ret void
193+
}
194+
195+
define dso_local void @unroll_full_count(i32 noundef %0) {
196+
%2 = alloca i32, align 4
197+
%3 = alloca i32, align 4
198+
store i32 %0, ptr %2, align 4
199+
store i32 0, ptr %3, align 4
200+
br label %4
201+
202+
4: ; preds = %7, %1
203+
%5 = load i32, ptr %3, align 4
204+
%6 = add nsw i32 %5, 1
205+
store i32 %6, ptr %3, align 4
206+
br label %7
207+
208+
7: ; preds = %4
209+
%8 = load i32, ptr %3, align 4
210+
%9 = load i32, ptr %2, align 4
211+
%10 = icmp slt i32 %8, %9
212+
br i1 %10, label %4, label %11, !llvm.loop !11
213+
214+
11: ; preds = %7
215+
ret void
216+
}
217+
218+
!1 = distinct !{!1, !2}
219+
!2 = !{!"llvm.loop.unroll.enable"}
220+
!3 = distinct !{!3, !2}
221+
!4 = distinct !{!4, !2}
222+
!5 = distinct !{!5, !6}
223+
!6 = !{!"llvm.loop.unroll.disable"}
224+
!7 = distinct !{!7, !8}
225+
!8 = !{!"llvm.loop.unroll.count", i32 4}
226+
!9 = distinct !{!9, !10}
227+
!10 = !{!"llvm.loop.unroll.full"}
228+
!11 = distinct !{!11, !10, !8}

0 commit comments

Comments
 (0)