@@ -238,18 +238,11 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
238
238
return std::nullopt;
239
239
LLVM_FALLTHROUGH;
240
240
case MVT::v2i8:
241
- case MVT::v2i16:
242
241
case MVT::v2i32:
243
242
case MVT::v2i64:
244
- case MVT::v2f16:
245
- case MVT::v2bf16:
246
243
case MVT::v2f32:
247
244
case MVT::v2f64:
248
- case MVT::v4i8:
249
- case MVT::v4i16:
250
245
case MVT::v4i32:
251
- case MVT::v4f16:
252
- case MVT::v4bf16:
253
246
case MVT::v4f32:
254
247
// This is a "native" vector type
255
248
return std::pair (NumElts, EltVT);
@@ -262,6 +255,13 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
262
255
if (!CanLowerTo256Bit)
263
256
return std::nullopt;
264
257
LLVM_FALLTHROUGH;
258
+ case MVT::v2i16: // <1 x i16x2>
259
+ case MVT::v2f16: // <1 x f16x2>
260
+ case MVT::v2bf16: // <1 x bf16x2>
261
+ case MVT::v4i8: // <1 x i8x4>
262
+ case MVT::v4i16: // <2 x i16x2>
263
+ case MVT::v4f16: // <2 x f16x2>
264
+ case MVT::v4bf16: // <2 x bf16x2>
265
265
case MVT::v8i8: // <2 x i8x4>
266
266
case MVT::v8f16: // <4 x f16x2>
267
267
case MVT::v8bf16: // <4 x bf16x2>
@@ -845,7 +845,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
845
845
// We have some custom DAG combine patterns for these nodes
846
846
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
847
847
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
848
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
848
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
849
+ ISD::STORE});
849
850
850
851
// setcc for f16x2 and bf16x2 needs special handling to prevent
851
852
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -3464,19 +3465,16 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
3464
3465
unsigned I = 0 ;
3465
3466
for (const unsigned NumElts : VectorInfo) {
3466
3467
const EVT EltVT = VTs[I];
3467
- const EVT LoadVT = [&]() -> EVT {
3468
- // i1 is loaded/stored as i8.
3469
- if (EltVT == MVT::i1)
3470
- return MVT::i8 ;
3471
- // getLoad needs a vector type, but it can't handle
3472
- // vectors which contain v2f16 or v2bf16 elements. So we must load
3473
- // using i32 here and then bitcast back.
3474
- if (EltVT.isVector ())
3475
- return MVT::getIntegerVT (EltVT.getFixedSizeInBits ());
3476
- return EltVT;
3477
- }();
3468
+ // i1 is loaded/stored as i8
3469
+ const EVT LoadVT = EltVT == MVT::i1 ? MVT::i8 : EltVT;
3470
+ // If the element is a packed type (ex. v2f16, v4i8, etc) holding
3471
+ // multiple elements.
3472
+ const unsigned PackingAmt =
3473
+ LoadVT.isVector () ? LoadVT.getVectorNumElements () : 1 ;
3474
+
3475
+ const EVT VecVT = EVT::getVectorVT (
3476
+ F->getContext (), LoadVT.getScalarType (), NumElts * PackingAmt);
3478
3477
3479
- const EVT VecVT = EVT::getVectorVT (F->getContext (), LoadVT, NumElts);
3480
3478
SDValue VecAddr = DAG.getObjectPtrOffset (
3481
3479
dl, ArgSymbol, TypeSize::getFixed (Offsets[I]));
3482
3480
@@ -3496,8 +3494,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
3496
3494
if (P.getNode ())
3497
3495
P.getNode ()->setIROrder (Arg.getArgNo () + 1 );
3498
3496
for (const unsigned J : llvm::seq (NumElts)) {
3499
- SDValue Elt = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
3500
- DAG.getIntPtrConstant (J, dl));
3497
+ SDValue Elt = DAG.getNode (LoadVT.isVector () ? ISD::EXTRACT_SUBVECTOR
3498
+ : ISD::EXTRACT_VECTOR_ELT,
3499
+ dl, LoadVT, P,
3500
+ DAG.getIntPtrConstant (J * PackingAmt, dl));
3501
3501
3502
3502
// Extend or truncate the element if necessary (e.g. an i8 is loaded
3503
3503
// into an i16 register)
@@ -3506,15 +3506,11 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
3506
3506
(ExpactedVT.isScalarInteger () &&
3507
3507
Elt.getValueType ().isScalarInteger ())) &&
3508
3508
" Non-integer argument type size mismatch" );
3509
- if (ExpactedVT.bitsGT (Elt.getValueType ())) {
3509
+ if (ExpactedVT.bitsGT (Elt.getValueType ()))
3510
3510
Elt = DAG.getNode (getExtOpcode (ArgIns[I + J].Flags ), dl, ExpactedVT,
3511
3511
Elt);
3512
- } else if (ExpactedVT.bitsLT (Elt.getValueType ())) {
3512
+ else if (ExpactedVT.bitsLT (Elt.getValueType ()))
3513
3513
Elt = DAG.getNode (ISD::TRUNCATE, dl, ExpactedVT, Elt);
3514
- } else {
3515
- // v2f16 was loaded as an i32. Now we must bitcast it back.
3516
- Elt = DAG.getBitcast (EltVT, Elt);
3517
- }
3518
3514
InVals.push_back (Elt);
3519
3515
}
3520
3516
I += NumElts;
@@ -5047,26 +5043,243 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
5047
5043
return SDValue ();
5048
5044
}
5049
5045
5050
- static SDValue PerformStoreCombineHelper (SDNode *N, std::size_t Front,
5051
- std::size_t Back) {
5046
+ // / Fold extractelts into a load by increasing the number of return values.
5047
+ // /
5048
+ // / ex:
5049
+ // / L: v2f16,ch = load <p>
5050
+ // / a: f16 = extractelt L:0, 0
5051
+ // / b: f16 = extractelt L:0, 1
5052
+ // / use(a, b)
5053
+ // /
5054
+ // / ...is turned into...
5055
+ // / L: f16,f16,ch = LoadV2 <p>
5056
+ // / use(L:0, L:1)
5057
+ static SDValue
5058
+ combineUnpackingMovIntoLoad (SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5059
+ // Don't run this optimization before the legalizer
5060
+ if (!DCI.isAfterLegalizeDAG ())
5061
+ return SDValue ();
5062
+
5063
+ EVT ElemVT = N->getValueType (0 );
5064
+ if (!Isv2x16VT (ElemVT))
5065
+ return SDValue ();
5066
+
5067
+ // Check whether all outputs are either used by an extractelt or are
5068
+ // glue/chain nodes
5069
+ if (!all_of (N->uses (), [&](SDUse &U) {
5070
+ // Skip glue, chain nodes
5071
+ if (U.getValueType () == MVT::Glue || U.getValueType () == MVT::Other)
5072
+ return true ;
5073
+ if (U.getUser ()->getOpcode () == ISD::EXTRACT_VECTOR_ELT) {
5074
+ if (N->getOpcode () != ISD::LOAD)
5075
+ return true ;
5076
+ // Since this is an ISD::LOAD, check all extractelts are used. If
5077
+ // any are not used, we don't want to defeat another optimization that
5078
+ // will narrow the load.
5079
+ //
5080
+ // For example:
5081
+ //
5082
+ // L: v2f16,ch = load <p>
5083
+ // e0: f16 = extractelt L:0, 0
5084
+ // e1: f16 = extractelt L:0, 1 <-- unused
5085
+ // store e0
5086
+ //
5087
+ // Can be optimized by DAGCombiner to:
5088
+ //
5089
+ // L: f16,ch = load <p>
5090
+ // store L:0
5091
+ return !U.getUser ()->use_empty ();
5092
+ }
5093
+
5094
+ // Otherwise, this use prevents us from splitting a value.
5095
+ return false ;
5096
+ }))
5097
+ return SDValue ();
5098
+
5099
+ auto *LD = cast<MemSDNode>(N);
5100
+ EVT MemVT = LD->getMemoryVT ();
5101
+ SDLoc DL (LD);
5102
+
5103
+ // the new opcode after we double the number of operands
5104
+ NVPTXISD::NodeType Opcode;
5105
+ SmallVector<SDValue> Operands (LD->ops ());
5106
+ unsigned OldNumOutputs; // non-glue, non-chain outputs
5107
+ switch (LD->getOpcode ()) {
5108
+ case ISD::LOAD:
5109
+ OldNumOutputs = 1 ;
5110
+ // Any packed type is legal, so the legalizer will not have lowered
5111
+ // ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
5112
+ // here.
5113
+ Opcode = NVPTXISD::LoadV2;
5114
+ Operands.push_back (DCI.DAG .getIntPtrConstant (
5115
+ cast<LoadSDNode>(LD)->getExtensionType (), DL));
5116
+ break ;
5117
+ case NVPTXISD::LoadParamV2:
5118
+ OldNumOutputs = 2 ;
5119
+ Opcode = NVPTXISD::LoadParamV4;
5120
+ break ;
5121
+ case NVPTXISD::LoadV2:
5122
+ OldNumOutputs = 2 ;
5123
+ Opcode = NVPTXISD::LoadV4;
5124
+ break ;
5125
+ case NVPTXISD::LoadV4:
5126
+ case NVPTXISD::LoadV8:
5127
+ // PTX doesn't support the next doubling of outputs
5128
+ return SDValue ();
5129
+ }
5130
+
5131
+ // the non-glue, non-chain outputs in the new load
5132
+ const unsigned NewNumOutputs = OldNumOutputs * 2 ;
5133
+ SmallVector<EVT> NewVTs (NewNumOutputs, ElemVT.getVectorElementType ());
5134
+ // add remaining chain and glue values
5135
+ NewVTs.append (LD->value_begin () + OldNumOutputs, LD->value_end ());
5136
+
5137
+ // Create the new load
5138
+ SDValue NewLoad =
5139
+ DCI.DAG .getMemIntrinsicNode (Opcode, DL, DCI.DAG .getVTList (NewVTs),
5140
+ Operands, MemVT, LD->getMemOperand ());
5141
+
5142
+ // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
5143
+ // the outputs the same. These nodes will be optimized away in later
5144
+ // DAGCombiner iterations.
5145
+ SmallVector<SDValue> Results;
5146
+ for (unsigned I : seq (OldNumOutputs))
5147
+ Results.push_back (DCI.DAG .getBuildVector (
5148
+ ElemVT, DL, {NewLoad.getValue (I * 2 ), NewLoad.getValue (I * 2 + 1 )}));
5149
+ // Add remaining chain and glue nodes
5150
+ for (unsigned I : seq (NewLoad->getNumValues () - NewNumOutputs))
5151
+ Results.push_back (NewLoad.getValue (NewNumOutputs + I));
5152
+
5153
+ return DCI.DAG .getMergeValues (Results, DL);
5154
+ }
5155
+
5156
+ // / Fold a packing mov into a store.
5157
+ // /
5158
+ // / ex:
5159
+ // / v: v2f16 = BUILD_VECTOR a:f16, b:f16
5160
+ // / StoreRetval v
5161
+ // /
5162
+ // / ...is turned into...
5163
+ // /
5164
+ // / StoreRetvalV2 a:f16, b:f16
5165
+ static SDValue combinePackingMovIntoStore (SDNode *N,
5166
+ TargetLowering::DAGCombinerInfo &DCI,
5167
+ unsigned Front, unsigned Back) {
5168
+ // We want to run this as late as possible since other optimizations may
5169
+ // eliminate the BUILD_VECTORs.
5170
+ if (!DCI.isAfterLegalizeDAG ())
5171
+ return SDValue ();
5172
+
5173
+ // Get the type of the operands being stored.
5174
+ EVT ElementVT = N->getOperand (Front).getValueType ();
5175
+
5176
+ if (!Isv2x16VT (ElementVT))
5177
+ return SDValue ();
5178
+
5179
+ auto *ST = cast<MemSDNode>(N);
5180
+ EVT MemVT = ElementVT.getVectorElementType ();
5181
+
5182
+ // The new opcode after we double the number of operands.
5183
+ NVPTXISD::NodeType Opcode;
5184
+ switch (N->getOpcode ()) {
5185
+ case ISD::STORE:
5186
+ // Any packed type is legal, so the legalizer will not have lowered
5187
+ // ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
5188
+ // it here.
5189
+ MemVT = ST->getMemoryVT ();
5190
+ Opcode = NVPTXISD::StoreV2;
5191
+ break ;
5192
+ case NVPTXISD::StoreParam:
5193
+ Opcode = NVPTXISD::StoreParamV2;
5194
+ break ;
5195
+ case NVPTXISD::StoreParamV2:
5196
+ Opcode = NVPTXISD::StoreParamV4;
5197
+ break ;
5198
+ case NVPTXISD::StoreRetval:
5199
+ Opcode = NVPTXISD::StoreRetvalV2;
5200
+ break ;
5201
+ case NVPTXISD::StoreRetvalV2:
5202
+ Opcode = NVPTXISD::StoreRetvalV4;
5203
+ break ;
5204
+ case NVPTXISD::StoreV2:
5205
+ MemVT = ST->getMemoryVT ();
5206
+ Opcode = NVPTXISD::StoreV4;
5207
+ break ;
5208
+ case NVPTXISD::StoreV4:
5209
+ case NVPTXISD::StoreParamV4:
5210
+ case NVPTXISD::StoreRetvalV4:
5211
+ case NVPTXISD::StoreV8:
5212
+ // PTX doesn't support the next doubling of operands
5213
+ return SDValue ();
5214
+ default :
5215
+ llvm_unreachable (" Unhandled store opcode" );
5216
+ }
5217
+
5218
+ // Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
5219
+ // their elements.
5220
+ SmallVector<SDValue, 4 > Operands (N->ops ().take_front (Front));
5221
+ for (SDValue BV : N->ops ().drop_front (Front).drop_back (Back)) {
5222
+ if (BV.getOpcode () != ISD::BUILD_VECTOR)
5223
+ return SDValue ();
5224
+
5225
+ // If the operand has multiple uses, this optimization can increase register
5226
+ // pressure.
5227
+ if (!BV.hasOneUse ())
5228
+ return SDValue ();
5229
+
5230
+ // DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
5231
+ // any signs they may be folded by some other pattern or rule.
5232
+ for (SDValue Op : BV->ops ()) {
5233
+ // Peek through bitcasts
5234
+ if (Op.getOpcode () == ISD::BITCAST)
5235
+ Op = Op.getOperand (0 );
5236
+
5237
+ // This may be folded into a PRMT.
5238
+ if (Op.getValueType () == MVT::i16 && Op.getOpcode () == ISD::TRUNCATE &&
5239
+ Op->getOperand (0 ).getValueType () == MVT::i32 )
5240
+ return SDValue ();
5241
+
5242
+ // This may be folded into cvt.bf16x2
5243
+ if (Op.getOpcode () == ISD::FP_ROUND)
5244
+ return SDValue ();
5245
+ }
5246
+ Operands.append ({BV.getOperand (0 ), BV.getOperand (1 )});
5247
+ }
5248
+ Operands.append (N->op_end () - Back, N->op_end ());
5249
+
5250
+ // Now we replace the store
5251
+ return DCI.DAG .getMemIntrinsicNode (Opcode, SDLoc (N), N->getVTList (), Operands,
5252
+ MemVT, ST->getMemOperand ());
5253
+ }
5254
+
5255
+ static SDValue PerformStoreCombineHelper (SDNode *N,
5256
+ TargetLowering::DAGCombinerInfo &DCI,
5257
+ unsigned Front, unsigned Back) {
5052
5258
if (all_of (N->ops ().drop_front (Front).drop_back (Back),
5053
5259
[](const SDUse &U) { return U.get ()->isUndef (); }))
5054
5260
// Operand 0 is the previous value in the chain. Cannot return EntryToken
5055
5261
// as the previous value will become unused and eliminated later.
5056
5262
return N->getOperand (0 );
5057
5263
5058
- return SDValue ();
5264
+ return combinePackingMovIntoStore (N, DCI, Front, Back);
5265
+ }
5266
+
5267
+ static SDValue PerformStoreCombine (SDNode *N,
5268
+ TargetLowering::DAGCombinerInfo &DCI) {
5269
+ return combinePackingMovIntoStore (N, DCI, 1 , 2 );
5059
5270
}
5060
5271
5061
- static SDValue PerformStoreParamCombine (SDNode *N) {
5272
+ static SDValue PerformStoreParamCombine (SDNode *N,
5273
+ TargetLowering::DAGCombinerInfo &DCI) {
5062
5274
// Operands from the 3rd to the 2nd last one are the values to be stored.
5063
5275
// {Chain, ArgID, Offset, Val, Glue}
5064
- return PerformStoreCombineHelper (N, 3 , 1 );
5276
+ return PerformStoreCombineHelper (N, DCI, 3 , 1 );
5065
5277
}
5066
5278
5067
- static SDValue PerformStoreRetvalCombine (SDNode *N) {
5279
+ static SDValue PerformStoreRetvalCombine (SDNode *N,
5280
+ TargetLowering::DAGCombinerInfo &DCI) {
5068
5281
// Operands from the 2nd to the last one are the values to be stored
5069
- return PerformStoreCombineHelper (N, 2 , 0 );
5282
+ return PerformStoreCombineHelper (N, DCI, 2 , 0 );
5070
5283
}
5071
5284
5072
5285
// / PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
@@ -5697,14 +5910,23 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5697
5910
return PerformREMCombine (N, DCI, OptLevel);
5698
5911
case ISD::SETCC:
5699
5912
return PerformSETCCCombine (N, DCI, STI.getSmVersion ());
5913
+ case ISD::LOAD:
5914
+ case NVPTXISD::LoadParamV2:
5915
+ case NVPTXISD::LoadV2:
5916
+ case NVPTXISD::LoadV4:
5917
+ return combineUnpackingMovIntoLoad (N, DCI);
5700
5918
case NVPTXISD::StoreRetval:
5701
5919
case NVPTXISD::StoreRetvalV2:
5702
5920
case NVPTXISD::StoreRetvalV4:
5703
- return PerformStoreRetvalCombine (N);
5921
+ return PerformStoreRetvalCombine (N, DCI );
5704
5922
case NVPTXISD::StoreParam:
5705
5923
case NVPTXISD::StoreParamV2:
5706
5924
case NVPTXISD::StoreParamV4:
5707
- return PerformStoreParamCombine (N);
5925
+ return PerformStoreParamCombine (N, DCI);
5926
+ case ISD::STORE:
5927
+ case NVPTXISD::StoreV2:
5928
+ case NVPTXISD::StoreV4:
5929
+ return PerformStoreCombine (N, DCI);
5708
5930
case ISD::EXTRACT_VECTOR_ELT:
5709
5931
return PerformEXTRACTCombine (N, DCI);
5710
5932
case ISD::VSELECT:
0 commit comments