Skip to content

Commit f927be0

Browse files
committed
[RISCV] Extract getAllOnesMask helper [nfc]
1 parent 484fcb9 commit f927be0

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,6 +1568,16 @@ static SDValue convertFromScalableVector(EVT VT, SDValue V, SelectionDAG &DAG,
15681568
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, V, Zero);
15691569
}
15701570

1571+
/// Creates an all ones mask suitable for masking a vector of type VecTy with
1572+
/// vector length VL. .
1573+
static SDValue getAllOnesMask(MVT VecVT, SDValue VL, SDLoc DL,
1574+
SelectionDAG &DAG) {
1575+
assert(VecVT.isVector());
1576+
ElementCount EC = VecVT.getVectorElementCount();
1577+
MVT MaskVT = MVT::getVectorVT(MVT::i1, EC);
1578+
return DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
1579+
}
1580+
15711581
// Gets the two common "VL" operands: an all-ones mask and the vector length.
15721582
// VecVT is a vector type, either fixed-length or scalable, and ContainerVT is
15731583
// the vector type that it is contained in.
@@ -1579,8 +1589,7 @@ getDefaultVLOps(MVT VecVT, MVT ContainerVT, SDLoc DL, SelectionDAG &DAG,
15791589
SDValue VL = VecVT.isFixedLengthVector()
15801590
? DAG.getConstant(VecVT.getVectorNumElements(), DL, XLenVT)
15811591
: DAG.getRegister(RISCV::X0, XLenVT);
1582-
MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
1583-
SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
1592+
SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG);
15841593
return {Mask, VL};
15851594
}
15861595

@@ -2588,9 +2597,7 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
25882597
V2 = DAG.getFreeze(V2);
25892598

25902599
// Recreate TrueMask using the widened type's element count.
2591-
MVT MaskVT =
2592-
MVT::getVectorVT(MVT::i1, HalfContainerVT.getVectorElementCount());
2593-
TrueMask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
2600+
TrueMask = getAllOnesMask(HalfContainerVT, VL, DL, DAG);
25942601

25952602
// Widen V1 and V2 with 0s and add one copy of V2 to V1.
25962603
SDValue Add = DAG.getNode(RISCVISD::VWADDU_VL, DL, WideIntContainerVT, V1,
@@ -4490,8 +4497,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
44904497
if (!isNullConstant(Idx)) {
44914498
// Use a VL of 1 to avoid processing more elements than we need.
44924499
SDValue VL = DAG.getConstant(1, DL, XLenVT);
4493-
MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
4494-
SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
4500+
SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG);
44954501
Vec = DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, ContainerVT,
44964502
DAG.getUNDEF(ContainerVT), Vec, Idx, Mask, VL);
44974503
}
@@ -4639,8 +4645,7 @@ static SDValue lowerVectorIntrinsicScalars(SDValue Op, SelectionDAG &DAG,
46394645
DAG.getNode(ISD::SHL, DL, XLenVT, VL, DAG.getConstant(1, DL, XLenVT));
46404646
}
46414647

4642-
MVT I32MaskVT = MVT::getVectorVT(MVT::i1, I32VT.getVectorElementCount());
4643-
SDValue I32Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, I32MaskVT, I32VL);
4648+
SDValue I32Mask = getAllOnesMask(I32VT, I32VL, DL, DAG);
46444649

46454650
// Shift the two scalar parts in using SEW=32 slide1up/slide1down
46464651
// instructions.
@@ -4803,7 +4808,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
48034808
DAG.getConstant(0, DL, MVT::i32), VL);
48044809

48054810
MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorElementCount());
4806-
SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
4811+
SDValue Mask = getAllOnesMask(VT, VL, DL, DAG);
48074812
SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, VT, Mask, VL);
48084813
SDValue SelectCond =
48094814
DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT, VID, SplattedIdx,
@@ -5672,8 +5677,7 @@ SDValue RISCVTargetLowering::lowerVECTOR_SPLICE(SDValue Op,
56725677
DownOffset = DAG.getNode(ISD::SUB, DL, XLenVT, VLMax, UpOffset);
56735678
}
56745679

5675-
MVT MaskVT = MVT::getVectorVT(MVT::i1, VecVT.getVectorElementCount());
5676-
SDValue TrueMask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VLMax);
5680+
SDValue TrueMask = getAllOnesMask(VecVT, VLMax, DL, DAG);
56775681

56785682
SDValue SlideDown =
56795683
DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, VecVT, DAG.getUNDEF(VecVT), V1,
@@ -5894,7 +5898,7 @@ RISCVTargetLowering::lowerFixedLengthVectorSetccToRVV(SDValue Op,
58945898
DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT());
58955899

58965900
MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
5897-
SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
5901+
SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG);
58985902

58995903
SDValue Cmp = DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT, Op1, Op2,
59005904
Op.getOperand(2), Mask, VL);
@@ -7103,9 +7107,8 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
71037107
MVT XLenVT = Subtarget.getXLenVT();
71047108

71057109
// Use a VL of 1 to avoid processing more elements than we need.
7106-
MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
71077110
SDValue VL = DAG.getConstant(1, DL, XLenVT);
7108-
SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
7111+
SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG);
71097112

71107113
// Unless the index is known to be 0, we must slide the vector down to get
71117114
// the desired element into index 0.
@@ -7225,8 +7228,8 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
72257228
// To extract the upper XLEN bits of the vector element, shift the first
72267229
// element right by 32 bits and re-extract the lower XLEN bits.
72277230
SDValue VL = DAG.getConstant(1, DL, XLenVT);
7228-
MVT MaskVT = MVT::getVectorVT(MVT::i1, VecVT.getVectorElementCount());
7229-
SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
7231+
SDValue Mask = getAllOnesMask(VecVT, VL, DL, DAG);
7232+
72307233
SDValue ThirtyTwoV =
72317234
DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VecVT, DAG.getUNDEF(VecVT),
72327235
DAG.getConstant(32, DL, XLenVT), VL);

0 commit comments

Comments
 (0)