@@ -7275,11 +7275,26 @@ static EVT getPackedVectorTypeFromPredicateType(LLVMContext &Ctx, EVT PredVT,
72757275// / Return the EVT of the data associated to a memory operation in \p
72767276// / Root. If such EVT cannot be retrived, it returns an invalid EVT.
72777277static EVT getMemVTFromNode (LLVMContext &Ctx, SDNode *Root) {
7278- if (isa<MemSDNode>(Root))
7279- return cast<MemSDNode>(Root)->getMemoryVT ();
7278+ if (auto *MemIntr = dyn_cast<MemIntrinsicSDNode>(Root))
7279+ return MemIntr->getMemoryVT ();
7280+
7281+ if (isa<MemSDNode>(Root)) {
7282+ EVT MemVT = cast<MemSDNode>(Root)->getMemoryVT ();
7283+
7284+ EVT DataVT;
7285+ if (auto *Load = dyn_cast<LoadSDNode>(Root))
7286+ DataVT = Load->getValueType (0 );
7287+ else if (auto *Load = dyn_cast<MaskedLoadSDNode>(Root))
7288+ DataVT = Load->getValueType (0 );
7289+ else if (auto *Store = dyn_cast<StoreSDNode>(Root))
7290+ DataVT = Store->getValue ().getValueType ();
7291+ else if (auto *Store = dyn_cast<MaskedStoreSDNode>(Root))
7292+ DataVT = Store->getValue ().getValueType ();
7293+ else
7294+ llvm_unreachable (" Unexpected MemSDNode!" );
72807295
7281- if (isa<MemIntrinsicSDNode>(Root))
7282- return cast<MemIntrinsicSDNode>(Root)-> getMemoryVT ();
7296+ return DataVT. changeVectorElementType (MemVT. getVectorElementType ());
7297+ }
72837298
72847299 const unsigned Opcode = Root->getOpcode ();
72857300 // For custom ISD nodes, we have to look at them individually to extract the
@@ -7380,12 +7395,23 @@ bool AArch64DAGToDAGISel::SelectAddrModeIndexedSVE(SDNode *Root, SDValue N,
73807395 return false ;
73817396
73827397 SDValue VScale = N.getOperand (1 );
7383- if (VScale.getOpcode () != ISD::VSCALE)
7398+ int64_t MulImm = std::numeric_limits<int64_t >::max ();
7399+ if (VScale.getOpcode () == ISD::VSCALE) {
7400+ MulImm = cast<ConstantSDNode>(VScale.getOperand (0 ))->getSExtValue ();
7401+ } else if (auto C = dyn_cast<ConstantSDNode>(VScale)) {
7402+ int64_t ByteOffset = C->getSExtValue ();
7403+ const auto KnownVScale =
7404+ Subtarget->getSVEVectorSizeInBits () / AArch64::SVEBitsPerBlock;
7405+
7406+ if (!KnownVScale || ByteOffset % KnownVScale != 0 )
7407+ return false ;
7408+
7409+ MulImm = ByteOffset / KnownVScale;
7410+ } else
73847411 return false ;
73857412
73867413 TypeSize TS = MemVT.getSizeInBits ();
73877414 int64_t MemWidthBytes = static_cast <int64_t >(TS.getKnownMinValue ()) / 8 ;
7388- int64_t MulImm = cast<ConstantSDNode>(VScale.getOperand (0 ))->getSExtValue ();
73897415
73907416 if ((MulImm % MemWidthBytes) != 0 )
73917417 return false ;
0 commit comments