Skip to content

Commit 487cdf1

Browse files
authored
[X86][AMX] Combine constant zero vector and AMX cast to tilezero (#92384)
Found this problem when investigating #91207
1 parent 79d1524 commit 487cdf1

File tree

2 files changed

+48
-60
lines changed

2 files changed

+48
-60
lines changed

llvm/lib/Target/X86/X86LowerAMXType.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,7 @@ class X86LowerAMXCast {
854854
: Func(F), SC(ShapeC), DT(nullptr) {}
855855
bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
856856
bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
857+
bool combineTilezero(IntrinsicInst *Cast);
857858
bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
858859
bool combineAMXcast(TargetLibraryInfo *TLI);
859860
bool transformAMXCast(IntrinsicInst *AMXCast);
@@ -1175,6 +1176,26 @@ bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
11751176
return EraseLoad;
11761177
}
11771178

1179+
// %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer)
1180+
// -->
1181+
// %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col)
1182+
bool X86LowerAMXCast::combineTilezero(IntrinsicInst *Cast) {
1183+
Value *Row = nullptr, *Col = nullptr;
1184+
Use &U = *(Cast->use_begin());
1185+
unsigned OpNo = U.getOperandNo();
1186+
auto *II = cast<IntrinsicInst>(U.getUser());
1187+
if (!isAMXIntrinsic(II))
1188+
return false;
1189+
1190+
std::tie(Row, Col) = SC->getShape(II, OpNo);
1191+
1192+
IRBuilder<> Builder(Cast);
1193+
Value *NewInst =
1194+
Builder.CreateIntrinsic(Intrinsic::x86_tilezero_internal, {}, {Row, Col});
1195+
Cast->replaceAllUsesWith(NewInst);
1196+
return true;
1197+
}
1198+
11781199
bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
11791200
bool Change = false;
11801201
for (auto *Cast : Casts) {
@@ -1198,6 +1219,14 @@ bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
11981219
for (auto *Store : DeadStores)
11991220
Store->eraseFromParent();
12001221
} else { // x86_cast_vector_to_tile
1222+
// %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer)
1223+
// -->
1224+
// %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col)
1225+
if (isa<ConstantAggregateZero>(Cast->getOperand(0))) {
1226+
Change |= combineTilezero(cast<IntrinsicInst>(Cast));
1227+
continue;
1228+
}
1229+
12011230
auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
12021231
if (!Load || !Load->hasOneUse())
12031232
continue;
@@ -1210,6 +1239,7 @@ bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
12101239
// Set the operand is null so that load instruction can be erased.
12111240
Cast->setOperand(0, nullptr);
12121241
Load->eraseFromParent();
1242+
Change = true;
12131243
}
12141244
}
12151245
}

llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll

