@@ -10732,33 +10732,44 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1073210732 SDValue Chain = MemSD->getChain();
1073310733 SDValue BasePtr = MemSD->getBasePtr();
1073410734
10735- SDValue Mask, PassThru, VL;
10735+ SDValue Mask, PassThru, LoadVL;
10736+ bool IsExpandingLoad = false;
1073610737 if (const auto *VPLoad = dyn_cast<VPLoadSDNode>(Op)) {
1073710738 Mask = VPLoad->getMask();
1073810739 PassThru = DAG.getUNDEF(VT);
10739- VL = VPLoad->getVectorLength();
10740+ LoadVL = VPLoad->getVectorLength();
1074010741 } else {
1074110742 const auto *MLoad = cast<MaskedLoadSDNode>(Op);
1074210743 Mask = MLoad->getMask();
1074310744 PassThru = MLoad->getPassThru();
10745+ IsExpandingLoad = MLoad->isExpandingLoad();
1074410746 }
1074510747
10746- bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
10748+ bool IsUnmasked =
10749+ ISD::isConstantSplatVectorAllOnes(Mask.getNode()) || IsExpandingLoad;
1074710750
1074810751 MVT XLenVT = Subtarget.getXLenVT();
1074910752
1075010753 MVT ContainerVT = VT;
1075110754 if (VT.isFixedLengthVector()) {
1075210755 ContainerVT = getContainerForFixedLengthVector(VT);
1075310756 PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget);
10754- if (!IsUnmasked) {
10757+ if (!IsUnmasked || IsExpandingLoad ) {
1075510758 MVT MaskVT = getMaskTypeFor(ContainerVT);
1075610759 Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
1075710760 }
1075810761 }
1075910762
10760- if (!VL)
10761- VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
10763+ if (!LoadVL)
10764+ LoadVL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
10765+
10766+ SDValue ExpandingVL;
10767+ if (IsExpandingLoad) {
10768+ ExpandingVL = LoadVL;
10769+ LoadVL = DAG.getNode(
10770+ RISCVISD::VCPOP_VL, DL, XLenVT, Mask,
10771+ getAllOnesMask(Mask.getSimpleValueType(), LoadVL, DL, DAG), LoadVL);
10772+ }
1076210773
1076310774 unsigned IntID =
1076410775 IsUnmasked ? Intrinsic::riscv_vle : Intrinsic::riscv_vle_mask;
@@ -10770,7 +10781,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1077010781 Ops.push_back(BasePtr);
1077110782 if (!IsUnmasked)
1077210783 Ops.push_back(Mask);
10773- Ops.push_back(VL );
10784+ Ops.push_back(LoadVL );
1077410785 if (!IsUnmasked)
1077510786 Ops.push_back(DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT));
1077610787
@@ -10779,6 +10790,18 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1077910790 SDValue Result =
1078010791 DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, MemVT, MMO);
1078110792 Chain = Result.getValue(1);
10793+ if (IsExpandingLoad) {
10794+ MVT IotaVT = ContainerVT;
10795+ if (ContainerVT.isFloatingPoint())
10796+ IotaVT = ContainerVT.changeVectorElementTypeToInteger();
10797+
10798+ SDValue Iota =
10799+ DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IotaVT,
10800+ DAG.getConstant(Intrinsic::riscv_viota, DL, XLenVT),
10801+ DAG.getUNDEF(IotaVT), Mask, ExpandingVL);
10802+ Result = DAG.getNode(RISCVISD::VRGATHER_VV_VL, DL, ContainerVT, Result,
10803+ Iota, PassThru, Mask, ExpandingVL);
10804+ }
1078210805
1078310806 if (VT.isFixedLengthVector())
1078410807 Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
0 commit comments