Skip to content

Commit e1cd450

Browse files
authored
[NVPTX] fold movs into loads and stores (#144581)
Fold movs into loads and stores by increasing the number of return values or operands. For example: ``` L: v2f16,ch = Load [p] e0 = extractelt L, 0 e1 = extractelt L, 1 consume(e0, e1) ``` ...becomes... ``` L: f16,f16,ch = LoadV2 [p] consume(L:0, L:1) ```
1 parent 3187d4d commit e1cd450

23 files changed

+2978
-2318
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 259 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -238,18 +238,11 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
238238
return std::nullopt;
239239
LLVM_FALLTHROUGH;
240240
case MVT::v2i8:
241-
case MVT::v2i16:
242241
case MVT::v2i32:
243242
case MVT::v2i64:
244-
case MVT::v2f16:
245-
case MVT::v2bf16:
246243
case MVT::v2f32:
247244
case MVT::v2f64:
248-
case MVT::v4i8:
249-
case MVT::v4i16:
250245
case MVT::v4i32:
251-
case MVT::v4f16:
252-
case MVT::v4bf16:
253246
case MVT::v4f32:
254247
// This is a "native" vector type
255248
return std::pair(NumElts, EltVT);
@@ -262,6 +255,13 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
262255
if (!CanLowerTo256Bit)
263256
return std::nullopt;
264257
LLVM_FALLTHROUGH;
258+
case MVT::v2i16: // <1 x i16x2>
259+
case MVT::v2f16: // <1 x f16x2>
260+
case MVT::v2bf16: // <1 x bf16x2>
261+
case MVT::v4i8: // <1 x i8x4>
262+
case MVT::v4i16: // <2 x i16x2>
263+
case MVT::v4f16: // <2 x f16x2>
264+
case MVT::v4bf16: // <2 x bf16x2>
265265
case MVT::v8i8: // <2 x i8x4>
266266
case MVT::v8f16: // <4 x f16x2>
267267
case MVT::v8bf16: // <4 x bf16x2>
@@ -845,7 +845,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
845845
// We have some custom DAG combine patterns for these nodes
846846
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
847847
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
848-
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
848+
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
849+
ISD::STORE});
849850

850851
// setcc for f16x2 and bf16x2 needs special handling to prevent
851852
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -3464,19 +3465,16 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34643465
unsigned I = 0;
34653466
for (const unsigned NumElts : VectorInfo) {
34663467
const EVT EltVT = VTs[I];
3467-
const EVT LoadVT = [&]() -> EVT {
3468-
// i1 is loaded/stored as i8.
3469-
if (EltVT == MVT::i1)
3470-
return MVT::i8;
3471-
// getLoad needs a vector type, but it can't handle
3472-
// vectors which contain v2f16 or v2bf16 elements. So we must load
3473-
// using i32 here and then bitcast back.
3474-
if (EltVT.isVector())
3475-
return MVT::getIntegerVT(EltVT.getFixedSizeInBits());
3476-
return EltVT;
3477-
}();
3468+
// i1 is loaded/stored as i8
3469+
const EVT LoadVT = EltVT == MVT::i1 ? MVT::i8 : EltVT;
3470+
// If the element is a packed type (ex. v2f16, v4i8, etc) holding
3471+
// multiple elements.
3472+
const unsigned PackingAmt =
3473+
LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
3474+
3475+
const EVT VecVT = EVT::getVectorVT(
3476+
F->getContext(), LoadVT.getScalarType(), NumElts * PackingAmt);
34783477

3479-
const EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
34803478
SDValue VecAddr = DAG.getObjectPtrOffset(
34813479
dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
34823480

@@ -3496,8 +3494,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34963494
if (P.getNode())
34973495
P.getNode()->setIROrder(Arg.getArgNo() + 1);
34983496
for (const unsigned J : llvm::seq(NumElts)) {
3499-
SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
3500-
DAG.getIntPtrConstant(J, dl));
3497+
SDValue Elt = DAG.getNode(LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
3498+
: ISD::EXTRACT_VECTOR_ELT,
3499+
dl, LoadVT, P,
3500+
DAG.getIntPtrConstant(J * PackingAmt, dl));
35013501

35023502
// Extend or truncate the element if necessary (e.g. an i8 is loaded
35033503
// into an i16 register)
@@ -3506,15 +3506,11 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
35063506
(ExpactedVT.isScalarInteger() &&
35073507
Elt.getValueType().isScalarInteger())) &&
35083508
"Non-integer argument type size mismatch");
3509-
if (ExpactedVT.bitsGT(Elt.getValueType())) {
3509+
if (ExpactedVT.bitsGT(Elt.getValueType()))
35103510
Elt = DAG.getNode(getExtOpcode(ArgIns[I + J].Flags), dl, ExpactedVT,
35113511
Elt);
3512-
} else if (ExpactedVT.bitsLT(Elt.getValueType())) {
3512+
else if (ExpactedVT.bitsLT(Elt.getValueType()))
35133513
Elt = DAG.getNode(ISD::TRUNCATE, dl, ExpactedVT, Elt);
3514-
} else {
3515-
// v2f16 was loaded as an i32. Now we must bitcast it back.
3516-
Elt = DAG.getBitcast(EltVT, Elt);
3517-
}
35183514
InVals.push_back(Elt);
35193515
}
35203516
I += NumElts;
@@ -5047,26 +5043,243 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
50475043
return SDValue();
50485044
}
50495045

