@@ -7050,14 +7050,23 @@ static SDValue LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, const SDLoc &dl,
70507050}
70517051
70527052// Recurse to find a LoadSDNode source and the accumulated ByteOffest.
7053- static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
7054- if (ISD::isNON_EXTLoad(Elt.getNode())) {
7055- auto *BaseLd = cast<LoadSDNode>(Elt);
7056- if (!BaseLd->isSimple())
7057- return false;
7058- Ld = BaseLd;
7059- ByteOffset = 0;
7060- return true;
7053+ template <typename T>
7054+ static bool findEltLoadSrc(SDValue Elt, T *&Ld, int64_t &ByteOffset) {
7055+ if constexpr (std::is_same_v<T, AtomicSDNode>) {
7056+ if (auto *BaseLd = dyn_cast<AtomicSDNode>(Elt)) {
7057+ Ld = BaseLd;
7058+ ByteOffset = 0;
7059+ return true;
7060+ }
7061+ } else if constexpr (std::is_same_v<T, LoadSDNode>) {
7062+ if (ISD::isNON_EXTLoad(Elt.getNode())) {
7063+ auto *BaseLd = cast<LoadSDNode>(Elt);
7064+ if (!BaseLd->isSimple())
7065+ return false;
7066+ Ld = BaseLd;
7067+ ByteOffset = 0;
7068+ return true;
7069+ }
70617070 }
70627071
70637072 switch (Elt.getOpcode()) {
@@ -7097,6 +7106,7 @@ static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
70977106/// a build_vector or insert_subvector whose loaded operands are 'Elts'.
70987107///
70997108/// Example: <load i32 *a, load i32 *a+4, zero, undef> -> zextload a
7109+ template <typename T>
71007110static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71017111 const SDLoc &DL, SelectionDAG &DAG,
71027112 const X86Subtarget &Subtarget,
@@ -7111,7 +7121,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71117121 APInt ZeroMask = APInt::getZero(NumElems);
71127122 APInt UndefMask = APInt::getZero(NumElems);
71137123
7114- SmallVector<LoadSDNode *, 8> Loads(NumElems, nullptr);
7124+ SmallVector<T *, 8> Loads(NumElems, nullptr);
71157125 SmallVector<int64_t, 8> ByteOffsets(NumElems, 0);
71167126
71177127 // For each element in the initializer, see if we've found a load, zero or an
@@ -7161,7 +7171,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71617171 EVT EltBaseVT = EltBase.getValueType();
71627172 assert(EltBaseVT.getSizeInBits() == EltBaseVT.getStoreSizeInBits() &&
71637173 "Register/Memory size mismatch");
7164- LoadSDNode *LDBase = Loads[FirstLoadedElt];
7174+ T *LDBase = Loads[FirstLoadedElt];
71657175 assert(LDBase && "Did not find base load for merging consecutive loads");
71667176 unsigned BaseSizeInBits = EltBaseVT.getStoreSizeInBits();
71677177 unsigned BaseSizeInBytes = BaseSizeInBits / 8;
@@ -7175,8 +7185,8 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71757185
71767186 // Check to see if the element's load is consecutive to the base load
71777187 // or offset from a previous (already checked) load.
7178- auto CheckConsecutiveLoad = [&](LoadSDNode *Base, int EltIdx) {
7179- LoadSDNode *Ld = Loads[EltIdx];
7188+ auto CheckConsecutiveLoad = [&](T *Base, int EltIdx) {
7189+ T *Ld = Loads[EltIdx];
71807190 int64_t ByteOffset = ByteOffsets[EltIdx];
71817191 if (ByteOffset && (ByteOffset % BaseSizeInBytes) == 0) {
71827192 int64_t BaseIdx = EltIdx - (ByteOffset / BaseSizeInBytes);
@@ -7204,7 +7214,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
72047214 }
72057215 }
72067216
7207- auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, LoadSDNode *LDBase) {
7217+ auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, T *LDBase) {
72087218 auto MMOFlags = LDBase->getMemOperand()->getFlags();
72097219 assert(LDBase->isSimple() &&
72107220 "Cannot merge volatile or atomic loads.");
@@ -7274,7 +7284,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
72747284 EVT HalfVT =
72757285 EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), HalfNumElems);
72767286 SDValue HalfLD =
7277- EltsFromConsecutiveLoads(HalfVT, Elts.drop_back(HalfNumElems), DL,
7287+ EltsFromConsecutiveLoads<T> (HalfVT, Elts.drop_back(HalfNumElems), DL,
72787288 DAG, Subtarget, IsAfterLegalize);
72797289 if (HalfLD)
72807290 return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT),
@@ -7351,7 +7361,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
73517361 EVT::getVectorVT(*DAG.getContext(), RepeatVT.getScalarType(),
73527362 VT.getSizeInBits() / ScalarSize);
73537363 if (TLI.isTypeLegal(BroadcastVT)) {
7354- if (SDValue RepeatLoad = EltsFromConsecutiveLoads(
7364+ if (SDValue RepeatLoad = EltsFromConsecutiveLoads<T> (
73557365 RepeatVT, RepeatedLoads, DL, DAG, Subtarget, IsAfterLegalize)) {
73567366 SDValue Broadcast = RepeatLoad;
73577367 if (RepeatSize > ScalarSize) {
@@ -7392,7 +7402,7 @@ static SDValue combineToConsecutiveLoads(EVT VT, SDValue Op, const SDLoc &DL,
73927402 return SDValue();
73937403 }
73947404 assert(Elts.size() == VT.getVectorNumElements());
7395- return EltsFromConsecutiveLoads(VT, Elts, DL, DAG, Subtarget,
7405+ return EltsFromConsecutiveLoads<LoadSDNode> (VT, Elts, DL, DAG, Subtarget,
73967406 IsAfterLegalize);
73977407}
73987408
@@ -9247,8 +9257,12 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
92479257 {
92489258 SmallVector<SDValue, 64> Ops(Op->op_begin(), Op->op_begin() + NumElems);
92499259 if (SDValue LD =
9250- EltsFromConsecutiveLoads(VT, Ops, dl, DAG, Subtarget, false))
9260+ EltsFromConsecutiveLoads<LoadSDNode> (VT, Ops, dl, DAG, Subtarget, false)) {
92519261 return LD;
9262+ } else if (SDValue LD =
9263+ EltsFromConsecutiveLoads<AtomicSDNode>(VT, Ops, dl, DAG, Subtarget, false)) {
9264+ return LD;
9265+ }
92529266 }
92539267
92549268 // If this is a splat of pairs of 32-bit elements, we can use a narrower
@@ -57934,7 +57948,7 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
5793457948 *FirstLd->getMemOperand(), &Fast) &&
5793557949 Fast) {
5793657950 if (SDValue Ld =
57937- EltsFromConsecutiveLoads(VT, Ops, DL, DAG, Subtarget, false))
57951+ EltsFromConsecutiveLoads<LoadSDNode> (VT, Ops, DL, DAG, Subtarget, false))
5793857952 return Ld;
5793957953 }
5794057954 }
0 commit comments