Skip to content

Commit abffc54

Browse files
authored
[X86][MemFold] Allow masked load folding if masks are equal (#161074)
Inspired by #160920#issuecomment-3341816198
1 parent cac0635 commit abffc54

File tree

3 files changed

+37
-5
lines changed

3 files changed

+37
-5
lines changed

llvm/lib/Target/X86/X86InstrAVX512.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3238,6 +3238,7 @@ multiclass avx512_load<bits<8> opc, string OpcodeStr, string Name,
32383238
(_.VT _.RC:$src1),
32393239
(_.VT _.RC:$src0))))], _.ExeDomain>,
32403240
EVEX, EVEX_K, Sched<[Sched.RR]>;
3241+
let mayLoad = 1, canFoldAsLoad = 1 in
32413242
def rmk : AVX512PI<opc, MRMSrcMem, (outs _.RC:$dst),
32423243
(ins _.RC:$src0, _.KRCWM:$mask, _.MemOp:$src1),
32433244
!strconcat(OpcodeStr, "\t{$src1, ${dst} {${mask}}|",
@@ -3248,6 +3249,7 @@ multiclass avx512_load<bits<8> opc, string OpcodeStr, string Name,
32483249
(_.VT _.RC:$src0))))], _.ExeDomain>,
32493250
EVEX, EVEX_K, Sched<[Sched.RM]>;
32503251
}
3252+
let mayLoad = 1, canFoldAsLoad = 1 in
32513253
def rmkz : AVX512PI<opc, MRMSrcMem, (outs _.RC:$dst),
32523254
(ins _.KRCWM:$mask, _.MemOp:$src),
32533255
OpcodeStr #"\t{$src, ${dst} {${mask}} {z}|"#

llvm/lib/Target/X86/X86InstrInfo.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8113,6 +8113,39 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
81138113
MachineBasicBlock::iterator InsertPt, MachineInstr &LoadMI,
81148114
LiveIntervals *LIS) const {
81158115

8116+
// If LoadMI is a masked load, check MI having the same mask.
8117+
const MCInstrDesc &MCID = get(LoadMI.getOpcode());
8118+
unsigned NumOps = MCID.getNumOperands();
8119+
if (NumOps >= 3) {
8120+
Register MaskReg;
8121+
const MachineOperand &Op1 = LoadMI.getOperand(1);
8122+
const MachineOperand &Op2 = LoadMI.getOperand(2);
8123+
8124+
auto IsVKWMClass = [](const TargetRegisterClass *RC) {
8125+
return RC == &X86::VK2WMRegClass || RC == &X86::VK4WMRegClass ||
8126+
RC == &X86::VK8WMRegClass || RC == &X86::VK16WMRegClass ||
8127+
RC == &X86::VK32WMRegClass || RC == &X86::VK64WMRegClass;
8128+
};
8129+
8130+
if (Op1.isReg() && IsVKWMClass(getRegClass(MCID, 1, &RI)))
8131+
MaskReg = Op1.getReg();
8132+
else if (Op2.isReg() && IsVKWMClass(getRegClass(MCID, 2, &RI)))
8133+
MaskReg = Op2.getReg();
8134+
8135+
if (MaskReg) {
8136+
bool HasSameMask = false;
8137+
for (unsigned I = 1, E = MI.getDesc().getNumOperands(); I < E; ++I) {
8138+
const MachineOperand &Op = MI.getOperand(I);
8139+
if (Op.isReg() && Op.getReg() == MaskReg) {
8140+
HasSameMask = true;
8141+
break;
8142+
}
8143+
}
8144+
if (!HasSameMask)
8145+
return nullptr;
8146+
}
8147+
}
8148+
81168149
// TODO: Support the case where LoadMI loads a wide register, but MI
81178150
// only uses a subreg.
81188151
for (auto Op : Ops) {
@@ -8121,7 +8154,6 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
81218154
}
81228155

81238156
// If loading from a FrameIndex, fold directly from the FrameIndex.
8124-
unsigned NumOps = LoadMI.getDesc().getNumOperands();
81258157
int FrameIndex;
81268158
if (isLoadFromStackSlot(LoadMI, FrameIndex)) {
81278159
if (isNonFoldablePartialRegisterLoad(LoadMI, MI, MF))

llvm/test/CodeGen/X86/avx512-mask-op.ll

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2119,8 +2119,7 @@ define void @ktest_1(<8 x double> %in, ptr %base) {
21192119
; KNL-LABEL: ktest_1:
21202120
; KNL: ## %bb.0:
21212121
; KNL-NEXT: vcmpgtpd (%rdi), %zmm0, %k1
2122-
; KNL-NEXT: vmovupd 8(%rdi), %zmm1 {%k1} {z}
2123-
; KNL-NEXT: vcmpltpd %zmm1, %zmm0, %k0 {%k1}
2122+
; KNL-NEXT: vcmpltpd 8(%rdi), %zmm0, %k0 {%k1}
21242123
; KNL-NEXT: kmovw %k0, %eax
21252124
; KNL-NEXT: testb %al, %al
21262125
; KNL-NEXT: je LBB44_2
@@ -2152,8 +2151,7 @@ define void @ktest_1(<8 x double> %in, ptr %base) {
21522151
; AVX512BW-LABEL: ktest_1:
21532152
; AVX512BW: ## %bb.0:
21542153
; AVX512BW-NEXT: vcmpgtpd (%rdi), %zmm0, %k1
2155-
; AVX512BW-NEXT: vmovupd 8(%rdi), %zmm1 {%k1} {z}
2156-
; AVX512BW-NEXT: vcmpltpd %zmm1, %zmm0, %k0 {%k1}
2154+
; AVX512BW-NEXT: vcmpltpd 8(%rdi), %zmm0, %k0 {%k1}
21572155
; AVX512BW-NEXT: kmovd %k0, %eax
21582156
; AVX512BW-NEXT: testb %al, %al
21592157
; AVX512BW-NEXT: je LBB44_2

0 commit comments

Comments
 (0)