Skip to content

Commit d18101f

Browse files
committed
use normal stores for retval
1 parent 8f33b13 commit d18101f

36 files changed

+462
-708
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 0 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,6 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
151151
if (tryLoadParam(N))
152152
return;
153153
break;
154-
case NVPTXISD::StoreRetval:
155-
case NVPTXISD::StoreRetvalV2:
156-
case NVPTXISD::StoreRetvalV4:
157-
if (tryStoreRetval(N))
158-
return;
159-
break;
160154
case NVPTXISD::StoreParam:
161155
case NVPTXISD::StoreParamV2:
162156
case NVPTXISD::StoreParamV4:
@@ -1504,84 +1498,6 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
15041498
return true;
15051499
}
15061500

1507-
bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
1508-
SDLoc DL(N);
1509-
SDValue Chain = N->getOperand(0);
1510-
SDValue Offset = N->getOperand(1);
1511-
unsigned OffsetVal = Offset->getAsZExtVal();
1512-
MemSDNode *Mem = cast<MemSDNode>(N);
1513-
1514-
// How many elements do we have?
1515-
unsigned NumElts = 1;
1516-
switch (N->getOpcode()) {
1517-
default:
1518-
return false;
1519-
case NVPTXISD::StoreRetval:
1520-
NumElts = 1;
1521-
break;
1522-
case NVPTXISD::StoreRetvalV2:
1523-
NumElts = 2;
1524-
break;
1525-
case NVPTXISD::StoreRetvalV4:
1526-
NumElts = 4;
1527-
break;
1528-
}
1529-
1530-
// Build vector of operands
1531-
SmallVector<SDValue, 6> Ops;
1532-
for (unsigned i = 0; i < NumElts; ++i)
1533-
Ops.push_back(N->getOperand(i + 2));
1534-
Ops.append({CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain});
1535-
1536-
// Determine target opcode
1537-
// If we have an i1, use an 8-bit store. The lowering code in
1538-
// NVPTXISelLowering will have already emitted an upcast.
1539-
std::optional<unsigned> Opcode = 0;
1540-
switch (NumElts) {
1541-
default:
1542-
return false;
1543-
case 1:
1544-
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
1545-
NVPTX::StoreRetvalI8, NVPTX::StoreRetvalI16,
1546-
NVPTX::StoreRetvalI32, NVPTX::StoreRetvalI64);
1547-
if (Opcode == NVPTX::StoreRetvalI8) {
1548-
// Fine tune the opcode depending on the size of the operand.
1549-
// This helps to avoid creating redundant COPY instructions in
1550-
// InstrEmitter::AddRegisterOperand().
1551-
switch (Ops[0].getSimpleValueType().SimpleTy) {
1552-
default:
1553-
break;
1554-
case MVT::i32:
1555-
Opcode = NVPTX::StoreRetvalI8TruncI32;
1556-
break;
1557-
case MVT::i64:
1558-
Opcode = NVPTX::StoreRetvalI8TruncI64;
1559-
break;
1560-
}
1561-
}
1562-
break;
1563-
case 2:
1564-
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
1565-
NVPTX::StoreRetvalV2I8, NVPTX::StoreRetvalV2I16,
1566-
NVPTX::StoreRetvalV2I32, NVPTX::StoreRetvalV2I64);
1567-
break;
1568-
case 4:
1569-
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
1570-
NVPTX::StoreRetvalV4I8, NVPTX::StoreRetvalV4I16,
1571-
NVPTX::StoreRetvalV4I32, {/* no v4i64 */});
1572-
break;
1573-
}
1574-
if (!Opcode)
1575-
return false;
1576-
1577-
SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, MVT::Other, Ops);
1578-
MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
1579-
CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef});
1580-
1581-
ReplaceNode(N, Ret);
1582-
return true;
1583-
}
1584-
15851501
// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
15861502
#define getOpcV2H(ty, opKind0, opKind1) \
15871503
NVPTX::StoreParamV2##ty##_##opKind0##opKind1

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
7979
bool tryStore(SDNode *N);
8080
bool tryStoreVector(SDNode *N);
8181
bool tryLoadParam(SDNode *N);
82-
bool tryStoreRetval(SDNode *N);
8382
bool tryStoreParam(SDNode *N);
8483
bool tryFence(SDNode *N);
8584
void SelectAddrSpaceCast(SDNode *N);

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 51 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,9 +1065,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10651065
MAKE_CASE(NVPTXISD::StoreParamV2)
10661066
MAKE_CASE(NVPTXISD::StoreParamV4)
10671067
MAKE_CASE(NVPTXISD::MoveParam)
1068-
MAKE_CASE(NVPTXISD::StoreRetval)
1069-
MAKE_CASE(NVPTXISD::StoreRetvalV2)
1070-
MAKE_CASE(NVPTXISD::StoreRetvalV4)
10711068
MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
10721069
MAKE_CASE(NVPTXISD::BUILD_VECTOR)
10731070
MAKE_CASE(NVPTXISD::CallPrototype)
@@ -1438,7 +1435,11 @@ static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
14381435
}
14391436

