-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[X86][MemFold] Allow masked load folding if masks are equal #161074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-backend-x86 Author: Phoebe Wang (phoebewang) ChangesInspired by #160920#issuecomment-3341816198 Full diff: https://github.com/llvm/llvm-project/pull/161074.diff 3 Files Affected:
diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index b8f299965faa3..2371ed4ed14a1 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -3238,6 +3238,7 @@ multiclass avx512_load<bits<8> opc, string OpcodeStr, string Name,
(_.VT _.RC:$src1),
(_.VT _.RC:$src0))))], _.ExeDomain>,
EVEX, EVEX_K, Sched<[Sched.RR]>;
+ let mayLoad = 1, canFoldAsLoad = 1 in
def rmk : AVX512PI<opc, MRMSrcMem, (outs _.RC:$dst),
(ins _.RC:$src0, _.KRCWM:$mask, _.MemOp:$src1),
!strconcat(OpcodeStr, "\t{$src1, ${dst} {${mask}}|",
@@ -3248,6 +3249,7 @@ multiclass avx512_load<bits<8> opc, string OpcodeStr, string Name,
(_.VT _.RC:$src0))))], _.ExeDomain>,
EVEX, EVEX_K, Sched<[Sched.RM]>;
}
+ let mayLoad = 1, canFoldAsLoad = 1 in
def rmkz : AVX512PI<opc, MRMSrcMem, (outs _.RC:$dst),
(ins _.KRCWM:$mask, _.MemOp:$src),
OpcodeStr #"\t{$src, ${dst} {${mask}} {z}|"#
diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp
index 03ac1d3ca5d89..f8066ab12080b 100644
--- a/llvm/lib/Target/X86/X86InstrInfo.cpp
+++ b/llvm/lib/Target/X86/X86InstrInfo.cpp
@@ -8113,6 +8113,39 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
MachineBasicBlock::iterator InsertPt, MachineInstr &LoadMI,
LiveIntervals *LIS) const {
+ Register MaskReg;
+ const MCInstrDesc &MCID = get(LoadMI.getOpcode());
+ unsigned NumOps = MCID.getNumOperands();
+ if (NumOps >= 3) {
+ const MachineOperand &Op1 = LoadMI.getOperand(1);
+ const MachineOperand &Op2 = LoadMI.getOperand(2);
+
+ auto IsVKWMClass = [](const TargetRegisterClass *RC) {
+ return RC == &X86::VK2WMRegClass || RC == &X86::VK4WMRegClass ||
+ RC == &X86::VK8WMRegClass || RC == &X86::VK16WMRegClass ||
+ RC == &X86::VK32WMRegClass || RC == &X86::VK64WMRegClass;
+ };
+
+ if (Op1.isReg() && IsVKWMClass(getRegClass(MCID, 1, &RI)))
+ MaskReg = Op1.getReg();
+ else if (Op2.isReg() && IsVKWMClass(getRegClass(MCID, 2, &RI)))
+ MaskReg = Op2.getReg();
+
+ if (MaskReg) {
+ errs() << "MaskReg = " << MaskReg << '\n';
+ bool HasSameMask = false;
+ for (unsigned I = 1, E = MI.getDesc().getNumOperands(); I < E; ++I) {
+ const MachineOperand &Op = MI.getOperand(I);
+ if (Op.isReg() && Op.getReg() == MaskReg) {
+ HasSameMask = true;
+ break;
+ }
+ }
+ if (!HasSameMask)
+ return nullptr;
+ }
+ }
+
// TODO: Support the case where LoadMI loads a wide register, but MI
// only uses a subreg.
for (auto Op : Ops) {
@@ -8121,7 +8154,6 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
}
// If loading from a FrameIndex, fold directly from the FrameIndex.
- unsigned NumOps = LoadMI.getDesc().getNumOperands();
int FrameIndex;
if (isLoadFromStackSlot(LoadMI, FrameIndex)) {
if (isNonFoldablePartialRegisterLoad(LoadMI, MI, MF))
diff --git a/llvm/test/CodeGen/X86/avx512-mask-op.ll b/llvm/test/CodeGen/X86/avx512-mask-op.ll
index 8aa898f3ec576..da0cef0e4e99b 100644
--- a/llvm/test/CodeGen/X86/avx512-mask-op.ll
+++ b/llvm/test/CodeGen/X86/avx512-mask-op.ll
@@ -2119,8 +2119,7 @@ define void @ktest_1(<8 x double> %in, ptr %base) {
; KNL-LABEL: ktest_1:
; KNL: ## %bb.0:
; KNL-NEXT: vcmpgtpd (%rdi), %zmm0, %k1
-; KNL-NEXT: vmovupd 8(%rdi), %zmm1 {%k1} {z}
-; KNL-NEXT: vcmpltpd %zmm1, %zmm0, %k0 {%k1}
+; KNL-NEXT: vcmpltpd 8(%rdi), %zmm0, %k0 {%k1}
; KNL-NEXT: kmovw %k0, %eax
; KNL-NEXT: testb %al, %al
; KNL-NEXT: je LBB44_2
@@ -2152,8 +2151,7 @@ define void @ktest_1(<8 x double> %in, ptr %base) {
; AVX512BW-LABEL: ktest_1:
; AVX512BW: ## %bb.0:
; AVX512BW-NEXT: vcmpgtpd (%rdi), %zmm0, %k1
-; AVX512BW-NEXT: vmovupd 8(%rdi), %zmm1 {%k1} {z}
-; AVX512BW-NEXT: vcmpltpd %zmm1, %zmm0, %k0 {%k1}
+; AVX512BW-NEXT: vcmpltpd 8(%rdi), %zmm0, %k0 {%k1}
; AVX512BW-NEXT: kmovd %k0, %eax
; AVX512BW-NEXT: testb %al, %al
; AVX512BW-NEXT: je LBB44_2
|
Inspired by llvm#160920#issuecomment-3341816198
6a6d188 to
2be4ef0
Compare
RKSimon
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with one minor
llvm/lib/Target/X86/X86InstrInfo.cpp
Outdated
| MachineBasicBlock::iterator InsertPt, MachineInstr &LoadMI, | ||
| LiveIntervals *LIS) const { | ||
|
|
||
| Register MaskReg; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a suitable summary comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
) Inspired by llvm#160920#issuecomment-3341816198
Inspired by #160920#issuecomment-3341816198