@@ -7061,14 +7061,23 @@ static SDValue LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, const SDLoc &dl,
70617061}
70627062
70637063// Recurse to find a LoadSDNode source and the accumulated ByteOffest.
7064- static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
7065- if (ISD::isNON_EXTLoad(Elt.getNode())) {
7066- auto *BaseLd = cast<LoadSDNode>(Elt);
7067- if (!BaseLd->isSimple())
7068- return false;
7069- Ld = BaseLd;
7070- ByteOffset = 0;
7071- return true;
7064+ template <typename T>
7065+ static bool findEltLoadSrc(SDValue Elt, T *&Ld, int64_t &ByteOffset) {
7066+ if constexpr (std::is_same_v<T, AtomicSDNode>) {
7067+ if (auto *BaseLd = dyn_cast<AtomicSDNode>(Elt)) {
7068+ Ld = BaseLd;
7069+ ByteOffset = 0;
7070+ return true;
7071+ }
7072+ } else if constexpr (std::is_same_v<T, LoadSDNode>) {
7073+ if (ISD::isNON_EXTLoad(Elt.getNode())) {
7074+ auto *BaseLd = cast<LoadSDNode>(Elt);
7075+ if (!BaseLd->isSimple())
7076+ return false;
7077+ Ld = BaseLd;
7078+ ByteOffset = 0;
7079+ return true;
7080+ }
70727081 }
70737082
70747083 switch (Elt.getOpcode()) {
@@ -7108,6 +7117,7 @@ static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
71087117/// a build_vector or insert_subvector whose loaded operands are 'Elts'.
71097118///
71107119/// Example: <load i32 *a, load i32 *a+4, zero, undef> -> zextload a
7120+ template <typename T>
71117121static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71127122 const SDLoc &DL, SelectionDAG &DAG,
71137123 const X86Subtarget &Subtarget,
@@ -7122,7 +7132,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71227132 APInt ZeroMask = APInt::getZero(NumElems);
71237133 APInt UndefMask = APInt::getZero(NumElems);
71247134
7125- SmallVector<LoadSDNode *, 8> Loads(NumElems, nullptr);
7135+ SmallVector<T *, 8> Loads(NumElems, nullptr);
71267136 SmallVector<int64_t, 8> ByteOffsets(NumElems, 0);
71277137
71287138 // For each element in the initializer, see if we've found a load, zero or an
@@ -7172,7 +7182,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71727182 EVT EltBaseVT = EltBase.getValueType();
71737183 assert(EltBaseVT.getSizeInBits() == EltBaseVT.getStoreSizeInBits() &&
71747184 "Register/Memory size mismatch");
7175- LoadSDNode *LDBase = Loads[FirstLoadedElt];
7185+ T *LDBase = Loads[FirstLoadedElt];
71767186 assert(LDBase && "Did not find base load for merging consecutive loads");
71777187 unsigned BaseSizeInBits = EltBaseVT.getStoreSizeInBits();
71787188 unsigned BaseSizeInBytes = BaseSizeInBits / 8;
@@ -7186,8 +7196,8 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71867196
71877197 // Check to see if the element's load is consecutive to the base load
71887198 // or offset from a previous (already checked) load.
7189- auto CheckConsecutiveLoad = [&](LoadSDNode *Base, int EltIdx) {
7190- LoadSDNode *Ld = Loads[EltIdx];
7199+ auto CheckConsecutiveLoad = [&](T *Base, int EltIdx) {
7200+ T *Ld = Loads[EltIdx];
71917201 int64_t ByteOffset = ByteOffsets[EltIdx];
71927202 if (ByteOffset && (ByteOffset % BaseSizeInBytes) == 0) {
71937203 int64_t BaseIdx = EltIdx - (ByteOffset / BaseSizeInBytes);
@@ -7215,7 +7225,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
72157225 }
72167226 }
72177227
7218- auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, LoadSDNode *LDBase) {
7228+ auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, T *LDBase) {
72197229 auto MMOFlags = LDBase->getMemOperand()->getFlags();
72207230 assert(LDBase->isSimple() &&
72217231 "Cannot merge volatile or atomic loads.");
@@ -7285,7 +7295,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
72857295 EVT HalfVT =
72867296 EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), HalfNumElems);
72877297 SDValue HalfLD =
7288- EltsFromConsecutiveLoads(HalfVT, Elts.drop_back(HalfNumElems), DL,
7298+ EltsFromConsecutiveLoads<T> (HalfVT, Elts.drop_back(HalfNumElems), DL,
72897299 DAG, Subtarget, IsAfterLegalize);
72907300 if (HalfLD)
72917301 return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT),
@@ -7362,7 +7372,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
73627372 EVT::getVectorVT(*DAG.getContext(), RepeatVT.getScalarType(),
73637373 VT.getSizeInBits() / ScalarSize);
73647374 if (TLI.isTypeLegal(BroadcastVT)) {
7365- if (SDValue RepeatLoad = EltsFromConsecutiveLoads(
7375+ if (SDValue RepeatLoad = EltsFromConsecutiveLoads<T> (
73667376 RepeatVT, RepeatedLoads, DL, DAG, Subtarget, IsAfterLegalize)) {
73677377 SDValue Broadcast = RepeatLoad;
73687378 if (RepeatSize > ScalarSize) {
@@ -7403,7 +7413,7 @@ static SDValue combineToConsecutiveLoads(EVT VT, SDValue Op, const SDLoc &DL,
74037413 return SDValue();
74047414 }
74057415 assert(Elts.size() == VT.getVectorNumElements());
7406- return EltsFromConsecutiveLoads(VT, Elts, DL, DAG, Subtarget,
7416+ return EltsFromConsecutiveLoads<LoadSDNode> (VT, Elts, DL, DAG, Subtarget,
74077417 IsAfterLegalize);
74087418}
74097419
@@ -9258,8 +9268,12 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
92589268 {
92599269 SmallVector<SDValue, 64> Ops(Op->op_begin(), Op->op_begin() + NumElems);
92609270 if (SDValue LD =
9261- EltsFromConsecutiveLoads(VT, Ops, dl, DAG, Subtarget, false))
9271+ EltsFromConsecutiveLoads<LoadSDNode> (VT, Ops, dl, DAG, Subtarget, false)) {
92629272 return LD;
9273+ } else if (SDValue LD =
9274+ EltsFromConsecutiveLoads<AtomicSDNode>(VT, Ops, dl, DAG, Subtarget, false)) {
9275+ return LD;
9276+ }
92639277 }
92649278
92659279 // If this is a splat of pairs of 32-bit elements, we can use a narrower
@@ -57979,7 +57993,7 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
5797957993 *FirstLd->getMemOperand(), &Fast) &&
5798057994 Fast) {
5798157995 if (SDValue Ld =
57982- EltsFromConsecutiveLoads(VT, Ops, DL, DAG, Subtarget, false))
57996+ EltsFromConsecutiveLoads<LoadSDNode> (VT, Ops, DL, DAG, Subtarget, false))
5798357997 return Ld;
5798457998 }
5798557999 }
0 commit comments