@@ -238,18 +238,11 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
238238      return  std::nullopt ;
239239    LLVM_FALLTHROUGH;
240240  case  MVT::v2i8:
241-   case  MVT::v2i16:
242241  case  MVT::v2i32:
243242  case  MVT::v2i64:
244-   case  MVT::v2f16:
245-   case  MVT::v2bf16:
246243  case  MVT::v2f32:
247244  case  MVT::v2f64:
248-   case  MVT::v4i8:
249-   case  MVT::v4i16:
250245  case  MVT::v4i32:
251-   case  MVT::v4f16:
252-   case  MVT::v4bf16:
253246  case  MVT::v4f32:
254247    //  This is a "native" vector type
255248    return  std::pair (NumElts, EltVT);
@@ -262,6 +255,13 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
262255    if  (!CanLowerTo256Bit)
263256      return  std::nullopt ;
264257    LLVM_FALLTHROUGH;
258+   case  MVT::v2i16:  //  <1 x i16x2>
259+   case  MVT::v2f16:  //  <1 x f16x2>
260+   case  MVT::v2bf16: //  <1 x bf16x2>
261+   case  MVT::v4i8:   //  <1 x i8x4>
262+   case  MVT::v4i16:  //  <2 x i16x2>
263+   case  MVT::v4f16:  //  <2 x f16x2>
264+   case  MVT::v4bf16: //  <2 x bf16x2>
265265  case  MVT::v8i8:   //  <2 x i8x4>
266266  case  MVT::v8f16:  //  <4 x f16x2>
267267  case  MVT::v8bf16: //  <4 x bf16x2>
@@ -845,7 +845,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
845845  //  We have some custom DAG combine patterns for these nodes
846846  setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
847847                       ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
848-                        ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
848+                        ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
849+                        ISD::STORE});
849850
850851  //  setcc for f16x2 and bf16x2 needs special handling to prevent
851852  //  legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -3464,19 +3465,16 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34643465      unsigned  I = 0 ;
34653466      for  (const  unsigned  NumElts : VectorInfo) {
34663467        const  EVT EltVT = VTs[I];
3467-         const  EVT LoadVT = [&]() -> EVT {
3468-           //  i1 is loaded/stored as i8.
3469-           if  (EltVT == MVT::i1)
3470-             return  MVT::i8 ;
3471-           //  getLoad needs a vector type, but it can't handle
3472-           //  vectors which contain v2f16 or v2bf16 elements. So we must load
3473-           //  using i32 here and then bitcast back.
3474-           if  (EltVT.isVector ())
3475-             return  MVT::getIntegerVT (EltVT.getFixedSizeInBits ());
3476-           return  EltVT;
3477-         }();
3468+         //  i1 is loaded/stored as i8
3469+         const  EVT LoadVT = EltVT == MVT::i1 ? MVT::i8  : EltVT;
3470+         //  If the element is a packed type (ex. v2f16, v4i8, etc) holding
3471+         //  multiple elements.
3472+         const  unsigned  PackingAmt =
3473+             LoadVT.isVector () ? LoadVT.getVectorNumElements () : 1 ;
3474+ 
3475+         const  EVT VecVT = EVT::getVectorVT (
3476+             F->getContext (), LoadVT.getScalarType (), NumElts * PackingAmt);
34783477
3479-         const  EVT VecVT = EVT::getVectorVT (F->getContext (), LoadVT, NumElts);
34803478        SDValue VecAddr = DAG.getObjectPtrOffset (
34813479            dl, ArgSymbol, TypeSize::getFixed (Offsets[I]));
34823480
@@ -3496,8 +3494,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34963494        if  (P.getNode ())
34973495          P.getNode ()->setIROrder (Arg.getArgNo () + 1 );
34983496        for  (const  unsigned  J : llvm::seq (NumElts)) {
3499-           SDValue Elt = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
3500-                                     DAG.getIntPtrConstant (J, dl));
3497+           SDValue Elt = DAG.getNode (LoadVT.isVector () ? ISD::EXTRACT_SUBVECTOR
3498+                                                       : ISD::EXTRACT_VECTOR_ELT,
3499+                                     dl, LoadVT, P,
3500+                                     DAG.getIntPtrConstant (J * PackingAmt, dl));
35013501
35023502          //  Extend or truncate the element if necessary (e.g. an i8 is loaded
35033503          //  into an i16 register)
@@ -3506,15 +3506,11 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
35063506                  (ExpactedVT.isScalarInteger () &&
35073507                   Elt.getValueType ().isScalarInteger ())) &&
35083508                 " Non-integer argument type size mismatch"  );
3509-           if  (ExpactedVT.bitsGT (Elt.getValueType ())) { 
3509+           if  (ExpactedVT.bitsGT (Elt.getValueType ()))
35103510            Elt = DAG.getNode (getExtOpcode (ArgIns[I + J].Flags ), dl, ExpactedVT,
35113511                              Elt);
3512-           }  else  if  (ExpactedVT.bitsLT (Elt.getValueType ())) { 
3512+           else  if  (ExpactedVT.bitsLT (Elt.getValueType ()))
35133513            Elt = DAG.getNode (ISD::TRUNCATE, dl, ExpactedVT, Elt);
3514-           } else  {
3515-             //  v2f16 was loaded as an i32. Now we must bitcast it back.
3516-             Elt = DAG.getBitcast (EltVT, Elt);
3517-           }
35183514          InVals.push_back (Elt);
35193515        }
35203516        I += NumElts;
@@ -5047,26 +5043,243 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
50475043  return  SDValue ();
50485044}
50495045
5050- static  SDValue PerformStoreCombineHelper (SDNode *N, std::size_t  Front,
5051-                                          std::size_t  Back) {
5046+ // / Fold extractelts into a load by increasing the number of return values.
5047+ // /
5048+ // / ex:
5049+ // / L: v2f16,ch = load <p>
5050+ // / a: f16 = extractelt L:0, 0
5051+ // / b: f16 = extractelt L:0, 1
5052+ // / use(a, b)
5053+ // /
5054+ // / ...is turned into...
5055+ // / L: f16,f16,ch = LoadV2 <p>
5056+ // / use(L:0, L:1)
5057+ static  SDValue
5058+ combineUnpackingMovIntoLoad (SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5059+   //  Don't run this optimization before the legalizer
5060+   if  (!DCI.isAfterLegalizeDAG ())
5061+     return  SDValue ();
5062+ 
5063+   EVT ElemVT = N->getValueType (0 );
5064+   if  (!Isv2x16VT (ElemVT))
5065+     return  SDValue ();
5066+ 
5067+   //  Check whether all outputs are either used by an extractelt or are
5068+   //  glue/chain nodes
5069+   if  (!all_of (N->uses (), [&](SDUse &U) {
5070+         //  Skip glue, chain nodes
5071+         if  (U.getValueType () == MVT::Glue || U.getValueType () == MVT::Other)
5072+           return  true ;
5073+         if  (U.getUser ()->getOpcode () == ISD::EXTRACT_VECTOR_ELT) {
5074+           if  (N->getOpcode () != ISD::LOAD)
5075+             return  true ;
5076+           //  Since this is an ISD::LOAD, check all extractelts are used. If
5077+           //  any are not used, we don't want to defeat another optimization that
5078+           //  will narrow the load.
5079+           // 
5080+           //  For example:
5081+           // 
5082+           //  L: v2f16,ch = load <p>
5083+           //  e0: f16 = extractelt L:0, 0
5084+           //  e1: f16 = extractelt L:0, 1        <-- unused
5085+           //  store e0
5086+           // 
5087+           //  Can be optimized by DAGCombiner to:
5088+           // 
5089+           //  L: f16,ch = load <p>
5090+           //  store L:0
5091+           return  !U.getUser ()->use_empty ();
5092+         }
5093+ 
5094+         //  Otherwise, this use prevents us from splitting a value.
5095+         return  false ;
5096+       }))
5097+     return  SDValue ();
5098+ 
5099+   auto  *LD = cast<MemSDNode>(N);
5100+   EVT MemVT = LD->getMemoryVT ();
5101+   SDLoc DL (LD);
5102+ 
5103+   //  the new opcode after we double the number of operands
5104+   NVPTXISD::NodeType Opcode;
5105+   SmallVector<SDValue> Operands (LD->ops ());
5106+   unsigned  OldNumOutputs; //  non-glue, non-chain outputs
5107+   switch  (LD->getOpcode ()) {
5108+   case  ISD::LOAD:
5109+     OldNumOutputs = 1 ;
5110+     //  Any packed type is legal, so the legalizer will not have lowered
5111+     //  ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
5112+     //  here.
5113+     Opcode = NVPTXISD::LoadV2;
5114+     Operands.push_back (DCI.DAG .getIntPtrConstant (
5115+         cast<LoadSDNode>(LD)->getExtensionType (), DL));
5116+     break ;
5117+   case  NVPTXISD::LoadParamV2:
5118+     OldNumOutputs = 2 ;
5119+     Opcode = NVPTXISD::LoadParamV4;
5120+     break ;
5121+   case  NVPTXISD::LoadV2:
5122+     OldNumOutputs = 2 ;
5123+     Opcode = NVPTXISD::LoadV4;
5124+     break ;
5125+   case  NVPTXISD::LoadV4:
5126+   case  NVPTXISD::LoadV8:
5127+     //  PTX doesn't support the next doubling of outputs
5128+     return  SDValue ();
5129+   }
5130+ 
5131+   //  the non-glue, non-chain outputs in the new load
5132+   const  unsigned  NewNumOutputs = OldNumOutputs * 2 ;
5133+   SmallVector<EVT> NewVTs (NewNumOutputs, ElemVT.getVectorElementType ());
5134+   //  add remaining chain and glue values
5135+   NewVTs.append (LD->value_begin () + OldNumOutputs, LD->value_end ());
5136+ 
5137+   //  Create the new load
5138+   SDValue NewLoad =
5139+       DCI.DAG .getMemIntrinsicNode (Opcode, DL, DCI.DAG .getVTList (NewVTs),
5140+                                   Operands, MemVT, LD->getMemOperand ());
5141+ 
5142+   //  Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
5143+   //  the outputs the same. These nodes will be optimized away in later
5144+   //  DAGCombiner iterations.
5145+   SmallVector<SDValue> Results;
5146+   for  (unsigned  I : seq (OldNumOutputs))
5147+     Results.push_back (DCI.DAG .getBuildVector (
5148+         ElemVT, DL, {NewLoad.getValue (I * 2 ), NewLoad.getValue (I * 2  + 1 )}));
5149+   //  Add remaining chain and glue nodes
5150+   for  (unsigned  I : seq (NewLoad->getNumValues () - NewNumOutputs))
5151+     Results.push_back (NewLoad.getValue (NewNumOutputs + I));
5152+ 
5153+   return  DCI.DAG .getMergeValues (Results, DL);
5154+ }
5155+ 
5156+ // / Fold a packing mov into a store.
5157+ // /
5158+ // / ex:
5159+ // / v: v2f16 = BUILD_VECTOR a:f16, b:f16
5160+ // / StoreRetval v
5161+ // /
5162+ // / ...is turned into...
5163+ // /
5164+ // / StoreRetvalV2 a:f16, b:f16
5165+ static  SDValue combinePackingMovIntoStore (SDNode *N,
5166+                                           TargetLowering::DAGCombinerInfo &DCI,
5167+                                           unsigned  Front, unsigned  Back) {
5168+   //  We want to run this as late as possible since other optimizations may
5169+   //  eliminate the BUILD_VECTORs.
5170+   if  (!DCI.isAfterLegalizeDAG ())
5171+     return  SDValue ();
5172+ 
5173+   //  Get the type of the operands being stored.
5174+   EVT ElementVT = N->getOperand (Front).getValueType ();
5175+ 
5176+   if  (!Isv2x16VT (ElementVT))
5177+     return  SDValue ();
5178+ 
5179+   auto  *ST = cast<MemSDNode>(N);
5180+   EVT MemVT = ElementVT.getVectorElementType ();
5181+ 
5182+   //  The new opcode after we double the number of operands.
5183+   NVPTXISD::NodeType Opcode;
5184+   switch  (N->getOpcode ()) {
5185+   case  ISD::STORE:
5186+     //  Any packed type is legal, so the legalizer will not have lowered
5187+     //  ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
5188+     //  it here.
5189+     MemVT = ST->getMemoryVT ();
5190+     Opcode = NVPTXISD::StoreV2;
5191+     break ;
5192+   case  NVPTXISD::StoreParam:
5193+     Opcode = NVPTXISD::StoreParamV2;
5194+     break ;
5195+   case  NVPTXISD::StoreParamV2:
5196+     Opcode = NVPTXISD::StoreParamV4;
5197+     break ;
5198+   case  NVPTXISD::StoreRetval:
5199+     Opcode = NVPTXISD::StoreRetvalV2;
5200+     break ;
5201+   case  NVPTXISD::StoreRetvalV2:
5202+     Opcode = NVPTXISD::StoreRetvalV4;
5203+     break ;
5204+   case  NVPTXISD::StoreV2:
5205+     MemVT = ST->getMemoryVT ();
5206+     Opcode = NVPTXISD::StoreV4;
5207+     break ;
5208+   case  NVPTXISD::StoreV4:
5209+   case  NVPTXISD::StoreParamV4:
5210+   case  NVPTXISD::StoreRetvalV4:
5211+   case  NVPTXISD::StoreV8:
5212+     //  PTX doesn't support the next doubling of operands
5213+     return  SDValue ();
5214+   default :
5215+     llvm_unreachable (" Unhandled store opcode"  );
5216+   }
5217+ 
5218+   //  Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
5219+   //  their elements.
5220+   SmallVector<SDValue, 4 > Operands (N->ops ().take_front (Front));
5221+   for  (SDValue BV : N->ops ().drop_front (Front).drop_back (Back)) {
5222+     if  (BV.getOpcode () != ISD::BUILD_VECTOR)
5223+       return  SDValue ();
5224+ 
5225+     //  If the operand has multiple uses, this optimization can increase register
5226+     //  pressure.
5227+     if  (!BV.hasOneUse ())
5228+       return  SDValue ();
5229+ 
5230+     //  DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
5231+     //  any signs they may be folded by some other pattern or rule.
5232+     for  (SDValue Op : BV->ops ()) {
5233+       //  Peek through bitcasts
5234+       if  (Op.getOpcode () == ISD::BITCAST)
5235+         Op = Op.getOperand (0 );
5236+ 
5237+       //  This may be folded into a PRMT.
5238+       if  (Op.getValueType () == MVT::i16  && Op.getOpcode () == ISD::TRUNCATE &&
5239+           Op->getOperand (0 ).getValueType () == MVT::i32 )
5240+         return  SDValue ();
5241+ 
5242+       //  This may be folded into cvt.bf16x2
5243+       if  (Op.getOpcode () == ISD::FP_ROUND)
5244+         return  SDValue ();
5245+     }
5246+     Operands.append ({BV.getOperand (0 ), BV.getOperand (1 )});
5247+   }
5248+   Operands.append (N->op_end () - Back, N->op_end ());
5249+ 
5250+   //  Now we replace the store
5251+   return  DCI.DAG .getMemIntrinsicNode (Opcode, SDLoc (N), N->getVTList (), Operands,
5252+                                      MemVT, ST->getMemOperand ());
5253+ }
5254+ 
5255+ static  SDValue PerformStoreCombineHelper (SDNode *N,
5256+                                          TargetLowering::DAGCombinerInfo &DCI,
5257+                                          unsigned  Front, unsigned  Back) {
50525258  if  (all_of (N->ops ().drop_front (Front).drop_back (Back),
50535259             [](const  SDUse &U) { return  U.get ()->isUndef (); }))
50545260    //  Operand 0 is the previous value in the chain. Cannot return EntryToken
50555261    //  as the previous value will become unused and eliminated later.
50565262    return  N->getOperand (0 );
50575263
5058-   return  SDValue ();
5264+   return  combinePackingMovIntoStore (N, DCI, Front, Back);
5265+ }
5266+ 
5267+ static  SDValue PerformStoreCombine (SDNode *N,
5268+                                    TargetLowering::DAGCombinerInfo &DCI) {
5269+   return  combinePackingMovIntoStore (N, DCI, 1 , 2 );
50595270}
50605271
5061- static  SDValue PerformStoreParamCombine (SDNode *N) {
5272+ static  SDValue PerformStoreParamCombine (SDNode *N,
5273+                                         TargetLowering::DAGCombinerInfo &DCI) {
50625274  //  Operands from the 3rd to the 2nd last one are the values to be stored.
50635275  //    {Chain, ArgID, Offset, Val, Glue}
5064-   return  PerformStoreCombineHelper (N, 3 , 1 );
5276+   return  PerformStoreCombineHelper (N, DCI,  3 , 1 );
50655277}
50665278
5067- static  SDValue PerformStoreRetvalCombine (SDNode *N) {
5279+ static  SDValue PerformStoreRetvalCombine (SDNode *N,
5280+                                          TargetLowering::DAGCombinerInfo &DCI) {
50685281  //  Operands from the 2nd to the last one are the values to be stored
5069-   return  PerformStoreCombineHelper (N, 2 , 0 );
5282+   return  PerformStoreCombineHelper (N, DCI,  2 , 0 );
50705283}
50715284
50725285// / PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
@@ -5697,14 +5910,23 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
56975910      return  PerformREMCombine (N, DCI, OptLevel);
56985911    case  ISD::SETCC:
56995912      return  PerformSETCCCombine (N, DCI, STI.getSmVersion ());
5913+     case  ISD::LOAD:
5914+     case  NVPTXISD::LoadParamV2:
5915+     case  NVPTXISD::LoadV2:
5916+     case  NVPTXISD::LoadV4:
5917+       return  combineUnpackingMovIntoLoad (N, DCI);
57005918    case  NVPTXISD::StoreRetval:
57015919    case  NVPTXISD::StoreRetvalV2:
57025920    case  NVPTXISD::StoreRetvalV4:
5703-       return  PerformStoreRetvalCombine (N);
5921+       return  PerformStoreRetvalCombine (N, DCI );
57045922    case  NVPTXISD::StoreParam:
57055923    case  NVPTXISD::StoreParamV2:
57065924    case  NVPTXISD::StoreParamV4:
5707-       return  PerformStoreParamCombine (N);
5925+       return  PerformStoreParamCombine (N, DCI);
5926+     case  ISD::STORE:
5927+     case  NVPTXISD::StoreV2:
5928+     case  NVPTXISD::StoreV4:
5929+       return  PerformStoreCombine (N, DCI);
57085930    case  ISD::EXTRACT_VECTOR_ELT:
57095931      return  PerformEXTRACTCombine (N, DCI);
57105932    case  ISD::VSELECT:
0 commit comments