diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp index 278ae46b8a5f5..0ba71ada8638e 100644 --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -854,6 +854,7 @@ class X86LowerAMXCast { : Func(F), SC(ShapeC), DT(nullptr) {} bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST); bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD); + bool combineTilezero(IntrinsicInst *Cast); bool combineLdSt(SmallVectorImpl &Casts); bool combineAMXcast(TargetLibraryInfo *TLI); bool transformAMXCast(IntrinsicInst *AMXCast); @@ -1175,6 +1176,26 @@ bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) { return EraseLoad; } +// %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer) +// --> +// %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col) +bool X86LowerAMXCast::combineTilezero(IntrinsicInst *Cast) { + Value *Row = nullptr, *Col = nullptr; + Use &U = *(Cast->use_begin()); + unsigned OpNo = U.getOperandNo(); + auto *II = cast(U.getUser()); + if (!isAMXIntrinsic(II)) + return false; + + std::tie(Row, Col) = SC->getShape(II, OpNo); + + IRBuilder<> Builder(Cast); + Value *NewInst = + Builder.CreateIntrinsic(Intrinsic::x86_tilezero_internal, {}, {Row, Col}); + Cast->replaceAllUsesWith(NewInst); + return true; +} + bool X86LowerAMXCast::combineLdSt(SmallVectorImpl &Casts) { bool Change = false; for (auto *Cast : Casts) { @@ -1198,6 +1219,14 @@ bool X86LowerAMXCast::combineLdSt(SmallVectorImpl &Casts) { for (auto *Store : DeadStores) Store->eraseFromParent(); } else { // x86_cast_vector_to_tile + // %19 = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> zeroinitializer) + // --> + // %19 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col) + if (isa(Cast->getOperand(0))) { + Change |= combineTilezero(cast(Cast)); + continue; + } + auto *Load = dyn_cast(Cast->getOperand(0)); if (!Load || !Load->hasOneUse()) continue; @@ -1210,6 +1239,7 @@ bool X86LowerAMXCast::combineLdSt(SmallVectorImpl &Casts) { // Set the operand is null so that load instruction can be erased. Cast->setOperand(0, nullptr); Load->eraseFromParent(); + Change = true; } } } diff --git a/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll b/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll index 6ef7219cfebdb..9cf7aab0b3655 100644 --- a/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll @@ -56,14 +56,9 @@ define void @PR90954(ptr %0, ptr %1, i32 %2) nounwind { ; CHECK-LABEL: PR90954: ; CHECK: # %bb.0: ; CHECK-NEXT: pushq %rbp -; CHECK-NEXT: movq %rsp, %rbp -; CHECK-NEXT: pushq %r15 ; CHECK-NEXT: pushq %r14 -; CHECK-NEXT: pushq %r13 -; CHECK-NEXT: pushq %r12 ; CHECK-NEXT: pushq %rbx -; CHECK-NEXT: andq $-1024, %rsp # imm = 0xFC00 -; CHECK-NEXT: subq $5120, %rsp # imm = 0x1400 +; CHECK-NEXT: subq $2912, %rsp # imm = 0xB60 ; CHECK-NEXT: vxorps %xmm0, %xmm0, %xmm0 ; CHECK-NEXT: vmovups %zmm0, {{[0-9]+}}(%rsp) ; CHECK-NEXT: movb $1, {{[0-9]+}}(%rsp) @@ -79,29 +74,26 @@ define void @PR90954(ptr %0, ptr %1, i32 %2) nounwind { ; CHECK-NEXT: movw $64, %cx ; CHECK-NEXT: movw $16, %di ; CHECK-NEXT: movb $1, %r8b -; CHECK-NEXT: movl $64, %r9d -; CHECK-NEXT: leaq {{[0-9]+}}(%rsp), %r10 -; CHECK-NEXT: leaq {{[0-9]+}}(%rsp), %r11 -; CHECK-NEXT: xorl %ebx, %ebx -; CHECK-NEXT: xorl %r14d, %r14d +; CHECK-NEXT: xorl %r9d, %r9d +; CHECK-NEXT: xorl %r10d, %r10d ; CHECK-NEXT: jmp .LBB1_1 ; CHECK-NEXT: .p2align 4 ; CHECK-NEXT: .LBB1_5: # in Loop: Header=BB1_1 Depth=1 -; CHECK-NEXT: incq %r14 -; CHECK-NEXT: addl %edx, %ebx +; CHECK-NEXT: incq %r10 +; CHECK-NEXT: addl %edx, %r9d ; CHECK-NEXT: .LBB1_1: # =>This Loop Header: Depth=1 ; CHECK-NEXT: # Child Loop BB1_2 Depth 2 -; CHECK-NEXT: movslq %ebx, %r15 -; CHECK-NEXT: leaq (%rsi,%r15,4), %r15 -; CHECK-NEXT: xorl %r12d, %r12d -; CHECK-NEXT: xorl %r13d, %r13d +; CHECK-NEXT: movslq %r9d, %r11 +; CHECK-NEXT: leaq (%rsi,%r11,4), %r11 +; CHECK-NEXT: xorl %ebx, %ebx +; CHECK-NEXT: xorl %r14d, %r14d ; CHECK-NEXT: jmp .LBB1_2 ; CHECK-NEXT: .p2align 4 ; CHECK-NEXT: .LBB1_4: # in Loop: Header=BB1_2 Depth=2 -; CHECK-NEXT: tilestored %tmm1, (%r15,%rax) -; CHECK-NEXT: incq %r13 -; CHECK-NEXT: addq $64, %r15 -; CHECK-NEXT: decq %r12 +; CHECK-NEXT: tilestored %tmm1, (%r11,%rax) +; CHECK-NEXT: incq %r14 +; CHECK-NEXT: addq $64, %r11 +; CHECK-NEXT: decq %rbx ; CHECK-NEXT: je .LBB1_5 ; CHECK-NEXT: .LBB1_2: # Parent Loop BB1_1 Depth=1 ; CHECK-NEXT: # => This Inner Loop Header: Depth=2 @@ -110,46 +102,12 @@ define void @PR90954(ptr %0, ptr %1, i32 %2) nounwind { ; CHECK-NEXT: testb %r8b, %r8b ; CHECK-NEXT: jne .LBB1_4 ; CHECK-NEXT: # %bb.3: # in Loop: Header=BB1_2 Depth=2 -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: tileloadd (%r10,%r9), %tmm1 -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: vmovaps %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: tileloadd (%r11,%r9), %tmm2 +; CHECK-NEXT: tilezero %tmm1 +; CHECK-NEXT: tilezero %tmm2 ; CHECK-NEXT: tdpbf16ps %tmm2, %tmm1, %tmm0 -; CHECK-NEXT: movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill -; CHECK-NEXT: movabsq $64, %rax -; CHECK-NEXT: tilestored %tmm0, 3072(%rsp,%rax) # 1024-byte Folded Spill -; CHECK-NEXT: tileloadd 3072(%rsp,%rax), %tmm1 # 1024-byte Folded Reload -; CHECK-NEXT: movq {{[-0-9]+}}(%r{{[sb]}}p), %rax # 8-byte Reload +; CHECK-NEXT: movabsq $64, %rbp +; CHECK-NEXT: tilestored %tmm0, 896(%rsp,%rbp) # 1024-byte Folded Spill +; CHECK-NEXT: tileloadd 896(%rsp,%rbp), %tmm1 # 1024-byte Folded Reload ; CHECK-NEXT: jmp .LBB1_4 %4 = shl i32 %2, 4 %5 = icmp eq i64 0, 0