Skip to content

Commit 016092d

Browse files
committed
Reapply "[X86][AMX] Try to hoist AMX shapes' def"
We request no intersections between AMX instructions and their shapes' def when we insert ldtilecfg. However, this is not always ture resulting from not only users don't follow AMX API model, but also optimizations. This patch adds a mechanism that tries to hoist AMX shapes' def as well. It only hoists shapes inside a BB, we can improve it for cases across BBs in future. Currently, it only hoists shapes of which all sources' def above the first AMX instruction. We can improve for the case that only source that moves an immediate value to a register below AMX instruction. Reviewed By: xiangzhangllvm Differential Revision: https://reviews.llvm.org/D101067
1 parent 9360430 commit 016092d

File tree

2 files changed

+67
-17
lines changed

2 files changed

+67
-17
lines changed

llvm/lib/Target/X86/X86PreTileConfig.cpp

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ struct MIRef {
5757
++I, ++Pos)
5858
MI = &*I;
5959
}
60+
MIRef(MachineInstr *MI)
61+
: MI(MI), MBB(MI->getParent()),
62+
Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
6063
MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
6164
: MI(MI), MBB(MBB),
6265
Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
@@ -66,6 +69,7 @@ struct MIRef {
6669
bool operator==(const MIRef &RHS) const {
6770
return MI == RHS.MI && MBB == RHS.MBB;
6871
}
72+
bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
6973
bool operator<(const MIRef &RHS) const {
7074
return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos);
7175
}
@@ -77,7 +81,7 @@ struct MIRef {
7781
struct BBInfo {
7882
MIRef FirstAMX;
7983
MIRef LastCall;
80-
MIRef LastShape;
84+
bool HasAMXRegLiveIn = false;
8185
bool TileCfgForbidden = false;
8286
bool NeedTileCfgLiveIn = false;
8387
};
@@ -86,8 +90,8 @@ class X86PreTileConfig : public MachineFunctionPass {
8690
MachineRegisterInfo *MRI;
8791
const MachineLoopInfo *MLI;
8892
SmallSet<MachineInstr *, 8> DefVisited;
89-
SmallSet<MachineBasicBlock *, 8> ShapeBBs;
9093
DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
94+
DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;
9195

9296
/// Check if the callee will clobber AMX registers.
9397
bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
@@ -124,6 +128,32 @@ class X86PreTileConfig : public MachineFunctionPass {
124128
/// Collect the shape def information for later use.
125129
void collectShapeInfo(MachineInstr &MI);
126130

131+
/// Try to hoist shapes definded below AMX instructions.
132+
bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) {
133+
MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX;
134+
auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX);
135+
auto InsertPoint = FirstAMX.MI->getIterator();
136+
for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) {
137+
// Do not hoist instructions that access memory.
138+
if (I->MI->mayLoadOrStore())
139+
return false;
140+
for (auto &MO : I->MI->operands()) {
141+
if (MO.isDef())
142+
continue;
143+
// Do not hoist instructions if the sources' def under AMX instruction.
144+
// TODO: We can handle isMoveImmediate MI here.
145+
if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX)
146+
return false;
147+
// TODO: Maybe need more checks here.
148+
}
149+
MBB->insert(InsertPoint, I->MI->removeFromParent());
150+
}
151+
// We only need to mark the last shape in the BB now.
152+
Shapes.clear();
153+
Shapes.push_back(MIRef(&*--InsertPoint, MBB));
154+
return true;
155+
}
156+
127157
public:
128158
X86PreTileConfig() : MachineFunctionPass(ID) {}
129159

@@ -165,9 +195,9 @@ INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
165195
void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
166196
auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
167197
MIRef MIR(MI, MBB);
168-
if (BBVisitedInfo[MBB].LastShape < MIR)
169-
BBVisitedInfo[MBB].LastShape = MIR;
170-
ShapeBBs.insert(MBB);
198+
auto I = llvm::lower_bound(ShapeBBs[MBB], MIR);
199+
if (I == ShapeBBs[MBB].end() || *I != MIR)
200+
ShapeBBs[MBB].insert(I, MIR);
171201
};
172202

173203
SmallVector<Register, 8> WorkList(
@@ -229,6 +259,10 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
229259
else
230260
CfgLiveInBBs.push_back(&MBB);
231261
}
262+
if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn)
263+
for (auto *Succ : MBB.successors())
264+
if (!isLoopBackEdge(Succ, &MBB))
265+
BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
232266
}
233267

