Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 259 additions & 37 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,11 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
return std::nullopt;
LLVM_FALLTHROUGH;
case MVT::v2i8:
case MVT::v2i16:
case MVT::v2i32:
case MVT::v2i64:
case MVT::v2f16:
case MVT::v2bf16:
case MVT::v2f32:
case MVT::v2f64:
case MVT::v4i8:
case MVT::v4i16:
case MVT::v4i32:
case MVT::v4f16:
case MVT::v4bf16:
case MVT::v4f32:
// This is a "native" vector type
return std::pair(NumElts, EltVT);
Expand All @@ -262,6 +255,13 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
if (!CanLowerTo256Bit)
return std::nullopt;
LLVM_FALLTHROUGH;
case MVT::v2i16: // <1 x i16x2>
case MVT::v2f16: // <1 x f16x2>
case MVT::v2bf16: // <1 x bf16x2>
case MVT::v4i8: // <1 x i8x4>
case MVT::v4i16: // <2 x i16x2>
case MVT::v4f16: // <2 x f16x2>
case MVT::v4bf16: // <2 x bf16x2>
case MVT::v8i8: // <2 x i8x4>
case MVT::v8f16: // <4 x f16x2>
case MVT::v8bf16: // <4 x bf16x2>
Expand Down Expand Up @@ -845,7 +845,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// We have some custom DAG combine patterns for these nodes
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
ISD::STORE});

// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
Expand Down Expand Up @@ -3464,19 +3465,16 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
unsigned I = 0;
for (const unsigned NumElts : VectorInfo) {
const EVT EltVT = VTs[I];
const EVT LoadVT = [&]() -> EVT {
// i1 is loaded/stored as i8.
if (EltVT == MVT::i1)
return MVT::i8;
// getLoad needs a vector type, but it can't handle
// vectors which contain v2f16 or v2bf16 elements. So we must load
// using i32 here and then bitcast back.
if (EltVT.isVector())
return MVT::getIntegerVT(EltVT.getFixedSizeInBits());
return EltVT;
}();
// i1 is loaded/stored as i8
const EVT LoadVT = EltVT == MVT::i1 ? MVT::i8 : EltVT;
// If the element is a packed type (ex. v2f16, v4i8, etc) holding
// multiple elements.
const unsigned PackingAmt =
LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;

const EVT VecVT = EVT::getVectorVT(
F->getContext(), LoadVT.getScalarType(), NumElts * PackingAmt);

const EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
SDValue VecAddr = DAG.getObjectPtrOffset(
dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));

Expand All @@ -3496,8 +3494,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (P.getNode())
P.getNode()->setIROrder(Arg.getArgNo() + 1);
for (const unsigned J : llvm::seq(NumElts)) {
SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
DAG.getIntPtrConstant(J, dl));
SDValue Elt = DAG.getNode(LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
: ISD::EXTRACT_VECTOR_ELT,
dl, LoadVT, P,
DAG.getIntPtrConstant(J * PackingAmt, dl));

// Extend or truncate the element if necessary (e.g. an i8 is loaded
// into an i16 register)
Expand All @@ -3506,15 +3506,11 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
(ExpactedVT.isScalarInteger() &&
Elt.getValueType().isScalarInteger())) &&
"Non-integer argument type size mismatch");
if (ExpactedVT.bitsGT(Elt.getValueType())) {
if (ExpactedVT.bitsGT(Elt.getValueType()))
Elt = DAG.getNode(getExtOpcode(ArgIns[I + J].Flags), dl, ExpactedVT,
Elt);
} else if (ExpactedVT.bitsLT(Elt.getValueType())) {
else if (ExpactedVT.bitsLT(Elt.getValueType()))
Elt = DAG.getNode(ISD::TRUNCATE, dl, ExpactedVT, Elt);
} else {
// v2f16 was loaded as an i32. Now we must bitcast it back.
Elt = DAG.getBitcast(EltVT, Elt);
}
InVals.push_back(Elt);
}
I += NumElts;
Expand Down Expand Up @@ -5047,26 +5043,243 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
return SDValue();
}

