@@ -7049,15 +7049,23 @@ static SDValue LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, const SDLoc &dl,
70497049 return SDValue();
70507050}
70517051
7052- // Recurse to find a LoadSDNode source and the accumulated ByteOffest.
7053- static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
7054- if (ISD::isNON_EXTLoad(Elt.getNode())) {
7055- auto *BaseLd = cast<LoadSDNode>(Elt);
7056- if (!BaseLd->isSimple())
7057- return false;
7058- Ld = BaseLd;
7059- ByteOffset = 0;
7060- return true;
7052+ template <typename T>
7053+ static bool findEltLoadSrc(SDValue Elt, T *&Ld, int64_t &ByteOffset) {
7054+ if constexpr (std::is_same_v<T, AtomicSDNode>) {
7055+ if (auto *BaseLd = dyn_cast<AtomicSDNode>(Elt)) {
7056+ Ld = BaseLd;
7057+ ByteOffset = 0;
7058+ return true;
7059+ }
7060+ } else if constexpr (std::is_same_v<T, LoadSDNode>) {
7061+ if (ISD::isNON_EXTLoad(Elt.getNode())) {
7062+ auto *BaseLd = cast<LoadSDNode>(Elt);
7063+ if (!BaseLd->isSimple())
7064+ return false;
7065+ Ld = BaseLd;
7066+ ByteOffset = 0;
7067+ return true;
7068+ }
70617069 }
70627070
70637071 switch (Elt.getOpcode()) {
@@ -7097,6 +7105,7 @@ static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
70977105/// a build_vector or insert_subvector whose loaded operands are 'Elts'.
70987106///
70997107/// Example: <load i32 *a, load i32 *a+4, zero, undef> -> zextload a
7108+ template <typename T>
71007109static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71017110 const SDLoc &DL, SelectionDAG &DAG,
71027111 const X86Subtarget &Subtarget,
@@ -7111,7 +7120,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71117120 APInt ZeroMask = APInt::getZero(NumElems);
71127121 APInt UndefMask = APInt::getZero(NumElems);
71137122
7114- SmallVector<LoadSDNode *, 8> Loads(NumElems, nullptr);
7123+ SmallVector<T *, 8> Loads(NumElems, nullptr);
71157124 SmallVector<int64_t, 8> ByteOffsets(NumElems, 0);
71167125
71177126 // For each element in the initializer, see if we've found a load, zero or an
@@ -7161,7 +7170,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71617170 EVT EltBaseVT = EltBase.getValueType();
71627171 assert(EltBaseVT.getSizeInBits() == EltBaseVT.getStoreSizeInBits() &&
71637172 "Register/Memory size mismatch");
7164- LoadSDNode *LDBase = Loads[FirstLoadedElt];
7173+ T *LDBase = Loads[FirstLoadedElt];
71657174 assert(LDBase && "Did not find base load for merging consecutive loads");
71667175 unsigned BaseSizeInBits = EltBaseVT.getStoreSizeInBits();
71677176 unsigned BaseSizeInBytes = BaseSizeInBits / 8;
@@ -7175,8 +7184,8 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71757184
71767185 // Check to see if the element's load is consecutive to the base load
71777186 // or offset from a previous (already checked) load.
7178- auto CheckConsecutiveLoad = [&](LoadSDNode *Base, int EltIdx) {
7179- LoadSDNode *Ld = Loads[EltIdx];
7187+ auto CheckConsecutiveLoad = [&](T *Base, int EltIdx) {
7188+ T *Ld = Loads[EltIdx];
71807189 int64_t ByteOffset = ByteOffsets[EltIdx];
71817190 if (ByteOffset && (ByteOffset % BaseSizeInBytes) == 0) {
71827191 int64_t BaseIdx = EltIdx - (ByteOffset / BaseSizeInBytes);
@@ -7204,7 +7213,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
72047213 }
72057214 }
72067215
7207- auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, LoadSDNode *LDBase) {
7216+ auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, T *LDBase) {
72087217 auto MMOFlags = LDBase->getMemOperand()->getFlags();
72097218 assert(LDBase->isSimple() &&
72107219 "Cannot merge volatile or atomic loads.");
@@ -7274,7 +7283,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
72747283 EVT HalfVT =
72757284 EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), HalfNumElems);
72767285 SDValue HalfLD =
7277- EltsFromConsecutiveLoads(HalfVT, Elts.drop_back(HalfNumElems), DL,
7286+ EltsFromConsecutiveLoads<T> (HalfVT, Elts.drop_back(HalfNumElems), DL,
72787287 DAG, Subtarget, IsAfterLegalize);
72797288 if (HalfLD)
72807289 return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT),
@@ -7351,7 +7360,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
73517360 EVT::getVectorVT(*DAG.getContext(), RepeatVT.getScalarType(),
73527361 VT.getSizeInBits() / ScalarSize);
73537362 if (TLI.isTypeLegal(BroadcastVT)) {
7354- if (SDValue RepeatLoad = EltsFromConsecutiveLoads(
7363+ if (SDValue RepeatLoad = EltsFromConsecutiveLoads<T> (
73557364 RepeatVT, RepeatedLoads, DL, DAG, Subtarget, IsAfterLegalize)) {
73567365 SDValue Broadcast = RepeatLoad;
73577366 if (RepeatSize > ScalarSize) {
@@ -7392,7 +7401,7 @@ static SDValue combineToConsecutiveLoads(EVT VT, SDValue Op, const SDLoc &DL,
73927401 return SDValue();
73937402 }
73947403 assert(Elts.size() == VT.getVectorNumElements());
7395- return EltsFromConsecutiveLoads(VT, Elts, DL, DAG, Subtarget,
7404+ return EltsFromConsecutiveLoads<LoadSDNode> (VT, Elts, DL, DAG, Subtarget,
73967405 IsAfterLegalize);
73977406}
73987407
@@ -9247,8 +9256,12 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
92479256 {
92489257 SmallVector<SDValue, 64> Ops(Op->op_begin(), Op->op_begin() + NumElems);
92499258 if (SDValue LD =
9250- EltsFromConsecutiveLoads(VT, Ops, dl, DAG, Subtarget, false))
9259+ EltsFromConsecutiveLoads<LoadSDNode> (VT, Ops, dl, DAG, Subtarget, false)) {
92519260 return LD;
9261+ } else if (SDValue LD =
9262+ EltsFromConsecutiveLoads<AtomicSDNode>(VT, Ops, dl, DAG, Subtarget, false)) {
9263+ return LD;
9264+ }
92529265 }
92539266
92549267 // If this is a splat of pairs of 32-bit elements, we can use a narrower
@@ -57934,7 +57947,7 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
5793457947 *FirstLd->getMemOperand(), &Fast) &&
5793557948 Fast) {
5793657949 if (SDValue Ld =
57937- EltsFromConsecutiveLoads(VT, Ops, DL, DAG, Subtarget, false))
57950+ EltsFromConsecutiveLoads<LoadSDNode> (VT, Ops, DL, DAG, Subtarget, false))
5793857951 return Ld;
5793957952 }
5794057953 }
0 commit comments