5050-
static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front,
5051-
std::size_t Back) {
5046+
/// Fold extractelts into a load by increasing the number of return values.
5047+
///
5048+
/// ex:
5049+
/// L: v2f16,ch = load <p>
5050+
/// a: f16 = extractelt L:0, 0
5051+
/// b: f16 = extractelt L:0, 1
5052+
/// use(a, b)
5053+
///
5054+
/// ...is turned into...
5055+
/// L: f16,f16,ch = LoadV2 <p>
5056+
/// use(L:0, L:1)
5057+
static SDValue
5058+
combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5059+
// Don't run this optimization before the legalizer
5060+
if (!DCI.isAfterLegalizeDAG())
5061+
return SDValue();
5062+
5063+
EVT ElemVT = N->getValueType(0);
5064+
if (!Isv2x16VT(ElemVT))
5065+
return SDValue();
5066+
5067+
// Check whether all outputs are either used by an extractelt or are
5068+
// glue/chain nodes
5069+
if (!all_of(N->uses(), [&](SDUse &U) {
5070+
// Skip glue, chain nodes
5071+
if (U.getValueType() == MVT::Glue || U.getValueType() == MVT::Other)
5072+
return true;
5073+
if (U.getUser()->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
5074+
if (N->getOpcode() != ISD::LOAD)
5075+
return true;
5076+
// Since this is an ISD::LOAD, check all extractelts are used. If
5077+
// any are not used, we don't want to defeat another optimization that
5078+
// will narrow the load.
5079+
//
5080+
// For example:
5081+
//
5082+
// L: v2f16,ch = load <p>
5083+
// e0: f16 = extractelt L:0, 0
5084+
// e1: f16 = extractelt L:0, 1 <-- unused
5085+
// store e0
5086+
//
5087+
// Can be optimized by DAGCombiner to:
5088+
//
5089+
// L: f16,ch = load <p>
5090+
// store L:0
5091+
return !U.getUser()->use_empty();
5092+
}
5093+
5094+
// Otherwise, this use prevents us from splitting a value.
5095+
return false;
5096+
}))
5097+
return SDValue();
5098+
5099+
auto *LD = cast<MemSDNode>(N);
5100+
EVT MemVT = LD->getMemoryVT();
5101+
SDLoc DL(LD);
5102+
5103+
// the new opcode after we double the number of operands
5104+
NVPTXISD::NodeType Opcode;
5105+
SmallVector<SDValue> Operands(LD->ops());
5106+
unsigned OldNumOutputs; // non-glue, non-chain outputs
5107+
switch (LD->getOpcode()) {
5108+
case ISD::LOAD:
5109+
OldNumOutputs = 1;
5110+
// Any packed type is legal, so the legalizer will not have lowered
5111+
// ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
5112+
// here.
5113+
Opcode = NVPTXISD::LoadV2;
5114+
Operands.push_back(DCI.DAG.getIntPtrConstant(
5115+
cast<LoadSDNode>(LD)->getExtensionType(), DL));
5116+
break;
5117+
case NVPTXISD::LoadParamV2:
5118+
OldNumOutputs = 2;
5119+
Opcode = NVPTXISD::LoadParamV4;
5120+
break;
5121+
case NVPTXISD::LoadV2:
5122+
OldNumOutputs = 2;
5123+
Opcode = NVPTXISD::LoadV4;
5124+
break;
5125+
case NVPTXISD::LoadV4:
5126+
case NVPTXISD::LoadV8:
5127+
// PTX doesn't support the next doubling of outputs
5128+
return SDValue();
5129+
}
5130+
5131+
// the non-glue, non-chain outputs in the new load
5132+
const unsigned NewNumOutputs = OldNumOutputs * 2;
5133+
SmallVector<EVT> NewVTs(NewNumOutputs, ElemVT.getVectorElementType());
5134+
// add remaining chain and glue values
5135+
NewVTs.append(LD->value_begin() + OldNumOutputs, LD->value_end());
5136+
5137+
// Create the new load
5138+
SDValue NewLoad =
5139+
DCI.DAG.getMemIntrinsicNode(Opcode, DL, DCI.DAG.getVTList(NewVTs),
5140+
Operands, MemVT, LD->getMemOperand());
5141+
5142+
// Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
5143+
// the outputs the same. These nodes will be optimized away in later
5144+
// DAGCombiner iterations.
5145+
SmallVector<SDValue> Results;
5146+
for (unsigned I : seq(OldNumOutputs))
5147+
Results.push_back(DCI.DAG.getBuildVector(
5148+
ElemVT, DL, {NewLoad.getValue(I * 2), NewLoad.getValue(I * 2 + 1)}));
5149+
// Add remaining chain and glue nodes
5150+
for (unsigned I : seq(NewLoad->getNumValues() - NewNumOutputs))
5151+
Results.push_back(NewLoad.getValue(NewNumOutputs + I));
5152+
5153+
return DCI.DAG.getMergeValues(Results, DL);
5154+
}
5155+
5156+
/// Fold a packing mov into a store.
5157+
///
5158+
/// ex:
5159+
/// v: v2f16 = BUILD_VECTOR a:f16, b:f16
5160+
/// StoreRetval v
5161+
///
5162+
/// ...is turned into...
5163+
///
5164+
/// StoreRetvalV2 a:f16, b:f16
5165+
static SDValue combinePackingMovIntoStore(SDNode *N,
5166+
TargetLowering::DAGCombinerInfo &DCI,
5167+
unsigned Front, unsigned Back) {
5168+
// We want to run this as late as possible since other optimizations may
5169+
// eliminate the BUILD_VECTORs.
5170+
if (!DCI.isAfterLegalizeDAG())
5171+
return SDValue();
5172+
5173+
// Get the type of the operands being stored.
5174+
EVT ElementVT = N->getOperand(Front).getValueType();
5175+
5176+
if (!Isv2x16VT(ElementVT))
5177+
return SDValue();
5178+
5179+
auto *ST = cast<MemSDNode>(N);
5180+
EVT MemVT = ElementVT.getVectorElementType();
5181+
5182+
// The new opcode after we double the number of operands.
5183+
NVPTXISD::NodeType Opcode;
5184+
switch (N->getOpcode()) {
5185+
case ISD::STORE:
5186+
// Any packed type is legal, so the legalizer will not have lowered
5187+
// ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
5188+
// it here.
5189+
MemVT = ST->getMemoryVT();
5190+
Opcode = NVPTXISD::StoreV2;
5191+
break;
5192+
case NVPTXISD::StoreParam:
5193+
Opcode = NVPTXISD::StoreParamV2;
5194+
break;
5195+
case NVPTXISD::StoreParamV2:
5196+
Opcode = NVPTXISD::StoreParamV4;
5197+
break;
5198+
case NVPTXISD::StoreRetval:
5199+
Opcode = NVPTXISD::StoreRetvalV2;
5200+
break;
5201+
case NVPTXISD::StoreRetvalV2:
5202+
Opcode = NVPTXISD::StoreRetvalV4;
5203+
break;
5204+
case NVPTXISD::StoreV2:
5205+
MemVT = ST->getMemoryVT();
5206+
Opcode = NVPTXISD::StoreV4;
5207+
break;
5208+
case NVPTXISD::StoreV4:
5209+
case NVPTXISD::StoreParamV4:
5210+
case NVPTXISD::StoreRetvalV4:
5211+
case NVPTXISD::StoreV8:
5212+
// PTX doesn't support the next doubling of operands
5213+
return SDValue();
5214+
default:
5215+
llvm_unreachable("Unhandled store opcode");
5216+
}
5217+
5218+
// Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
5219+
// their elements.
5220+
SmallVector<SDValue, 4> Operands(N->ops().take_front(Front));
5221+
for (SDValue BV : N->ops().drop_front(Front).drop_back(Back)) {
5222+
if (BV.getOpcode() != ISD::BUILD_VECTOR)
5223+
return SDValue();
5224+
5225+
// If the operand has multiple uses, this optimization can increase register
5226+
// pressure.
5227+
if (!BV.hasOneUse())
5228+
return SDValue();
5229+
5230+
// DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
5231+
// any signs they may be folded by some other pattern or rule.
5232+
for (SDValue Op : BV->ops()) {
5233+
// Peek through bitcasts
5234+
if (Op.getOpcode() == ISD::BITCAST)
5235+
Op = Op.getOperand(0);
5236+
5237+
// This may be folded into a PRMT.
5238+
if (Op.getValueType() == MVT::i16 && Op.getOpcode() == ISD::TRUNCATE &&
5239+
Op->getOperand(0).getValueType() == MVT::i32)
5240+
return SDValue();
5241+
5242+
// This may be folded into cvt.bf16x2
5243+
if (Op.getOpcode() == ISD::FP_ROUND)
5244+
return SDValue();
5245+
}
5246+
Operands.append({BV.getOperand(0), BV.getOperand(1)});
5247+
}
5248+
Operands.append(N->op_end() - Back, N->op_end());
5249+
5250+
// Now we replace the store
5251+
return DCI.DAG.getMemIntrinsicNode(Opcode, SDLoc(N), N->getVTList(), Operands,
5252+
MemVT, ST->getMemOperand());
5253+
}
5254+
5255+
static SDValue PerformStoreCombineHelper(SDNode *N,
5256+
TargetLowering::DAGCombinerInfo &DCI,
5257+
unsigned Front, unsigned Back) {
50525258
if (all_of(N->ops().drop_front(Front).drop_back(Back),
50535259
[](const SDUse &U) { return U.get()->isUndef(); }))
50545260
// Operand 0 is the previous value in the chain. Cannot return EntryToken
50555261
// as the previous value will become unused and eliminated later.
50565262
return N->getOperand(0);
50575263

5058-
return SDValue();
5264+
return combinePackingMovIntoStore(N, DCI, Front, Back);
5265+
}
5266+
5267+
static SDValue PerformStoreCombine(SDNode *N,
5268+
TargetLowering::DAGCombinerInfo &DCI) {
5269+
return combinePackingMovIntoStore(N, DCI, 1, 2);
50595270
}
50605271