234268
// Update NeedTileCfgLiveIn for predecessors.
@@ -252,8 +286,17 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
252286
return false;
253287

254288
// Avoid to insert ldtilecfg before any shape defs.
255-
SmallVector<MachineBasicBlock *, 8> WorkList(
256-
make_range(ShapeBBs.begin(), ShapeBBs.end()));
289+
SmallVector<MachineBasicBlock *, 8> WorkList;
290+
for (auto &I : ShapeBBs) {
291+
// TODO: We can hoist shapes across BBs here.
292+
if (BBVisitedInfo[I.first].HasAMXRegLiveIn)
293+
REPORT_CONFIG_FAIL
294+
if (BBVisitedInfo[I.first].FirstAMX &&
295+
BBVisitedInfo[I.first].FirstAMX < I.second.back() &&
296+
!hoistShapesInBB(I.first, I.second))
297+
REPORT_CONFIG_FAIL
298+
WorkList.push_back(I.first);
299+
}
257300
while (!WorkList.empty()) {
258301
MachineBasicBlock *MBB = WorkList.pop_back_val();
259302
for (auto *Pred : MBB->predecessors()) {
@@ -282,9 +325,6 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
282325
} else {
283326
// Avoid the BB to be multi visited.
284327
VisitedOrInserted.insert(I);
285-
// We cannot sink it across any AMX instruction.
286-
if (BBVisitedInfo[I.MBB].FirstAMX)
287-
REPORT_CONFIG_FAIL;
288328
// Sink the inserting point along the chain with NeedTileCfgLiveIn =
289329
// true when MBB isn't all shapes reachable.
290330
for (auto *Succ : I.MBB->successors())
@@ -296,14 +336,9 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
296336

297337
// A given point might be forked due to shape conditions are not met.
298338
for (MIRef I : InsertPoints) {
299-
// Even MBB is all shapes reachable, we still need to check if there's
300-
// AMX that intersects with shapes in the same MBB.
301-
if (BBVisitedInfo[I.MBB].FirstAMX &&
302-
BBVisitedInfo[I.MBB].FirstAMX < BBVisitedInfo[I.MBB].LastShape)
303-
REPORT_CONFIG_FAIL;
304339
// Make sure we insert ldtilecfg after the last shape def in MBB.
305-
if (I < BBVisitedInfo[I.MBB].LastShape)
306-
I = BBVisitedInfo[I.MBB].LastShape;
340+
if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back())
341+
I = ShapeBBs[I.MBB].back();
307342
// There're chances the MBB is sunk more than once. Record it to avoid
308343
// multi insert.
309344
if (VisitedOrInserted.insert(I).second) {

llvm/test/CodeGen/X86/AMX/amx-sched.ll

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
define <256 x i32> @test_shape_sched(i16 %m, i16 %n, i16 %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b) nounwind {
44
; Just to make sure shape def is not scheduled across ldtilecfg.
5+
; CHECK-LABEL: test_shape_sched:
56
; CHECK: ldtilecfg
67
; CHECK-NOT: movw
78
%c1 = bitcast <256 x i32> %c to x86_amx
@@ -12,5 +13,19 @@ define <256 x i32> @test_shape_sched(i16 %m, i16 %n, i16 %k, <256 x i32> %c, <25
1213
ret <256 x i32> %res
1314
}
1415

16+
define <256 x i32> @test_shape_sched2(i16 %m, i16 %n, i16 %k, i8* %c, i8* %a, i8* %b) nounwind {
17+
; Just to make sure shape def is not scheduled across ldtilecfg.
18+
; CHECK-LABEL: test_shape_sched2:
19+
; CHECK: ldtilecfg
20+
; CHECK-NOT: movw
21+
%aa = lshr i16 %k, 2
22+
%c1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %n, i8* %c, i64 64)
23+
%a1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %k, i8* %a, i64 64)
24+
%b1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %aa, i16 %n, i8* %b, i64 64)
25+
%t = call x86_amx @llvm.x86.tdpbssd.internal(i16 %m, i16 %n, i16 %k, x86_amx %c1, x86_amx %a1, x86_amx %b1)
26+
%res = bitcast x86_amx %t to <256 x i32>
27+
ret <256 x i32> %res
28+
}
1529

30+
declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
1631
declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)

0 commit comments

Comments
 (0)