@@ -1568,13 +1568,20 @@ static SDValue convertFromScalableVector(EVT VT, SDValue V, SelectionDAG &DAG,
15681568 return DAG.getNode (ISD::EXTRACT_SUBVECTOR, DL, VT, V, Zero);
15691569}
15701570
1571+ // / Return the type of the mask type suitable for masking the provided
1572+ // / vector type. This is simply an i1 element type vector of the same
1573+ // / (possibly scalable) length.
1574+ static MVT getMaskTypeFor (EVT VecVT) {
1575+ assert (VecVT.isVector ());
1576+ ElementCount EC = VecVT.getVectorElementCount ();
1577+ return MVT::getVectorVT (MVT::i1, EC);
1578+ }
1579+
15711580// / Creates an all ones mask suitable for masking a vector of type VecTy with
15721581// / vector length VL. .
15731582static SDValue getAllOnesMask (MVT VecVT, SDValue VL, SDLoc DL,
15741583 SelectionDAG &DAG) {
1575- assert (VecVT.isVector ());
1576- ElementCount EC = VecVT.getVectorElementCount ();
1577- MVT MaskVT = MVT::getVectorVT (MVT::i1, EC);
1584+ MVT MaskVT = getMaskTypeFor (VecVT);
15781585 return DAG.getNode (RISCVISD::VMSET_VL, DL, MaskVT, VL);
15791586}
15801587
@@ -4237,8 +4244,7 @@ SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op,
42374244 ContainerVT = getContainerForFixedLengthVector (SrcVT);
42384245 Src = convertToScalableVector (ContainerVT, Src, DAG, Subtarget);
42394246 if (IsVPTrunc) {
4240- MVT MaskVT =
4241- MVT::getVectorVT (MVT::i1, ContainerVT.getVectorElementCount ());
4247+ MVT MaskVT = getMaskTypeFor (ContainerVT);
42424248 Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
42434249 }
42444250 }
@@ -4298,8 +4304,7 @@ SDValue RISCVTargetLowering::lowerVectorFPRoundLike(SDValue Op,
42984304 SrcContainerVT.changeVectorElementType (VT.getVectorElementType ());
42994305 Src = convertToScalableVector (SrcContainerVT, Src, DAG, Subtarget);
43004306 if (IsVPFPTrunc) {
4301- MVT MaskVT =
4302- MVT::getVectorVT (MVT::i1, ContainerVT.getVectorElementCount ());
4307+ MVT MaskVT = getMaskTypeFor (ContainerVT);
43034308 Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
43044309 }
43054310 }
@@ -4807,7 +4812,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
48074812 DAG.getNode (RISCVISD::VMV_V_X_VL, DL, VT, DAG.getUNDEF (VT),
48084813 DAG.getConstant (0 , DL, MVT::i32 ), VL);
48094814
4810- MVT MaskVT = MVT::getVectorVT (MVT::i1, VT. getVectorElementCount () );
4815+ MVT MaskVT = getMaskTypeFor (VT );
48114816 SDValue Mask = getAllOnesMask (VT, VL, DL, DAG);
48124817 SDValue VID = DAG.getNode (RISCVISD::VID_VL, DL, VT, Mask, VL);
48134818 SDValue SelectCond =
@@ -4841,8 +4846,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
48414846
48424847 SDValue PassThru = Op.getOperand (2 );
48434848 if (!IsUnmasked) {
4844- MVT MaskVT =
4845- MVT::getVectorVT (MVT::i1, ContainerVT.getVectorElementCount ());
4849+ MVT MaskVT = getMaskTypeFor (ContainerVT);
48464850 Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
48474851 PassThru = convertToScalableVector (ContainerVT, PassThru, DAG, Subtarget);
48484852 }
@@ -4939,8 +4943,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op,
49394943
49404944 Val = convertToScalableVector (ContainerVT, Val, DAG, Subtarget);
49414945 if (!IsUnmasked) {
4942- MVT MaskVT =
4943- MVT::getVectorVT (MVT::i1, ContainerVT.getVectorElementCount ());
4946+ MVT MaskVT = getMaskTypeFor (ContainerVT);
49444947 Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
49454948 }
49464949
@@ -5791,8 +5794,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
57915794 ContainerVT = getContainerForFixedLengthVector (VT);
57925795 PassThru = convertToScalableVector (ContainerVT, PassThru, DAG, Subtarget);
57935796 if (!IsUnmasked) {
5794- MVT MaskVT =
5795- MVT::getVectorVT (MVT::i1, ContainerVT.getVectorElementCount ());
5797+ MVT MaskVT = getMaskTypeFor (ContainerVT);
57965798 Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
57975799 }
57985800 }
@@ -5858,8 +5860,7 @@ SDValue RISCVTargetLowering::lowerMaskedStore(SDValue Op,
58585860
58595861 Val = convertToScalableVector (ContainerVT, Val, DAG, Subtarget);
58605862 if (!IsUnmasked) {
5861- MVT MaskVT =
5862- MVT::getVectorVT (MVT::i1, ContainerVT.getVectorElementCount ());
5863+ MVT MaskVT = getMaskTypeFor (ContainerVT);
58635864 Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
58645865 }
58655866 }
@@ -5897,7 +5898,7 @@ RISCVTargetLowering::lowerFixedLengthVectorSetccToRVV(SDValue Op,
58975898 SDValue VL =
58985899 DAG.getConstant (VT.getVectorNumElements (), DL, Subtarget.getXLenVT ());
58995900
5900- MVT MaskVT = MVT::getVectorVT (MVT::i1, ContainerVT. getVectorElementCount () );
5901+ MVT MaskVT = getMaskTypeFor ( ContainerVT);
59015902 SDValue Mask = getAllOnesMask (ContainerVT, VL, DL, DAG);
59025903
59035904 SDValue Cmp = DAG.getNode (RISCVISD::SETCC_VL, DL, MaskVT, Op1, Op2,
@@ -6200,7 +6201,7 @@ SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG,
62006201 DstVT = getContainerForFixedLengthVector (DstVT);
62016202 SrcVT = getContainerForFixedLengthVector (SrcVT);
62026203 Src = convertToScalableVector (SrcVT, Src, DAG, Subtarget);
6203- MVT MaskVT = MVT::getVectorVT (MVT::i1, DstVT. getVectorElementCount () );
6204+ MVT MaskVT = getMaskTypeFor ( DstVT);
62046205 Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
62056206 }
62066207
@@ -6413,8 +6414,7 @@ SDValue RISCVTargetLowering::lowerMaskedGather(SDValue Op,
64136414 Index = convertToScalableVector (IndexVT, Index, DAG, Subtarget);
64146415
64156416 if (!IsUnmasked) {
6416- MVT MaskVT =
6417- MVT::getVectorVT (MVT::i1, ContainerVT.getVectorElementCount ());
6417+ MVT MaskVT = getMaskTypeFor (ContainerVT);
64186418 Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
64196419 PassThru = convertToScalableVector (ContainerVT, PassThru, DAG, Subtarget);
64206420 }
@@ -6525,8 +6525,7 @@ SDValue RISCVTargetLowering::lowerMaskedScatter(SDValue Op,
65256525 Val = convertToScalableVector (ContainerVT, Val, DAG, Subtarget);
65266526
65276527 if (!IsUnmasked) {
6528- MVT MaskVT =
6529- MVT::getVectorVT (MVT::i1, ContainerVT.getVectorElementCount ());
6528+ MVT MaskVT = getMaskTypeFor (ContainerVT);
65306529 Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
65316530 }
65326531 }
@@ -8813,7 +8812,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
88138812 // The memory VT and the element type must match.
88148813 if (VecVT.getVectorElementType () == MemVT) {
88158814 SDLoc DL (N);
8816- MVT MaskVT = MVT::getVectorVT (MVT::i1, VecVT. getVectorElementCount () );
8815+ MVT MaskVT = getMaskTypeFor ( VecVT);
88178816 return DAG.getStoreVP (
88188817 Store->getChain (), DL, Src, Store->getBasePtr (), Store->getOffset (),
88198818 DAG.getConstant (1 , DL, MaskVT),
0 commit comments