5061-
static SDValue PerformStoreParamCombine(SDNode *N) {
5272+
static SDValue PerformStoreParamCombine(SDNode *N,
5273+
TargetLowering::DAGCombinerInfo &DCI) {
50625274
// Operands from the 3rd to the 2nd last one are the values to be stored.
50635275
// {Chain, ArgID, Offset, Val, Glue}
5064-
return PerformStoreCombineHelper(N, 3, 1);
5276+
return PerformStoreCombineHelper(N, DCI, 3, 1);
50655277
}
50665278

5067-
static SDValue PerformStoreRetvalCombine(SDNode *N) {
5279+
static SDValue PerformStoreRetvalCombine(SDNode *N,
5280+
TargetLowering::DAGCombinerInfo &DCI) {
50685281
// Operands from the 2nd to the last one are the values to be stored
5069-
return PerformStoreCombineHelper(N, 2, 0);
5282+
return PerformStoreCombineHelper(N, DCI, 2, 0);
50705283
}
50715284

50725285
/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
@@ -5697,14 +5910,23 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
56975910
return PerformREMCombine(N, DCI, OptLevel);
56985911
case ISD::SETCC:
56995912
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
5913+
case ISD::LOAD:
5914+
case NVPTXISD::LoadParamV2:
5915+
case NVPTXISD::LoadV2:
5916+
case NVPTXISD::LoadV4:
5917+
return combineUnpackingMovIntoLoad(N, DCI);
57005918
case NVPTXISD::StoreRetval:
57015919
case NVPTXISD::StoreRetvalV2:
57025920
case NVPTXISD::StoreRetvalV4:
5703-
return PerformStoreRetvalCombine(N);
5921+
return PerformStoreRetvalCombine(N, DCI);
57045922
case NVPTXISD::StoreParam:
57055923
case NVPTXISD::StoreParamV2:
57065924
case NVPTXISD::StoreParamV4:
5707-
return PerformStoreParamCombine(N);
5925+
return PerformStoreParamCombine(N, DCI);
5926+
case ISD::STORE:
5927+
case NVPTXISD::StoreV2:
5928+
case NVPTXISD::StoreV4:
5929+
return PerformStoreCombine(N, DCI);
57085930
case ISD::EXTRACT_VECTOR_ELT:
57095931
return PerformEXTRACTCombine(N, DCI);
57105932
case ISD::VSELECT:

0 commit comments

Comments
 (0)