Skip to content

Commit 6449bea

Browse files
committed
[RISCV] Select unmasked RVV pseudos in a DAG post-process
This patch drops TableGen patterns matching all-ones masked RVV pseudos in the case where there are fallback patterns matching the generic masked forms to "_MASK" pseudos. This optimization is now performed with a SelectionDAG post-processing step which peephole-optimizes these same pseudos with all-ones masks and swaps them out to their unmasked pseudos. This cuts our generated ISel table down by around ~5% (~110kB) in lieu of a far smaller auto-generated table to help with the peephole. This only targets our custom RISCVISD::*_VL binary operator nodes, which use the one form for both masked and unmasked variants. A similar approach could be used for our intrinsics but we'd need to do some work, e.g., to represent unmasked intrinsics as true-masked intrinsics at the IR or ISel level. At a rough estimate, this could save us a further 9% on the size of our ISel table for the binary intrinsic patterns alone. There is no observable impact on our tests. Reviewed By: craig.topper Differential Revision: https://reviews.llvm.org/D118810
1 parent 4db88a5 commit 6449bea

File tree

4 files changed

+122
-44
lines changed

4 files changed

+122
-44
lines changed

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ namespace RISCV {
3737
#define GET_RISCVVSETable_IMPL
3838
#define GET_RISCVVLXTable_IMPL
3939
#define GET_RISCVVSXTable_IMPL
40+
#define GET_RISCVMaskedPseudosTable_IMPL
4041
#include "RISCVGenSearchableTables.inc"
4142
} // namespace RISCV
4243
} // namespace llvm
@@ -123,6 +124,7 @@ void RISCVDAGToDAGISel::PostprocessISelDAG() {
123124

124125
MadeChange |= doPeepholeSExtW(N);
125126
MadeChange |= doPeepholeLoadStoreADDI(N);
127+
MadeChange |= doPeepholeMaskedRVV(N);
126128
}
127129

128130
if (MadeChange)
@@ -2133,6 +2135,102 @@ bool RISCVDAGToDAGISel::doPeepholeSExtW(SDNode *N) {
21332135
return false;
21342136
}
21352137