14401437
static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags) {
1441-
return Flags.isSExt() ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
1438+
if (Flags.isSExt())
1439+
return ISD::SIGN_EXTEND;
1440+
if (Flags.isZExt())
1441+
return ISD::ZERO_EXTEND;
1442+
return ISD::ANY_EXTEND;
14421443
}
14431444

14441445
SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
@@ -3385,9 +3386,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
33853386
const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
33863387
unsigned I = 0;
33873388
for (const unsigned NumElts : VectorInfo) {
3388-
const EVT EltVT = VTs[I];
33893389
// i1 is loaded/stored as i8
3390-
const EVT LoadVT = EltVT == MVT::i1 ? MVT::i8 : EltVT;
3390+
const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
33913391
// If the element is a packed type (ex. v2f16, v4i8, etc) holding
33923392
// multiple elements.
33933393
const unsigned PackingAmt =
@@ -3408,17 +3408,16 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34083408
if (P.getNode())
34093409
P.getNode()->setIROrder(Arg.getArgNo() + 1);
34103410
for (const unsigned J : llvm::seq(NumElts)) {
3411-
SDValue Elt = DAG.getNode(LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
3412-
: ISD::EXTRACT_VECTOR_ELT,
3413-
dl, LoadVT, P,
3414-
DAG.getVectorIdxConstant(J * PackingAmt, dl));
3411+
SDValue Elt = DAG.getNode(
3412+
LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
3413+
: ISD::EXTRACT_VECTOR_ELT,
3414+
dl, LoadVT, P, DAG.getVectorIdxConstant(J * PackingAmt, dl));
34153415

34163416
// Extend or truncate the element if necessary (e.g. an i8 is loaded
34173417
// into an i16 register)
34183418
const EVT ExpactedVT = ArgIns[I + J].VT;
34193419
assert((Elt.getValueType() == ExpactedVT ||
3420-
(ExpactedVT.isInteger() &&
3421-
Elt.getValueType().isInteger())) &&
3420+
(ExpactedVT.isInteger() && Elt.getValueType().isInteger())) &&
34223421
"Non-integer argument type size mismatch");
34233422
if (ExpactedVT.bitsGT(Elt.getValueType()))
34243423
Elt = DAG.getNode(getExtOpcode(ArgIns[I + J].Flags), dl, ExpactedVT,
@@ -3438,33 +3437,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34383437
return Chain;
34393438
}
34403439

3441-
// Use byte-store when the param adress of the return value is unaligned.
3442-
// This may happen when the return value is a field of a packed structure.
3443-
static SDValue LowerUnalignedStoreRet(SelectionDAG &DAG, SDValue Chain,
3444-
uint64_t Offset, EVT ElementType,
3445-
SDValue RetVal, const SDLoc &dl) {
3446-
// Bit logic only works on integer types
3447-
if (adjustElementType(ElementType))
3448-
RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
3449-
3450-
// Store each byte
3451-
for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
3452-
// Shift the byte to the last byte position
3453-
SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, RetVal,
3454-
DAG.getConstant(i * 8, dl, MVT::i32));
3455-
SDValue StoreOperands[] = {Chain, DAG.getConstant(Offset + i, dl, MVT::i32),
3456-
ShiftVal};
3457-
// Trunc store only the last byte by using
3458-
// st.param.b8
3459-
// The register type can be larger than b8.
3460-
Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
3461-
DAG.getVTList(MVT::Other), StoreOperands,
3462-
MVT::i8, MachinePointerInfo(), std::nullopt,
3463-
MachineMemOperand::MOStore);
3464-
}
3465-
return Chain;
3466-
}
3467-
34683440
SDValue
34693441
NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
34703442
bool isVarArg,
@@ -3486,10 +3458,6 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
34863458
ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
34873459
assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
34883460

