@@ -53392,34 +53392,38 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
5339253392static SDValue foldToMaskedStore(StoreSDNode *Store, SelectionDAG &DAG,
5339353393 const SDLoc &Dl,
5339453394 const X86Subtarget &Subtarget) {
53395- using namespace llvm::SDPatternMatch;
53396-
5339753395 if (!Subtarget.hasAVX() && !Subtarget.hasAVX2() && !Subtarget.hasAVX512())
5339853396 return SDValue();
5339953397
53400- if (!Store->isSimple() || Store->isTruncatingStore() )
53398+ if (!Store->isSimple())
5340153399 return SDValue();
5340253400
5340353401 SDValue StoredVal = Store->getValue();
5340453402 SDValue StorePtr = Store->getBasePtr();
5340553403 SDValue StoreOffset = Store->getOffset();
53406- EVT VT = Store->getMemoryVT ();
53404+ EVT VT = StoredVal.getValueType ();
5340753405 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
5340853406
5340953407 if (!TLI.isTypeLegal(VT) || !TLI.isOperationLegalOrCustom(ISD::MSTORE, VT))
5341053408 return SDValue();
5341153409
53412- SDValue Mask, TrueVec, LoadCh;
53413- if (!sd_match(StoredVal,
53414- m_VSelect(m_Value(Mask), m_Value(TrueVec),
53415- m_Load(m_Value(LoadCh), m_Specific(StorePtr),
53416- m_Specific(StoreOffset)))))
53410+ if (StoredVal.getOpcode() != ISD::VSELECT)
5341753411 return SDValue();
5341853412
53419- LoadSDNode *Load = cast<LoadSDNode>(StoredVal.getOperand(2));
53413+ SDValue Mask = StoredVal.getOperand(0);
53414+ SDValue TrueVec = StoredVal.getOperand(1);
53415+ SDValue FalseVec = StoredVal.getOperand(2);
53416+
53417+ LoadSDNode *Load = cast<LoadSDNode>(FalseVec.getNode());
5342053418 if (!Load || !Load->isSimple())
5342153419 return SDValue();
5342253420
53421+ SDValue LoadPtr = Load->getBasePtr();
53422+ SDValue LoadOffset = Load->getOffset();
53423+
53424+ if (StorePtr != LoadPtr || StoreOffset != LoadOffset)
53425+ return SDValue();
53426+
5342353427 auto IsSafeToFold = [](StoreSDNode *Store, LoadSDNode *Load) {
5342453428 std::queue<SDValue> Worklist;
5342553429
@@ -53433,13 +53437,13 @@ static SDValue foldToMaskedStore(StoreSDNode *Store, SelectionDAG &DAG,
5343353437 if (!Node)
5343453438 return false;
5343553439
53436- if (Node == Load)
53437- return true;
53438-
5343953440 if (const auto *MemNode = dyn_cast<MemSDNode>(Node))
5344053441 if (!MemNode->isSimple() || MemNode->writeMem())
5344153442 return false;
5344253443
53444+ if (Node == Load)
53445+ return true;
53446+
5344353447 if (Node->getOpcode() == ISD::TokenFactor) {
5344453448 for (unsigned i = 0; i < Node->getNumOperands(); ++i)
5344553449 Worklist.push(Node->getOperand(i));
@@ -53455,8 +53459,8 @@ static SDValue foldToMaskedStore(StoreSDNode *Store, SelectionDAG &DAG,
5345553459 return SDValue();
5345653460
5345753461 return DAG.getMaskedStore(Store->getChain(), Dl, TrueVec, StorePtr,
53458- StoreOffset, Mask, VT, Store->getMemOperand (),
53459- Store->getAddressingMode());
53462+ StoreOffset, Mask, Store->getMemoryVT (),
53463+ Store->getMemOperand(), Store-> getAddressingMode());
5346053464}
5346153465
5346253466static SDValue combineStore(SDNode *N, SelectionDAG &DAG,
@@ -53723,9 +53727,6 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG,
5372353727 St->getMemOperand());
5372453728 }
5372553729
53726- if (SDValue MaskedStore = foldToMaskedStore(St, DAG, dl, Subtarget))
53727- return MaskedStore;
53728-
5372953730 // Turn load->store of MMX types into GPR load/stores. This avoids clobbering
5373053731 // the FP state in cases where an emms may be missing.
5373153732 // A preferable solution to the general problem is to figure out the right
@@ -53787,6 +53788,9 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG,
5378753788 St->getMemOperand()->getFlags());
5378853789 }
5378953790
53791+ if (SDValue MaskedStore = foldToMaskedStore(St, DAG, dl, Subtarget))
53792+ return MaskedStore;
53793+
5379053794 return SDValue();
5379153795}
5379253796
0 commit comments