@@ -7193,15 +7193,19 @@ static SDValue LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, const SDLoc &dl,
71937193}
71947194
71957195// Recurse to find a LoadSDNode source and the accumulated ByteOffest.
7196- static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
7197- if (ISD::isNON_EXTLoad(Elt.getNode())) {
7198- auto *BaseLd = cast<LoadSDNode>(Elt);
7199- if (!BaseLd->isSimple())
7200- return false;
7196+ static bool findEltLoadSrc(SDValue Elt, MemSDNode *&Ld, int64_t &ByteOffset) {
7197+ if (auto *BaseLd = dyn_cast<AtomicSDNode>(Elt)) {
72017198 Ld = BaseLd;
72027199 ByteOffset = 0;
72037200 return true;
7204- }
7201+ } else if (auto *BaseLd = dyn_cast<LoadSDNode>(Elt))
7202+ if (ISD::isNON_EXTLoad(Elt.getNode())) {
7203+ if (!BaseLd->isSimple())
7204+ return false;
7205+ Ld = BaseLd;
7206+ ByteOffset = 0;
7207+ return true;
7208+ }
72057209
72067210 switch (Elt.getOpcode()) {
72077211 case ISD::BITCAST:
@@ -7254,7 +7258,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
72547258 APInt ZeroMask = APInt::getZero(NumElems);
72557259 APInt UndefMask = APInt::getZero(NumElems);
72567260
7257- SmallVector<LoadSDNode *, 8> Loads(NumElems, nullptr);
7261+ SmallVector<MemSDNode *, 8> Loads(NumElems, nullptr);
72587262 SmallVector<int64_t, 8> ByteOffsets(NumElems, 0);
72597263
72607264 // For each element in the initializer, see if we've found a load, zero or an
@@ -7304,7 +7308,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
73047308 EVT EltBaseVT = EltBase.getValueType();
73057309 assert(EltBaseVT.getSizeInBits() == EltBaseVT.getStoreSizeInBits() &&
73067310 "Register/Memory size mismatch");
7307- LoadSDNode *LDBase = Loads[FirstLoadedElt];
7311+ MemSDNode *LDBase = Loads[FirstLoadedElt];
73087312 assert(LDBase && "Did not find base load for merging consecutive loads");
73097313 unsigned BaseSizeInBits = EltBaseVT.getStoreSizeInBits();
73107314 unsigned BaseSizeInBytes = BaseSizeInBits / 8;
@@ -7318,15 +7322,18 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
73187322
73197323 // Check to see if the element's load is consecutive to the base load
73207324 // or offset from a previous (already checked) load.
7321- auto CheckConsecutiveLoad = [&](LoadSDNode *Base, int EltIdx) {
7322- LoadSDNode *Ld = Loads[EltIdx];
7325+ auto CheckConsecutiveLoad = [&](MemSDNode *Base, int EltIdx) {
7326+ MemSDNode *Ld = Loads[EltIdx];
73237327 int64_t ByteOffset = ByteOffsets[EltIdx];
73247328 if (ByteOffset && (ByteOffset % BaseSizeInBytes) == 0) {
73257329 int64_t BaseIdx = EltIdx - (ByteOffset / BaseSizeInBytes);
73267330 return (0 <= BaseIdx && BaseIdx < (int)NumElems && LoadMask[BaseIdx] &&
73277331 Loads[BaseIdx] == Ld && ByteOffsets[BaseIdx] == 0);
73287332 }
7329- return DAG.areNonVolatileConsecutiveLoads(Ld, Base, BaseSizeInBytes,
7333+ auto *L = dyn_cast<LoadSDNode>(Ld);
7334+ auto *B = dyn_cast<LoadSDNode>(Base);
7335+ return L && B &&
7336+ DAG.areNonVolatileConsecutiveLoads(L, B, BaseSizeInBytes,
73307337 EltIdx - FirstLoadedElt);
73317338 };
73327339
@@ -7347,7 +7354,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
73477354 }
73487355 }
73497356
7350- auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, LoadSDNode *LDBase) {
7357+ auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, MemSDNode *LDBase) {
73517358 auto MMOFlags = LDBase->getMemOperand()->getFlags();
73527359 assert(LDBase->isSimple() &&
73537360 "Cannot merge volatile or atomic loads.");
@@ -60539,6 +60546,35 @@ static SDValue combineINTRINSIC_VOID(SDNode *N, SelectionDAG &DAG,
6053960546 return SDValue();
6054060547}
6054160548
60549+ static SDValue combineVZEXT_LOAD(SDNode *N, SelectionDAG &DAG,
60550+ TargetLowering::DAGCombinerInfo &DCI) {
60551+ // Find the TokenFactor to locate the associated AtomicLoad.
60552+ SDNode *ALD = nullptr;
60553+ for (auto &TF : N->uses())
60554+ if (TF.getUser()->getOpcode() == ISD::TokenFactor) {
60555+ SDValue L = TF.getUser()->getOperand(0);
60556+ SDValue R = TF.getUser()->getOperand(1);
60557+ if (L.getNode() == N)
60558+ ALD = R.getNode();
60559+ else if (R.getNode() == N)
60560+ ALD = L.getNode();
60561+ }
60562+
60563+ if (!ALD)
60564+ return SDValue();
60565+ if (!isa<AtomicSDNode>(ALD))
60566+ return SDValue();
60567+
60568+ // Replace the VZEXT_LOAD with the AtomicLoad.
60569+ SDLoc dl(N);
60570+ SDValue SV =
60571+ DAG.getNode(ISD::SCALAR_TO_VECTOR, dl,
60572+ N->getValueType(0).changeTypeToInteger(), SDValue(ALD, 0));
60573+ SDValue BC = DAG.getNode(ISD::BITCAST, dl, N->getValueType(0), SV);
60574+ BC = DCI.CombineTo(N, BC, SDValue(ALD, 1));
60575+ return BC;
60576+ }
60577+
6054260578SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
6054360579 DAGCombinerInfo &DCI) const {
6054460580 SelectionDAG &DAG = DCI.DAG;
@@ -60735,6 +60771,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
6073560771 case ISD::INTRINSIC_VOID: return combineINTRINSIC_VOID(N, DAG, DCI);
6073660772 case ISD::FP_TO_SINT_SAT:
6073760773 case ISD::FP_TO_UINT_SAT: return combineFP_TO_xINT_SAT(N, DAG, Subtarget);
60774+ case X86ISD::VZEXT_LOAD: return combineVZEXT_LOAD(N, DAG, DCI);
6073860775 // clang-format on
6073960776 }
6074060777
0 commit comments