3489-
for (const unsigned I : llvm::seq(VTs.size()))
3490-
if (const auto PromotedVT = PromoteScalarIntegerPTX(VTs[I]))
3491-
VTs[I] = *PromotedVT;
3492-
34933461
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
34943462
// 32-bits are sign extended or zero extended, depending on whether
34953463
// they are signed or unsigned types.
@@ -3501,12 +3469,20 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
35013469
assert(!PromoteScalarIntegerPTX(RetVal.getValueType()) &&
35023470
"OutVal type should always be legal");
35033471

3504-
if (ExtendIntegerRetVal) {
3505-
RetVal = DAG.getNode(getExtOpcode(Outs[I].Flags), dl, MVT::i32, RetVal);
3506-
} else if (RetVal.getValueSizeInBits() < 16) {
3507-
// Use 16-bit registers for small load-stores as it's the
3508-
// smallest general purpose register size supported by NVPTX.
3509-
RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal);
3472+
EVT VTI = VTs[I];
3473+
if (const auto PromotedVT = PromoteScalarIntegerPTX(VTI))
3474+
VTI = *PromotedVT;
3475+
3476+
const EVT StoreVT =
3477+
ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
3478+
3479+
assert((RetVal.getValueType() == StoreVT ||
3480+
(StoreVT.isInteger() && RetVal.getValueType().isInteger())) &&
3481+
"Non-integer argument type size mismatch");
3482+
if (StoreVT.bitsGT(RetVal.getValueType())) {
3483+
RetVal = DAG.getNode(getExtOpcode(Outs[I].Flags), dl, StoreVT, RetVal);
3484+
} else if (StoreVT.bitsLT(RetVal.getValueType())) {
3485+
RetVal = DAG.getNode(ISD::TRUNCATE, dl, StoreVT, RetVal);
35103486
}
35113487
return RetVal;
35123488
};
@@ -3515,45 +3491,36 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
35153491
const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
35163492
unsigned I = 0;
35173493
for (const unsigned NumElts : VectorInfo) {
3518-
const Align CurrentAlign = commonAlignment(RetAlign, Offsets[I]);
3519-
if (NumElts == 1 && RetTy->isAggregateType() &&
3520-
CurrentAlign < DAG.getEVTAlign(VTs[I])) {
3521-
Chain = LowerUnalignedStoreRet(DAG, Chain, Offsets[I], VTs[I],
3522-
GetRetVal(I), dl);
3523-
3524-
// The call to LowerUnalignedStoreRet inserted the necessary SDAG nodes
3525-
// into the graph, so just move on to the next element.
3526-
I++;
3527-
continue;
3528-
}
3529-
3530-
SmallVector<SDValue, 6> StoreOperands{
3531-
Chain, DAG.getConstant(Offsets[I], dl, MVT::i32)};
3494+
const MaybeAlign CurrentAlign = ExtendIntegerRetVal
3495+
? MaybeAlign(std::nullopt)
3496+
: commonAlignment(RetAlign, Offsets[I]);
35323497

3533-
for (const unsigned J : llvm::seq(NumElts))
3534-
StoreOperands.push_back(GetRetVal(I + J));
3498+
SDValue Val;
3499+
if (NumElts == 1) {
3500+
Val = GetRetVal(I);
3501+
} else {
3502+
SmallVector<SDValue, 6> StoreVals;
3503+
for (const unsigned J : llvm::seq(NumElts)) {
3504+
SDValue ValJ = GetRetVal(I + J);
3505+
if (ValJ.getValueType().isVector())
3506+
DAG.ExtractVectorElements(ValJ, StoreVals);
3507+
else
3508+
StoreVals.push_back(ValJ);
3509+
}
35353510

3536-
NVPTXISD::NodeType Op;
3537-
switch (NumElts) {
3538-
case 1:
3539-
Op = NVPTXISD::StoreRetval;
3540-
break;
3541-
case 2:
3542-
Op = NVPTXISD::StoreRetvalV2;
3543-
break;
3544-
case 4:
3545-
Op = NVPTXISD::StoreRetvalV4;
3546-
break;
3547-
default:
3548-
llvm_unreachable("Invalid vector info.");
3511+
EVT VT = EVT::getVectorVT(F.getContext(), StoreVals[0].getValueType(),
3512+
StoreVals.size());
3513+
Val = DAG.getBuildVector(VT, dl, StoreVals);
35493514
}
35503515

3551-
// Adjust type of load/store op if we've extended the scalar
3552-
// return value.
3553-
EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[I];
3554-
Chain = DAG.getMemIntrinsicNode(
3555-
Op, dl, DAG.getVTList(MVT::Other), StoreOperands, TheStoreType,
3556-
MachinePointerInfo(), CurrentAlign, MachineMemOperand::MOStore);
3516+
SDValue RetSymbol =
3517+
DAG.getNode(NVPTXISD::Wrapper, dl, MVT::i32,
3518+
DAG.getTargetExternalSymbol("func_retval0", MVT::i32));
3519+
SDValue Ptr =
3520+
DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
3521+
3522+
Chain = DAG.getStore(Chain, dl, Val, Ptr,
3523+
MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
35573524

35583525
I += NumElts;
35593526
}
@@ -5109,19 +5076,12 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
51095076
case NVPTXISD::StoreParamV2:
51105077
Opcode = NVPTXISD::StoreParamV4;
51115078
break;
5112-
case NVPTXISD::StoreRetval:
5113-
Opcode = NVPTXISD::StoreRetvalV2;
5114-
break;
5115-
case NVPTXISD::StoreRetvalV2:
5116-
Opcode = NVPTXISD::StoreRetvalV4;
5117-
break;
51185079
case NVPTXISD::StoreV2:
51195080
MemVT = ST->getMemoryVT();
51205081
Opcode = NVPTXISD::StoreV4;
51215082
break;
51225083
case NVPTXISD::StoreV4:
51235084
case NVPTXISD::StoreParamV4:
5124-
case NVPTXISD::StoreRetvalV4:
51255085
case NVPTXISD::StoreV8:
51265086
// PTX doesn't support the next doubling of operands
51275087
return SDValue();
@@ -5190,12 +5150,6 @@ static SDValue PerformStoreParamCombine(SDNode *N,
51905150
return PerformStoreCombineHelper(N, DCI, 3, 1);
51915151
}
51925152

5193-
static SDValue PerformStoreRetvalCombine(SDNode *N,
5194-
TargetLowering::DAGCombinerInfo &DCI) {
5195-
// Operands from the 2nd to the last one are the values to be stored
5196-
return PerformStoreCombineHelper(N, DCI, 2, 0);
5197-
}
5198-
51995153
/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
52005154
///
52015155
static SDValue PerformADDCombine(SDNode *N,
@@ -5829,10 +5783,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
58295783
case NVPTXISD::LoadV2:
58305784
case NVPTXISD::LoadV4:
58315785
return combineUnpackingMovIntoLoad(N, DCI);
5832-
case NVPTXISD::StoreRetval:
5833-
case NVPTXISD::StoreRetvalV2:
5834-
case NVPTXISD::StoreRetvalV4:
5835-
return PerformStoreRetvalCombine(N, DCI);
58365786
case NVPTXISD::StoreParam:
58375787
case NVPTXISD::StoreParamV2:
58385788
case NVPTXISD::StoreParamV4:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,7 @@ enum NodeType : unsigned {
8686
StoreParam,
8787
StoreParamV2,
8888
StoreParamV4,
89-
StoreRetval,
90-
StoreRetvalV2,
91-
StoreRetvalV4,
92-
LAST_MEMORY_OPCODE = StoreRetvalV4,
89+
LAST_MEMORY_OPCODE = StoreParamV4,
9390
};
9491
}
9592

0 commit comments

Comments
 (0)