Lines changed: 18 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,9 @@ define void @PR90954(ptr %0, ptr %1, i32 %2) nounwind {
5656
; CHECK-LABEL: PR90954:
5757
; CHECK: # %bb.0:
5858
; CHECK-NEXT: pushq %rbp
59-
; CHECK-NEXT: movq %rsp, %rbp
60-
; CHECK-NEXT: pushq %r15
6159
; CHECK-NEXT: pushq %r14
62-
; CHECK-NEXT: pushq %r13
63-
; CHECK-NEXT: pushq %r12
6460
; CHECK-NEXT: pushq %rbx
65-
; CHECK-NEXT: andq $-1024, %rsp # imm = 0xFC00
66-
; CHECK-NEXT: subq $5120, %rsp # imm = 0x1400
61+
; CHECK-NEXT: subq $2912, %rsp # imm = 0xB60
6762
; CHECK-NEXT: vxorps %xmm0, %xmm0, %xmm0
6863
; CHECK-NEXT: vmovups %zmm0, {{[0-9]+}}(%rsp)
6964
; CHECK-NEXT: movb $1, {{[0-9]+}}(%rsp)
@@ -79,29 +74,26 @@ define void @PR90954(ptr %0, ptr %1, i32 %2) nounwind {
7974
; CHECK-NEXT: movw $64, %cx
8075
; CHECK-NEXT: movw $16, %di
8176
; CHECK-NEXT: movb $1, %r8b
82-
; CHECK-NEXT: movl $64, %r9d
83-
; CHECK-NEXT: leaq {{[0-9]+}}(%rsp), %r10
84-
; CHECK-NEXT: leaq {{[0-9]+}}(%rsp), %r11
85-
; CHECK-NEXT: xorl %ebx, %ebx
86-
; CHECK-NEXT: xorl %r14d, %r14d
77+
; CHECK-NEXT: xorl %r9d, %r9d
78+
; CHECK-NEXT: xorl %r10d, %r10d
8779
; CHECK-NEXT: jmp .LBB1_1
8880
; CHECK-NEXT: .p2align 4
8981
; CHECK-NEXT: .LBB1_5: # in Loop: Header=BB1_1 Depth=1
90-
; CHECK-NEXT: incq %r14
91-
; CHECK-NEXT: addl %edx, %ebx
82+
; CHECK-NEXT: incq %r10
83+
; CHECK-NEXT: addl %edx, %r9d
9284
; CHECK-NEXT: .LBB1_1: # =>This Loop Header: Depth=1
9385
; CHECK-NEXT: # Child Loop BB1_2 Depth 2
94-
; CHECK-NEXT: movslq %ebx, %r15
95-
; CHECK-NEXT: leaq (%rsi,%r15,4), %r15
96-
; CHECK-NEXT: xorl %r12d, %r12d
97-
; CHECK-NEXT: xorl %r13d, %r13d
86+
; CHECK-NEXT: movslq %r9d, %r11
87+
; CHECK-NEXT: leaq (%rsi,%r11,4), %r11
88+
; CHECK-NEXT: xorl %ebx, %ebx
89+
; CHECK-NEXT: xorl %r14d, %r14d
9890
; CHECK-NEXT: jmp .LBB1_2
9991
; CHECK-NEXT: .p2align 4
10092
; CHECK-NEXT: .LBB1_4: # in Loop: Header=BB1_2 Depth=2
101-
; CHECK-NEXT: tilestored %tmm1, (%r15,%rax)
102-
; CHECK-NEXT: incq %r13
103-
; CHECK-NEXT: addq $64, %r15
104-
; CHECK-NEXT: decq %r12
93+
; CHECK-NEXT: tilestored %tmm1, (%r11,%rax)
94+
; CHECK-NEXT: incq %r14
95+
; CHECK-NEXT: addq $64, %r11
96+
; CHECK-NEXT: decq %rbx
10597
; CHECK-NEXT: je .LBB1_5
10698
; CHECK-NEXT: .LBB1_2: # Parent Loop BB1_1 Depth=1
10799
; CHECK-NEXT: # => This Inner Loop Header: Depth=2
@@ -110,46 +102,12 @@ define void @PR90954(ptr %0, ptr %1, i32 %2) nounwind {
110102
; CHECK-NEXT: testb %r8b, %r8b
111103
; CHECK-NEXT: jne .LBB1_4
112104
; CHECK-NEXT: # %bb.3: # in Loop: Header=BB1_2 Depth=2
113-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
114-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
115-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
116-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
117-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
118-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
119-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
120-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
121-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
122-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
123-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
124-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
125-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
126-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
127-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
128-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
129-
; CHECK-NEXT: tileloadd (%r10,%r9), %tmm1
130-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
131-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
132-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
133-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
134-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
135-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
136-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
137-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
138-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
139-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
140-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
141-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
142-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
143-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
144-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
145-
; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp)
146-
; CHECK-NEXT: tileloadd (%r11,%r9), %tmm2
105+
; CHECK-NEXT: tilezero %tmm1
106+
; CHECK-NEXT: tilezero %tmm2
147107
; CHECK-NEXT: tdpbf16ps %tmm2, %tmm1, %tmm0
148-
; CHECK-NEXT: movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill
149-
; CHECK-NEXT: movabsq $64, %rax
150-
; CHECK-NEXT: tilestored %tmm0, 3072(%rsp,%rax) # 1024-byte Folded Spill
151-
; CHECK-NEXT: tileloadd 3072(%rsp,%rax), %tmm1 # 1024-byte Folded Reload
152-
; CHECK-NEXT: movq {{[-0-9]+}}(%r{{[sb]}}p), %rax # 8-byte Reload
108+
; CHECK-NEXT: movabsq $64, %rbp
109+
; CHECK-NEXT: tilestored %tmm0, 896(%rsp,%rbp) # 1024-byte Folded Spill
110+
; CHECK-NEXT: tileloadd 896(%rsp,%rbp), %tmm1 # 1024-byte Folded Reload
153111
; CHECK-NEXT: jmp .LBB1_4
154112
%4 = shl i32 %2, 4
155113
%5 = icmp eq i64 0, 0

0 commit comments

Comments
 (0)