static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front,
std::size_t Back) {
/// Fold extractelts into a load by increasing the number of return values.
///
/// ex:
/// L: v2f16,ch = load <p>
/// a: f16 = extractelt L:0, 0
/// b: f16 = extractelt L:0, 1
/// use(a, b)
///
/// ...is turned into...
/// L: f16,f16,ch = LoadV2 <p>
/// use(L:0, L:1)
static SDValue
combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
// Don't run this optimization before the legalizer
if (!DCI.isAfterLegalizeDAG())
return SDValue();

EVT ElemVT = N->getValueType(0);
if (!Isv2x16VT(ElemVT))
return SDValue();

// Check whether all outputs are either used by an extractelt or are
// glue/chain nodes
if (!all_of(N->uses(), [&](SDUse &U) {
// Skip glue, chain nodes
if (U.getValueType() == MVT::Glue || U.getValueType() == MVT::Other)
return true;
if (U.getUser()->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
if (N->getOpcode() != ISD::LOAD)
return true;
// Since this is an ISD::LOAD, check all extractelts are used. If
// any are not used, we don't want to defeat another optimization that
// will narrow the load.
//
// For example:
//
// L: v2f16,ch = load <p>
// e0: f16 = extractelt L:0, 0
// e1: f16 = extractelt L:0, 1 <-- unused
// store e0
//
// Can be optimized by DAGCombiner to:
//
// L: f16,ch = load <p>
// store L:0
return !U.getUser()->use_empty();
}

// Otherwise, this use prevents us from splitting a value.
return false;
}))
return SDValue();

auto *LD = cast<MemSDNode>(N);
EVT MemVT = LD->getMemoryVT();
SDLoc DL(LD);

// the new opcode after we double the number of operands
NVPTXISD::NodeType Opcode;
SmallVector<SDValue> Operands(LD->ops());
unsigned OldNumOutputs; // non-glue, non-chain outputs
switch (LD->getOpcode()) {
case ISD::LOAD:
OldNumOutputs = 1;
// Any packed type is legal, so the legalizer will not have lowered
// ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
// here.
Opcode = NVPTXISD::LoadV2;
Operands.push_back(DCI.DAG.getIntPtrConstant(
cast<LoadSDNode>(LD)->getExtensionType(), DL));
break;
case NVPTXISD::LoadParamV2:
OldNumOutputs = 2;
Opcode = NVPTXISD::LoadParamV4;
break;
case NVPTXISD::LoadV2:
OldNumOutputs = 2;
Opcode = NVPTXISD::LoadV4;
break;
case NVPTXISD::LoadV4:
case NVPTXISD::LoadV8:
// PTX doesn't support the next doubling of outputs
return SDValue();
}

// the non-glue, non-chain outputs in the new load
const unsigned NewNumOutputs = OldNumOutputs * 2;
SmallVector<EVT> NewVTs(NewNumOutputs, ElemVT.getVectorElementType());
// add remaining chain and glue values
NewVTs.append(LD->value_begin() + OldNumOutputs, LD->value_end());

// Create the new load
SDValue NewLoad =
DCI.DAG.getMemIntrinsicNode(Opcode, DL, DCI.DAG.getVTList(NewVTs),
Operands, MemVT, LD->getMemOperand());

// Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
// the outputs the same. These nodes will be optimized away in later
// DAGCombiner iterations.
SmallVector<SDValue> Results;
for (unsigned I : seq(OldNumOutputs))
Results.push_back(DCI.DAG.getBuildVector(
ElemVT, DL, {NewLoad.getValue(I * 2), NewLoad.getValue(I * 2 + 1)}));
// Add remaining chain and glue nodes
for (unsigned I : seq(NewLoad->getNumValues() - NewNumOutputs))
Results.push_back(NewLoad.getValue(NewNumOutputs + I));

return DCI.DAG.getMergeValues(Results, DL);
}

