@@ -238,18 +238,11 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
238238 return std::nullopt ;
239239 LLVM_FALLTHROUGH;
240240 case MVT::v2i8:
241- case MVT::v2i16:
242241 case MVT::v2i32:
243242 case MVT::v2i64:
244- case MVT::v2f16:
245- case MVT::v2bf16:
246243 case MVT::v2f32:
247244 case MVT::v2f64:
248- case MVT::v4i8:
249- case MVT::v4i16:
250245 case MVT::v4i32:
251- case MVT::v4f16:
252- case MVT::v4bf16:
253246 case MVT::v4f32:
254247 // This is a "native" vector type
255248 return std::pair (NumElts, EltVT);
@@ -262,6 +255,13 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
262255 if (!CanLowerTo256Bit)
263256 return std::nullopt ;
264257 LLVM_FALLTHROUGH;
258+ case MVT::v2i16: // <1 x i16x2>
259+ case MVT::v2f16: // <1 x f16x2>
260+ case MVT::v2bf16: // <1 x bf16x2>
261+ case MVT::v4i8: // <1 x i8x4>
262+ case MVT::v4i16: // <2 x i16x2>
263+ case MVT::v4f16: // <2 x f16x2>
264+ case MVT::v4bf16: // <2 x bf16x2>
265265 case MVT::v8i8: // <2 x i8x4>
266266 case MVT::v8f16: // <4 x f16x2>
267267 case MVT::v8bf16: // <4 x bf16x2>
@@ -845,7 +845,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
845845 // We have some custom DAG combine patterns for these nodes
846846 setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
847847 ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
848- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
848+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
849+ ISD::STORE});
849850
850851 // setcc for f16x2 and bf16x2 needs special handling to prevent
851852 // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -3464,19 +3465,16 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34643465 unsigned I = 0 ;
34653466 for (const unsigned NumElts : VectorInfo) {
34663467 const EVT EltVT = VTs[I];
3467- const EVT LoadVT = [&]() -> EVT {
3468- // i1 is loaded/stored as i8.
3469- if (EltVT == MVT::i1)
3470- return MVT::i8 ;
3471- // getLoad needs a vector type, but it can't handle
3472- // vectors which contain v2f16 or v2bf16 elements. So we must load
3473- // using i32 here and then bitcast back.
3474- if (EltVT.isVector ())
3475- return MVT::getIntegerVT (EltVT.getFixedSizeInBits ());
3476- return EltVT;
3477- }();
3468+ // i1 is loaded/stored as i8
3469+ const EVT LoadVT = EltVT == MVT::i1 ? MVT::i8 : EltVT;
3470+ // If the element is a packed type (ex. v2f16, v4i8, etc) holding
3471+ // multiple elements.
3472+ const unsigned PackingAmt =
3473+ LoadVT.isVector () ? LoadVT.getVectorNumElements () : 1 ;
3474+
3475+ const EVT VecVT = EVT::getVectorVT (
3476+ F->getContext (), LoadVT.getScalarType (), NumElts * PackingAmt);
34783477
3479- const EVT VecVT = EVT::getVectorVT (F->getContext (), LoadVT, NumElts);
34803478 SDValue VecAddr = DAG.getObjectPtrOffset (
34813479 dl, ArgSymbol, TypeSize::getFixed (Offsets[I]));
34823480
@@ -3496,8 +3494,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34963494 if (P.getNode ())
34973495 P.getNode ()->setIROrder (Arg.getArgNo () + 1 );
34983496 for (const unsigned J : llvm::seq (NumElts)) {
3499- SDValue Elt = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
3500- DAG.getIntPtrConstant (J, dl));
3497+ SDValue Elt = DAG.getNode (LoadVT.isVector () ? ISD::EXTRACT_SUBVECTOR
3498+ : ISD::EXTRACT_VECTOR_ELT,
3499+ dl, LoadVT, P,
3500+ DAG.getIntPtrConstant (J * PackingAmt, dl));
35013501
35023502 // Extend or truncate the element if necessary (e.g. an i8 is loaded
35033503 // into an i16 register)
@@ -3506,15 +3506,11 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
35063506 (ExpactedVT.isScalarInteger () &&
35073507 Elt.getValueType ().isScalarInteger ())) &&
35083508 " Non-integer argument type size mismatch" );
3509- if (ExpactedVT.bitsGT (Elt.getValueType ())) {
3509+ if (ExpactedVT.bitsGT (Elt.getValueType ()))
35103510 Elt = DAG.getNode (getExtOpcode (ArgIns[I + J].Flags ), dl, ExpactedVT,
35113511 Elt);
3512- } else if (ExpactedVT.bitsLT (Elt.getValueType ())) {
3512+ else if (ExpactedVT.bitsLT (Elt.getValueType ()))
35133513 Elt = DAG.getNode (ISD::TRUNCATE, dl, ExpactedVT, Elt);
3514- } else {
3515- // v2f16 was loaded as an i32. Now we must bitcast it back.
3516- Elt = DAG.getBitcast (EltVT, Elt);
3517- }
35183514 InVals.push_back (Elt);
35193515 }
35203516 I += NumElts;
@@ -5047,26 +5043,243 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
50475043 return SDValue ();
50485044}
50495045
5050- static SDValue PerformStoreCombineHelper (SDNode *N, std::size_t Front,
5051- std::size_t Back) {
5046+ // / Fold extractelts into a load by increasing the number of return values.
5047+ // /
5048+ // / ex:
5049+ // / L: v2f16,ch = load <p>
5050+ // / a: f16 = extractelt L:0, 0
5051+ // / b: f16 = extractelt L:0, 1
5052+ // / use(a, b)
5053+ // /
5054+ // / ...is turned into...
5055+ // / L: f16,f16,ch = LoadV2 <p>
5056+ // / use(L:0, L:1)
5057+ static SDValue
5058+ combineUnpackingMovIntoLoad (SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5059+ // Don't run this optimization before the legalizer
5060+ if (!DCI.isAfterLegalizeDAG ())
5061+ return SDValue ();
5062+
5063+ EVT ElemVT = N->getValueType (0 );
5064+ if (!Isv2x16VT (ElemVT))
5065+ return SDValue ();
5066+
5067+ // Check whether all outputs are either used by an extractelt or are
5068+ // glue/chain nodes
5069+ if (!all_of (N->uses (), [&](SDUse &U) {
5070+ // Skip glue, chain nodes
5071+ if (U.getValueType () == MVT::Glue || U.getValueType () == MVT::Other)
5072+ return true ;
5073+ if (U.getUser ()->getOpcode () == ISD::EXTRACT_VECTOR_ELT) {
5074+ if (N->getOpcode () != ISD::LOAD)
5075+ return true ;
5076+ // Since this is an ISD::LOAD, check all extractelts are used. If
5077+ // any are not used, we don't want to defeat another optimization that
5078+ // will narrow the load.
5079+ //
5080+ // For example:
5081+ //
5082+ // L: v2f16,ch = load <p>
5083+ // e0: f16 = extractelt L:0, 0
5084+ // e1: f16 = extractelt L:0, 1 <-- unused
5085+ // store e0
5086+ //
5087+ // Can be optimized by DAGCombiner to:
5088+ //
5089+ // L: f16,ch = load <p>
5090+ // store L:0
5091+ return !U.getUser ()->use_empty ();
5092+ }
5093+
5094+ // Otherwise, this use prevents us from splitting a value.
5095+ return false ;
5096+ }))
5097+ return SDValue ();
5098+
5099+ auto *LD = cast<MemSDNode>(N);
5100+ EVT MemVT = LD->getMemoryVT ();
5101+ SDLoc DL (LD);
5102+
5103+ // the new opcode after we double the number of operands
5104+ NVPTXISD::NodeType Opcode;
5105+ SmallVector<SDValue> Operands (LD->ops ());
5106+ unsigned OldNumOutputs; // non-glue, non-chain outputs
5107+ switch (LD->getOpcode ()) {
5108+ case ISD::LOAD:
5109+ OldNumOutputs = 1 ;
5110+ // Any packed type is legal, so the legalizer will not have lowered
5111+ // ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
5112+ // here.
5113+ Opcode = NVPTXISD::LoadV2;
5114+ Operands.push_back (DCI.DAG .getIntPtrConstant (
5115+ cast<LoadSDNode>(LD)->getExtensionType (), DL));
5116+ break ;
5117+ case NVPTXISD::LoadParamV2:
5118+ OldNumOutputs = 2 ;
5119+ Opcode = NVPTXISD::LoadParamV4;
5120+ break ;
5121+ case NVPTXISD::LoadV2:
5122+ OldNumOutputs = 2 ;
5123+ Opcode = NVPTXISD::LoadV4;
5124+ break ;
5125+ case NVPTXISD::LoadV4:
5126+ case NVPTXISD::LoadV8:
5127+ // PTX doesn't support the next doubling of outputs
5128+ return SDValue ();
5129+ }
5130+
5131+ // the non-glue, non-chain outputs in the new load
5132+ const unsigned NewNumOutputs = OldNumOutputs * 2 ;
5133+ SmallVector<EVT> NewVTs (NewNumOutputs, ElemVT.getVectorElementType ());
5134+ // add remaining chain and glue values
5135+ NewVTs.append (LD->value_begin () + OldNumOutputs, LD->value_end ());
5136+
5137+ // Create the new load
5138+ SDValue NewLoad =
5139+ DCI.DAG .getMemIntrinsicNode (Opcode, DL, DCI.DAG .getVTList (NewVTs),
5140+ Operands, MemVT, LD->getMemOperand ());
5141+
5142+ // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
5143+ // the outputs the same. These nodes will be optimized away in later
5144+ // DAGCombiner iterations.
5145+ SmallVector<SDValue> Results;
5146+ for (unsigned I : seq (OldNumOutputs))
5147+ Results.push_back (DCI.DAG .getBuildVector (
5148+ ElemVT, DL, {NewLoad.getValue (I * 2 ), NewLoad.getValue (I * 2 + 1 )}));
5149+ // Add remaining chain and glue nodes
5150+ for (unsigned I : seq (NewLoad->getNumValues () - NewNumOutputs))
5151+ Results.push_back (NewLoad.getValue (NewNumOutputs + I));
5152+
5153+ return DCI.DAG .getMergeValues (Results, DL);
5154+ }
5155+
5156+ // / Fold a packing mov into a store.
5157+ // /
5158+ // / ex:
5159+ // / v: v2f16 = BUILD_VECTOR a:f16, b:f16
5160+ // / StoreRetval v
5161+ // /
5162+ // / ...is turned into...
5163+ // /
5164+ // / StoreRetvalV2 a:f16, b:f16
5165+ static SDValue combinePackingMovIntoStore (SDNode *N,
5166+ TargetLowering::DAGCombinerInfo &DCI,
5167+ unsigned Front, unsigned Back) {
5168+ // We want to run this as late as possible since other optimizations may
5169+ // eliminate the BUILD_VECTORs.
5170+ if (!DCI.isAfterLegalizeDAG ())
5171+ return SDValue ();
5172+
5173+ // Get the type of the operands being stored.
5174+ EVT ElementVT = N->getOperand (Front).getValueType ();
5175+
5176+ if (!Isv2x16VT (ElementVT))
5177+ return SDValue ();
5178+
5179+ auto *ST = cast<MemSDNode>(N);
5180+ EVT MemVT = ElementVT.getVectorElementType ();
5181+
5182+ // The new opcode after we double the number of operands.
5183+ NVPTXISD::NodeType Opcode;
5184+ switch (N->getOpcode ()) {
5185+ case ISD::STORE:
5186+ // Any packed type is legal, so the legalizer will not have lowered
5187+ // ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
5188+ // it here.
5189+ MemVT = ST->getMemoryVT ();
5190+ Opcode = NVPTXISD::StoreV2;
5191+ break ;
5192+ case NVPTXISD::StoreParam:
5193+ Opcode = NVPTXISD::StoreParamV2;
5194+ break ;
5195+ case NVPTXISD::StoreParamV2:
5196+ Opcode = NVPTXISD::StoreParamV4;
5197+ break ;
5198+ case NVPTXISD::StoreRetval:
5199+ Opcode = NVPTXISD::StoreRetvalV2;
5200+ break ;
5201+ case NVPTXISD::StoreRetvalV2:
5202+ Opcode = NVPTXISD::StoreRetvalV4;
5203+ break ;
5204+ case NVPTXISD::StoreV2:
5205+ MemVT = ST->getMemoryVT ();
5206+ Opcode = NVPTXISD::StoreV4;
5207+ break ;
5208+ case NVPTXISD::StoreV4:
5209+ case NVPTXISD::StoreParamV4:
5210+ case NVPTXISD::StoreRetvalV4:
5211+ case NVPTXISD::StoreV8:
5212+ // PTX doesn't support the next doubling of operands
5213+ return SDValue ();
5214+ default :
5215+ llvm_unreachable (" Unhandled store opcode" );
5216+ }
5217+
5218+ // Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
5219+ // their elements.
5220+ SmallVector<SDValue, 4 > Operands (N->ops ().take_front (Front));
5221+ for (SDValue BV : N->ops ().drop_front (Front).drop_back (Back)) {
5222+ if (BV.getOpcode () != ISD::BUILD_VECTOR)
5223+ return SDValue ();
5224+
5225+ // If the operand has multiple uses, this optimization can increase register
5226+ // pressure.
5227+ if (!BV.hasOneUse ())
5228+ return SDValue ();
5229+
5230+ // DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
5231+ // any signs they may be folded by some other pattern or rule.
5232+ for (SDValue Op : BV->ops ()) {
5233+ // Peek through bitcasts
5234+ if (Op.getOpcode () == ISD::BITCAST)
5235+ Op = Op.getOperand (0 );
5236+
5237+ // This may be folded into a PRMT.
5238+ if (Op.getValueType () == MVT::i16 && Op.getOpcode () == ISD::TRUNCATE &&
5239+ Op->getOperand (0 ).getValueType () == MVT::i32 )
5240+ return SDValue ();
5241+
5242+ // This may be folded into cvt.bf16x2
5243+ if (Op.getOpcode () == ISD::FP_ROUND)
5244+ return SDValue ();
5245+ }
5246+ Operands.append ({BV.getOperand (0 ), BV.getOperand (1 )});
5247+ }
5248+ Operands.append (N->op_end () - Back, N->op_end ());
5249+
5250+ // Now we replace the store
5251+ return DCI.DAG .getMemIntrinsicNode (Opcode, SDLoc (N), N->getVTList (), Operands,
5252+ MemVT, ST->getMemOperand ());
5253+ }
5254+
5255+ static SDValue PerformStoreCombineHelper (SDNode *N,
5256+ TargetLowering::DAGCombinerInfo &DCI,
5257+ unsigned Front, unsigned Back) {
50525258 if (all_of (N->ops ().drop_front (Front).drop_back (Back),
50535259 [](const SDUse &U) { return U.get ()->isUndef (); }))
50545260 // Operand 0 is the previous value in the chain. Cannot return EntryToken
50555261 // as the previous value will become unused and eliminated later.
50565262 return N->getOperand (0 );
50575263
5058- return SDValue ();
5264+ return combinePackingMovIntoStore (N, DCI, Front, Back);
5265+ }
5266+
5267+ static SDValue PerformStoreCombine (SDNode *N,
5268+ TargetLowering::DAGCombinerInfo &DCI) {
5269+ return combinePackingMovIntoStore (N, DCI, 1 , 2 );
50595270}
50605271
5061- static SDValue PerformStoreParamCombine (SDNode *N) {
5272+ static SDValue PerformStoreParamCombine (SDNode *N,
5273+ TargetLowering::DAGCombinerInfo &DCI) {
50625274 // Operands from the 3rd to the 2nd last one are the values to be stored.
50635275 // {Chain, ArgID, Offset, Val, Glue}
5064- return PerformStoreCombineHelper (N, 3 , 1 );
5276+ return PerformStoreCombineHelper (N, DCI, 3 , 1 );
50655277}
50665278
5067- static SDValue PerformStoreRetvalCombine (SDNode *N) {
5279+ static SDValue PerformStoreRetvalCombine (SDNode *N,
5280+ TargetLowering::DAGCombinerInfo &DCI) {
50685281 // Operands from the 2nd to the last one are the values to be stored
5069- return PerformStoreCombineHelper (N, 2 , 0 );
5282+ return PerformStoreCombineHelper (N, DCI, 2 , 0 );
50705283}
50715284
50725285// / PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
@@ -5697,14 +5910,23 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
56975910 return PerformREMCombine (N, DCI, OptLevel);
56985911 case ISD::SETCC:
56995912 return PerformSETCCCombine (N, DCI, STI.getSmVersion ());
5913+ case ISD::LOAD:
5914+ case NVPTXISD::LoadParamV2:
5915+ case NVPTXISD::LoadV2:
5916+ case NVPTXISD::LoadV4:
5917+ return combineUnpackingMovIntoLoad (N, DCI);
57005918 case NVPTXISD::StoreRetval:
57015919 case NVPTXISD::StoreRetvalV2:
57025920 case NVPTXISD::StoreRetvalV4:
5703- return PerformStoreRetvalCombine (N);
5921+ return PerformStoreRetvalCombine (N, DCI );
57045922 case NVPTXISD::StoreParam:
57055923 case NVPTXISD::StoreParamV2:
57065924 case NVPTXISD::StoreParamV4:
5707- return PerformStoreParamCombine (N);
5925+ return PerformStoreParamCombine (N, DCI);
5926+ case ISD::STORE:
5927+ case NVPTXISD::StoreV2:
5928+ case NVPTXISD::StoreV4:
5929+ return PerformStoreCombine (N, DCI);
57085930 case ISD::EXTRACT_VECTOR_ELT:
57095931 return PerformEXTRACTCombine (N, DCI);
57105932 case ISD::VSELECT:
0 commit comments