2138+
// Optimize masked RVV pseudo instructions with a known all-ones mask to their
2139+
// corresponding "unmasked" pseudo versions. The mask we're interested in will
2140+
// take the form of a V0 physical register operand, with a glued
2141+
// register-setting instruction.
2142+
bool RISCVDAGToDAGISel::doPeepholeMaskedRVV(SDNode *N) {
2143+
const RISCV::RISCVMaskedPseudoInfo *I =
2144+
RISCV::getMaskedPseudoInfo(N->getMachineOpcode());
2145+
if (!I)
2146+
return false;
2147+
2148+
unsigned MaskOpIdx = I->MaskOpIdx;
2149+
2150+
// Check that we're using V0 as a mask register.
2151+
if (!isa<RegisterSDNode>(N->getOperand(MaskOpIdx)) ||
2152+
cast<RegisterSDNode>(N->getOperand(MaskOpIdx))->getReg() != RISCV::V0)
2153+
return false;
2154+
2155+
// The glued user defines V0.
2156+
const auto *Glued = N->getGluedNode();
2157+
2158+
if (!Glued || Glued->getOpcode() != ISD::CopyToReg)
2159+
return false;
2160+
2161+
// Check that we're defining V0 as a mask register.
2162+
if (!isa<RegisterSDNode>(Glued->getOperand(1)) ||
2163+
cast<RegisterSDNode>(Glued->getOperand(1))->getReg() != RISCV::V0)
2164+
return false;
2165+
2166+
// Check the instruction defining V0; it needs to be a VMSET pseudo.
2167+
SDValue MaskSetter = Glued->getOperand(2);
2168+
2169+
const auto IsVMSet = [](unsigned Opc) {
2170+
return Opc == RISCV::PseudoVMSET_M_B1 || Opc == RISCV::PseudoVMSET_M_B16 ||
2171+
Opc == RISCV::PseudoVMSET_M_B2 || Opc == RISCV::PseudoVMSET_M_B32 ||
2172+
Opc == RISCV::PseudoVMSET_M_B4 || Opc == RISCV::PseudoVMSET_M_B64 ||
2173+
Opc == RISCV::PseudoVMSET_M_B8;
2174+
};
2175+
2176+
// TODO: Check that the VMSET is the expected bitwidth? The pseudo has
2177+
// undefined behaviour if it's the wrong bitwidth, so we could choose to
2178+
// assume that it's all-ones? Same applies to its VL.
2179+
if (!MaskSetter->isMachineOpcode() || !IsVMSet(MaskSetter.getMachineOpcode()))
2180+
return false;
2181+
2182+
// Retrieve the tail policy operand index, if any.
2183+
Optional<unsigned> TailPolicyOpIdx;
2184+
const RISCVInstrInfo *TII = static_cast<const RISCVInstrInfo *>(
2185+
CurDAG->getSubtarget().getInstrInfo());
2186+
2187+
const MCInstrDesc &MaskedMCID = TII->get(N->getMachineOpcode());
2188+
2189+
if (RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags)) {
2190+
// The last operand of the pseudo is the policy op, but we're expecting a
2191+
// Glue operand last. We may also have a chain.
2192+
TailPolicyOpIdx = N->getNumOperands() - 1;
2193+
if (N->getOperand(*TailPolicyOpIdx).getValueType() == MVT::Glue)
2194+
(*TailPolicyOpIdx)--;
2195+
if (N->getOperand(*TailPolicyOpIdx).getValueType() == MVT::Other)
2196+
(*TailPolicyOpIdx)--;
2197+
2198+
// If the policy isn't TAIL_AGNOSTIC we can't perform this optimization.
2199+
if (N->getConstantOperandVal(*TailPolicyOpIdx) != RISCVII::TAIL_AGNOSTIC)
2200+
return false;
2201+
}
2202+
2203+
const MCInstrDesc &UnmaskedMCID = TII->get(I->UnmaskedPseudo);
2204+
2205+
// Check that we're dropping the merge operand, the mask operand, and any
2206+
// policy operand when we transform to this unmasked pseudo.
2207+
assert(!RISCVII::hasMergeOp(UnmaskedMCID.TSFlags) &&
2208+
RISCVII::hasDummyMaskOp(UnmaskedMCID.TSFlags) &&
2209+
!RISCVII::hasVecPolicyOp(UnmaskedMCID.TSFlags) &&
2210+
"Unexpected pseudo to transform to");
2211+
2212+
SmallVector<SDValue, 8> Ops;
2213+
// Skip the merge operand at index 0.
2214+
for (unsigned I = 1, E = N->getNumOperands(); I != E; I++) {
2215+
// Skip the mask, the policy, and the Glue.
2216+
SDValue Op = N->getOperand(I);
2217+
if (I == MaskOpIdx || I == TailPolicyOpIdx ||
2218+
Op.getValueType() == MVT::Glue)
2219+
continue;
2220+
Ops.push_back(Op);
2221+
}
2222+
2223+
// Transitively apply any node glued to our new node.
2224+
if (auto *TGlued = Glued->getGluedNode())
2225+
Ops.push_back(SDValue(TGlued, TGlued->getNumValues() - 1));
2226+
2227+
SDNode *Result =
2228+
CurDAG->getMachineNode(I->UnmaskedPseudo, SDLoc(N), N->getVTList(), Ops);
2229+
ReplaceUses(N, Result);
2230+
2231+
return true;
2232+
}
2233+
21362234
// This pass converts a legalized DAG into a RISCV-specific DAG, ready
21372235
// for instruction scheduling.
21382236
FunctionPass *llvm::createRISCVISelDag(RISCVTargetMachine &TM) {

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ class RISCVDAGToDAGISel : public SelectionDAGISel {
117117
private:
118118
bool doPeepholeLoadStoreADDI(SDNode *Node);
119119
bool doPeepholeSExtW(SDNode *Node);
120+
bool doPeepholeMaskedRVV(SDNode *Node);
120121
};
121122

122123
namespace RISCV {
@@ -187,6 +188,12 @@ struct VLX_VSXPseudo {
187188
uint16_t Pseudo;
188189
};
189190

191+
struct RISCVMaskedPseudoInfo {
192+
uint16_t MaskedPseudo;
193+
uint16_t UnmaskedPseudo;
194+
uint8_t MaskOpIdx;
195+
};
196+
190197
#define GET_RISCVVSSEGTable_DECL
191198
#define GET_RISCVVLSEGTable_DECL
192199
#define GET_RISCVVLXSEGTable_DECL
@@ -195,6 +202,7 @@ struct VLX_VSXPseudo {
195202
#define GET_RISCVVSETable_DECL
196203
#define GET_RISCVVLXTable_DECL
197204
#define GET_RISCVVSXTable_DECL
205+
#define GET_RISCVMaskedPseudosTable_DECL
198206
#include "RISCVGenSearchableTables.inc"
199207
} // namespace RISCV
200208

llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,20 @@ def RISCVVIntrinsicsTable : GenericTable {
424424
let PrimaryKeyName = "getRISCVVIntrinsicInfo";
425425
}
426426

427+
class RISCVMaskedPseudo<bits<4> MaskIdx> {
428+
Pseudo MaskedPseudo = !cast<Pseudo>(NAME);
429+
Pseudo UnmaskedPseudo = !cast<Pseudo>(!subst("_MASK", "", NAME));
430+
bits<4> MaskOpIdx = MaskIdx;
431+
}
432+
433+
def RISCVMaskedPseudosTable : GenericTable {
434+
let FilterClass = "RISCVMaskedPseudo";
435+
let CppTypeName = "RISCVMaskedPseudoInfo";
436+
let Fields = ["MaskedPseudo", "UnmaskedPseudo", "MaskOpIdx"];
437+
let PrimaryKey = ["MaskedPseudo"];
438+
let PrimaryKeyName = "getMaskedPseudoInfo";
439+
}
440+
427441
class RISCVVLE<bit M, bit TU, bit Str, bit F, bits<3> S, bits<3> L> {
428442
bits<1> Masked = M;
429443
bits<1> IsTU = TU;
@@ -1639,7 +1653,8 @@ multiclass VPseudoBinary<VReg RetClass,
16391653
def "_" # MInfo.MX : VPseudoBinaryNoMask<RetClass, Op1Class, Op2Class,
16401654
Constraint>;
16411655
def "_" # MInfo.MX # "_MASK" : VPseudoBinaryMaskTA<RetClass, Op1Class, Op2Class,
1642-
Constraint>;
1656+
Constraint>,
1657+
RISCVMaskedPseudo</*MaskOpIdx*/ 3>;
16431658
}
16441659
}
16451660

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -309,15 +309,6 @@ multiclass VPatBinaryVL_V<SDNode vop,
309309
LMULInfo vlmul,
310310
VReg op1_reg_class,
311311
VReg op2_reg_class> {
312-
def : Pat<(result_type (vop
313-
(op1_type op1_reg_class:$rs1),
314-
(op2_type op2_reg_class:$rs2),
315-
(mask_type true_mask),
316-
VLOpFrag)),
317-
(!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX)
318-
op1_reg_class:$rs1,
319-
op2_reg_class:$rs2,
320-
GPR:$vl, sew)>;
321312
def : Pat<(result_type (vop
322313
(op1_type op1_reg_class:$rs1),
323314
(op2_type op2_reg_class:$rs2),
@@ -342,15 +333,6 @@ multiclass VPatBinaryVL_XI<SDNode vop,
342333
VReg vop_reg_class,
343334
ComplexPattern SplatPatKind,
344335
DAGOperand xop_kind> {
345-
def : Pat<(result_type (vop
346-
(vop1_type vop_reg_class:$rs1),
347-
(vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))),
348-
(mask_type true_mask),
349-
VLOpFrag)),
350-
(!cast<Instruction>(instruction_name#_#suffix#_# vlmul.MX)
351-
vop_reg_class:$rs1,
352-
xop_kind:$rs2,
353-
GPR:$vl, sew)>;
354336
def : Pat<(result_type (vop
355337
(vop1_type vop_reg_class:$rs1),
356338
(vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))),
@@ -422,14 +404,6 @@ multiclass VPatBinaryVL_VF<SDNode vop,
422404
LMULInfo vlmul,
423405
VReg vop_reg_class,
424406
RegisterClass scalar_reg_class> {
425-
def : Pat<(result_type (vop (vop_type vop_reg_class:$rs1),
426-
(vop_type (SplatFPOp scalar_reg_class:$rs2)),
427-
(mask_type true_mask),
428-
VLOpFrag)),
429-
(!cast<Instruction>(instruction_name#"_"#vlmul.MX)
430-
vop_reg_class:$rs1,
431-
scalar_reg_class:$rs2,
432-
GPR:$vl, sew)>;
433407
def : Pat<(result_type (vop (vop_type vop_reg_class:$rs1),
434408
(vop_type (SplatFPOp scalar_reg_class:$rs2)),
435409
(mask_type V0),
@@ -454,13 +428,6 @@ multiclass VPatBinaryFPVL_VV_VF<SDNode vop, string instruction_name> {
454428

455429
multiclass VPatBinaryFPVL_R_VF<SDNode vop, string instruction_name> {
456430
foreach fvti = AllFloatVectors in {
457-
def : Pat<(fvti.Vector (vop (SplatFPOp fvti.ScalarRegClass:$rs2),
458-
fvti.RegClass:$rs1,
459-
(fvti.Mask true_mask),
460-
VLOpFrag)),
461-
(!cast<Instruction>(instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX)
462-
fvti.RegClass:$rs1, fvti.ScalarRegClass:$rs2,
463-
GPR:$vl, fvti.Log2SEW)>;
464431
def : Pat<(fvti.Vector (vop (SplatFPOp fvti.ScalarRegClass:$rs2),
465432
fvti.RegClass:$rs1,
466433
(fvti.Mask V0),
@@ -747,22 +714,12 @@ defm : VPatBinaryVL_VV_VX<riscv_sub_vl, "PseudoVSUB">;
747714
// Handle VRSUB specially since it's the only integer binary op with reversed
748715
// pattern operands
749716
foreach vti = AllIntegerVectors in {
750-
def : Pat<(riscv_sub_vl (vti.Vector (SplatPat (XLenVT GPR:$rs2))),
751-
(vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask),
752-
VLOpFrag),
753-
(!cast<Instruction>("PseudoVRSUB_VX_"# vti.LMul.MX)
754-
vti.RegClass:$rs1, GPR:$rs2, GPR:$vl, vti.Log2SEW)>;
755717
def : Pat<(riscv_sub_vl (vti.Vector (SplatPat (XLenVT GPR:$rs2))),
756718
(vti.Vector vti.RegClass:$rs1), (vti.Mask V0),
757719
VLOpFrag),
758720
(!cast<Instruction>("PseudoVRSUB_VX_"# vti.LMul.MX#"_MASK")
759721
(vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, GPR:$rs2,
760722
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
761-
def : Pat<(riscv_sub_vl (vti.Vector (SplatPat_simm5 simm5:$rs2)),
762-
(vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask),
763-
VLOpFrag),
764-
(!cast<Instruction>("PseudoVRSUB_VI_"# vti.LMul.MX)
765-
vti.RegClass:$rs1, simm5:$rs2, GPR:$vl, vti.Log2SEW)>;
766723
def : Pat<(riscv_sub_vl (vti.Vector (SplatPat_simm5 simm5:$rs2)),
767724
(vti.Vector vti.RegClass:$rs1), (vti.Mask V0),
768725
VLOpFrag),

0 commit comments

Comments
 (0)