/// Fold a packing mov into a store.
///
/// ex:
/// v: v2f16 = BUILD_VECTOR a:f16, b:f16
/// StoreRetval v
///
/// ...is turned into...
///
/// StoreRetvalV2 a:f16, b:f16
static SDValue combinePackingMovIntoStore(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
unsigned Front, unsigned Back) {
// We want to run this as late as possible since other optimizations may
// eliminate the BUILD_VECTORs.
if (!DCI.isAfterLegalizeDAG())
return SDValue();

// Get the type of the operands being stored.
EVT ElementVT = N->getOperand(Front).getValueType();

if (!Isv2x16VT(ElementVT))
return SDValue();

auto *ST = cast<MemSDNode>(N);
EVT MemVT = ElementVT.getVectorElementType();

// The new opcode after we double the number of operands.
NVPTXISD::NodeType Opcode;
switch (N->getOpcode()) {
case ISD::STORE:
// Any packed type is legal, so the legalizer will not have lowered
// ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
// it here.
MemVT = ST->getMemoryVT();
Opcode = NVPTXISD::StoreV2;
break;
case NVPTXISD::StoreParam:
Opcode = NVPTXISD::StoreParamV2;
break;
case NVPTXISD::StoreParamV2:
Opcode = NVPTXISD::StoreParamV4;
break;
case NVPTXISD::StoreRetval:
Opcode = NVPTXISD::StoreRetvalV2;
break;
case NVPTXISD::StoreRetvalV2:
Opcode = NVPTXISD::StoreRetvalV4;
break;
case NVPTXISD::StoreV2:
MemVT = ST->getMemoryVT();
Opcode = NVPTXISD::StoreV4;
break;
case NVPTXISD::StoreV4:
case NVPTXISD::StoreParamV4:
case NVPTXISD::StoreRetvalV4:
case NVPTXISD::StoreV8:
// PTX doesn't support the next doubling of operands
return SDValue();
default:
llvm_unreachable("Unhandled store opcode");
}

// Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
// their elements.
SmallVector<SDValue, 4> Operands(N->ops().take_front(Front));
for (SDValue BV : N->ops().drop_front(Front).drop_back(Back)) {
if (BV.getOpcode() != ISD::BUILD_VECTOR)
return SDValue();

// If the operand has multiple uses, this optimization can increase register
// pressure.
if (!BV.hasOneUse())
return SDValue();

// DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
// any signs they may be folded by some other pattern or rule.
for (SDValue Op : BV->ops()) {
// Peek through bitcasts
if (Op.getOpcode() == ISD::BITCAST)
Op = Op.getOperand(0);

// This may be folded into a PRMT.
if (Op.getValueType() == MVT::i16 && Op.getOpcode() == ISD::TRUNCATE &&
Op->getOperand(0).getValueType() == MVT::i32)
return SDValue();

// This may be folded into cvt.bf16x2
if (Op.getOpcode() == ISD::FP_ROUND)
return SDValue();
}
Operands.append({BV.getOperand(0), BV.getOperand(1)});
}
Operands.append(N->op_end() - Back, N->op_end());

// Now we replace the store
return DCI.DAG.getMemIntrinsicNode(Opcode, SDLoc(N), N->getVTList(), Operands,
MemVT, ST->getMemOperand());
}

static SDValue PerformStoreCombineHelper(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
unsigned Front, unsigned Back) {
if (all_of(N->ops().drop_front(Front).drop_back(Back),
[](const SDUse &U) { return U.get()->isUndef(); }))
// Operand 0 is the previous value in the chain. Cannot return EntryToken
// as the previous value will become unused and eliminated later.
return N->getOperand(0);

return SDValue();
return combinePackingMovIntoStore(N, DCI, Front, Back);
}

static SDValue PerformStoreCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
return combinePackingMovIntoStore(N, DCI, 1, 2);
}

static SDValue PerformStoreParamCombine(SDNode *N) {
static SDValue PerformStoreParamCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
// Operands from the 3rd to the 2nd last one are the values to be stored.
// {Chain, ArgID, Offset, Val, Glue}
return PerformStoreCombineHelper(N, 3, 1);
return PerformStoreCombineHelper(N, DCI, 3, 1);
}

static SDValue PerformStoreRetvalCombine(SDNode *N) {
static SDValue PerformStoreRetvalCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
// Operands from the 2nd to the last one are the values to be stored
return PerformStoreCombineHelper(N, 2, 0);
return PerformStoreCombineHelper(N, DCI, 2, 0);
}

/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
Expand Down Expand Up @@ -5697,14 +5910,23 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformREMCombine(N, DCI, OptLevel);
case ISD::SETCC:
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
case ISD::LOAD:
case NVPTXISD::LoadParamV2:
case NVPTXISD::LoadV2:
case NVPTXISD::LoadV4:
return combineUnpackingMovIntoLoad(N, DCI);
case NVPTXISD::StoreRetval:
case NVPTXISD::StoreRetvalV2:
case NVPTXISD::StoreRetvalV4:
return PerformStoreRetvalCombine(N);
return PerformStoreRetvalCombine(N, DCI);
case NVPTXISD::StoreParam:
case NVPTXISD::StoreParamV2:
case NVPTXISD::StoreParamV4:
return PerformStoreParamCombine(N);
return PerformStoreParamCombine(N, DCI);
case ISD::STORE:
case NVPTXISD::StoreV2:
case NVPTXISD::StoreV4:
return PerformStoreCombine(N, DCI);
case ISD::EXTRACT_VECTOR_ELT:
return PerformEXTRACTCombine(N, DCI);
case ISD::VSELECT:
Expand Down
Loading
Loading