@@ -7193,15 +7193,19 @@ static SDValue LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, const SDLoc &dl,
71937193}
71947194
71957195// Recurse to find a LoadSDNode source and the accumulated ByteOffest.
7196- static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
7197- if (ISD::isNON_EXTLoad(Elt.getNode())) {
7198- auto *BaseLd = cast<LoadSDNode>(Elt);
7199- if (!BaseLd->isSimple())
7200- return false;
7196+ static bool findEltLoadSrc(SDValue Elt, MemSDNode *&Ld, int64_t &ByteOffset) {
7197+ if (auto *BaseLd = dyn_cast<AtomicSDNode>(Elt)) {
72017198 Ld = BaseLd;
72027199 ByteOffset = 0;
72037200 return true;
7204- }
7201+ } else if (auto *BaseLd = dyn_cast<LoadSDNode>(Elt))
7202+ if (ISD::isNON_EXTLoad(Elt.getNode())) {
7203+ if (!BaseLd->isSimple())
7204+ return false;
7205+ Ld = BaseLd;
7206+ ByteOffset = 0;
7207+ return true;
7208+ }
72057209
72067210 switch (Elt.getOpcode()) {
72077211 case ISD::BITCAST:
@@ -7230,6 +7234,20 @@ static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
72307234 }
72317235 }
72327236 break;
7237+ case ISD::EXTRACT_ELEMENT:
7238+ if (auto *IdxC = dyn_cast<ConstantSDNode>(Elt.getOperand(1))) {
7239+ SDValue Src = Elt.getOperand(0);
7240+ unsigned SrcSizeInBits = Src.getScalarValueSizeInBits();
7241+ unsigned DstSizeInBits = Elt.getScalarValueSizeInBits();
7242+ if (2 * DstSizeInBits == SrcSizeInBits && (SrcSizeInBits % 8) == 0 &&
7243+ findEltLoadSrc(Src, Ld, ByteOffset)) {
7244+ uint64_t Idx = IdxC->getZExtValue();
7245+ if (Idx == 1) // Get the upper half.
7246+ ByteOffset += SrcSizeInBits / 8 / 2;
7247+ return true;
7248+ }
7249+ }
7250+ break;
72337251 }
72347252
72357253 return false;
@@ -7254,7 +7272,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
72547272 APInt ZeroMask = APInt::getZero(NumElems);
72557273 APInt UndefMask = APInt::getZero(NumElems);
72567274
7257- SmallVector<LoadSDNode *, 8> Loads(NumElems, nullptr);
7275+ SmallVector<MemSDNode *, 8> Loads(NumElems, nullptr);
72587276 SmallVector<int64_t, 8> ByteOffsets(NumElems, 0);
72597277
72607278 // For each element in the initializer, see if we've found a load, zero or an
@@ -7304,7 +7322,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
73047322 EVT EltBaseVT = EltBase.getValueType();
73057323 assert(EltBaseVT.getSizeInBits() == EltBaseVT.getStoreSizeInBits() &&
73067324 "Register/Memory size mismatch");
7307- LoadSDNode *LDBase = Loads[FirstLoadedElt];
7325+ MemSDNode *LDBase = Loads[FirstLoadedElt];
73087326 assert(LDBase && "Did not find base load for merging consecutive loads");
73097327 unsigned BaseSizeInBits = EltBaseVT.getStoreSizeInBits();
73107328 unsigned BaseSizeInBytes = BaseSizeInBits / 8;
@@ -7318,8 +7336,8 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
73187336
73197337 // Check to see if the element's load is consecutive to the base load
73207338 // or offset from a previous (already checked) load.
7321- auto CheckConsecutiveLoad = [&](LoadSDNode *Base, int EltIdx) {
7322- LoadSDNode *Ld = Loads[EltIdx];
7339+ auto CheckConsecutiveLoad = [&](MemSDNode *Base, int EltIdx) {
7340+ MemSDNode *Ld = Loads[EltIdx];
73237341 int64_t ByteOffset = ByteOffsets[EltIdx];
73247342 if (ByteOffset && (ByteOffset % BaseSizeInBytes) == 0) {
73257343 int64_t BaseIdx = EltIdx - (ByteOffset / BaseSizeInBytes);
@@ -7347,7 +7365,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
73477365 }
73487366 }
73497367
7350- auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, LoadSDNode *LDBase) {
7368+ auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, MemSDNode *LDBase) {
73517369 auto MMOFlags = LDBase->getMemOperand()->getFlags();
73527370 assert(LDBase->isSimple() &&
73537371 "Cannot merge volatile or atomic loads.");
@@ -9452,8 +9470,9 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
94529470 {
94539471 SmallVector<SDValue, 64> Ops(Op->ops().take_front(NumElems));
94549472 if (SDValue LD =
9455- EltsFromConsecutiveLoads(VT, Ops, dl, DAG, Subtarget, false))
9473+ EltsFromConsecutiveLoads(VT, Ops, dl, DAG, Subtarget, false)) {
94569474 return LD;
9475+ }
94579476 }
94589477
94599478 // If this is a splat of pairs of 32-bit elements, we can use a narrower
@@ -60381,6 +60400,35 @@ static SDValue combineINTRINSIC_VOID(SDNode *N, SelectionDAG &DAG,
6038160400 return SDValue();
6038260401}
6038360402
60403+ static SDValue combineVZEXT_LOAD(SDNode *N, SelectionDAG &DAG,
60404+ TargetLowering::DAGCombinerInfo &DCI) {
60405+ // Find the TokenFactor to locate the associated AtomicLoad.
60406+ SDNode *ALD = nullptr;
60407+ for (auto &TF : DAG.allnodes())
60408+ if (TF.getOpcode() == ISD::TokenFactor) {
60409+ SDValue L = TF.getOperand(0);
60410+ SDValue R = TF.getOperand(1);
60411+ if (L.getNode() == N)
60412+ ALD = R.getNode();
60413+ else if (R.getNode() == N)
60414+ ALD = L.getNode();
60415+ }
60416+
60417+ if (!ALD)
60418+ return SDValue();
60419+ if (!isa<AtomicSDNode>(ALD))
60420+ return SDValue();
60421+
60422+ // Replace the VZEXT_LOAD with the AtomicLoad.
60423+ SDLoc dl(N);
60424+ SDValue SV =
60425+ DAG.getNode(ISD::SCALAR_TO_VECTOR, dl,
60426+ N->getValueType(0).changeTypeToInteger(), SDValue(ALD, 0));
60427+ SDValue BC = DAG.getNode(ISD::BITCAST, dl, N->getValueType(0), SV);
60428+ BC = DCI.CombineTo(N, BC, SDValue(ALD, 1));
60429+ return BC;
60430+ }
60431+
6038460432SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
6038560433 DAGCombinerInfo &DCI) const {
6038660434 SelectionDAG &DAG = DCI.DAG;
@@ -60577,6 +60625,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
6057760625 case ISD::INTRINSIC_VOID: return combineINTRINSIC_VOID(N, DAG, DCI);
6057860626 case ISD::FP_TO_SINT_SAT:
6057960627 case ISD::FP_TO_UINT_SAT: return combineFP_TO_xINT_SAT(N, DAG, Subtarget);
60628+ case X86ISD::VZEXT_LOAD: return combineVZEXT_LOAD(N, DAG, DCI);
6058060629 // clang-format on
6058160630 }
6058260631
0 commit comments