Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions clang/test/CodeGenCUDA/bf16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,7 @@ __device__ __bf16 test_call( __bf16 in) {
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z9test_callDF16b_param_0];
// CHECK: st.param.b16 [param0], %[[R]];
// CHECK: .param .align 2 .b8 retval0[2];
// CHECK: call.uni (retval0),
// CHECK-NEXT: _Z13external_funcDF16b,
// CHECK-NEXT: (
// CHECK-NEXT: param0
// CHECK-NEXT );
// CHECK: call.uni (retval0), _Z13external_funcDF16b, (param0);
// CHECK: ld.param.b16 %[[RET:rs[0-9]+]], [retval0];
return external_func(in);
// CHECK: st.param.b16 [func_retval0], %[[RET]]
Expand Down
22 changes: 22 additions & 0 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,25 @@ void NVPTXInstPrinter::printCTAGroup(const MCInst *MI, int OpNum,
}
llvm_unreachable("Invalid cta_group in printCTAGroup");
}

void NVPTXInstPrinter::printCallOperand(const MCInst *MI, int OpNum,
raw_ostream &O, StringRef Modifier) {
const MCOperand &MO = MI->getOperand(OpNum);
assert(MO.isImm() && "Invalid operand");
const auto Imm = MO.getImm();

if (Modifier == "RetList") {
assert((Imm == 1 || Imm == 0) && "Invalid return list");
if (Imm)
O << " (retval0),";
return;
}

if (Modifier == "ParamList") {
assert(Imm >= 0 && "Invalid parameter list");
interleaveComma(llvm::seq(Imm), O,
[&](const auto &I) { O << "param" << I; });
return;
}
Comment on lines +474 to +479
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if PTX imposes practical limits on the line length?
Some JIT compilations may end up generating functions with a very long list of arguments.
We may want to try folding them, or print them one per line if the total number of arguments is above certain threshold.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know of a practical limit. I know we've had some very long lines with initialization of static arrays and with debug info in the past and as far as I know it hasn't been a problem.

Should I prophylactically add some newlines or do you think it makes sense to wait and see?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it as is. JIT-produced PTX is rarely read by anyone other than ptxas. For human use current form is fine.

llvm_unreachable("Invalid modifier");
}
2 changes: 2 additions & 0 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
void printTmaReductionMode(const MCInst *MI, int OpNum, raw_ostream &O);
void printCTAGroup(const MCInst *MI, int OpNum, raw_ostream &O);
void printCallOperand(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier = {});
};

}
Expand Down
72 changes: 8 additions & 64 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,9 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
case NVPTXISD::StoreParam:
case NVPTXISD::StoreParamV2:
case NVPTXISD::StoreParamV4:
case NVPTXISD::StoreParamS32:
case NVPTXISD::StoreParamU32:
if (tryStoreParam(N))
return;
break;
case ISD::INTRINSIC_WO_CHAIN:
if (tryIntrinsicNoChain(N))
return;
break;
case ISD::INTRINSIC_W_CHAIN:
if (tryIntrinsicChain(N))
return;
Expand Down Expand Up @@ -904,25 +898,6 @@ NVPTXDAGToDAGISel::insertMemoryInstructionFence(SDLoc DL, SDValue &Chain,
return {InstructionOrdering, Scope};
}

bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) {
unsigned IID = N->getConstantOperandVal(0);
switch (IID) {
default:
return false;
case Intrinsic::nvvm_texsurf_handle_internal:
SelectTexSurfHandle(N);
return true;
}
}

void NVPTXDAGToDAGISel::SelectTexSurfHandle(SDNode *N) {
// Op 0 is the intrinsic ID
SDValue Wrapper = N->getOperand(1);
SDValue GlobalVal = Wrapper.getOperand(0);
ReplaceNode(N, CurDAG->getMachineNode(NVPTX::texsurf_handles, SDLoc(N),
MVT::i64, GlobalVal));
}

void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
SDValue Src = N->getOperand(0);
AddrSpaceCastSDNode *CastN = cast<AddrSpaceCastSDNode>(N);
Expand Down Expand Up @@ -1717,8 +1692,6 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
switch (N->getOpcode()) {
default:
llvm_unreachable("Unexpected opcode");
case NVPTXISD::StoreParamU32:
case NVPTXISD::StoreParamS32:
case NVPTXISD::StoreParam:
NumElts = 1;
break;
Expand Down Expand Up @@ -1796,27 +1769,6 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
}
}
break;
// Special case: if we have a sign-extend/zero-extend node, insert the
// conversion instruction first, and use that as the value operand to
// the selected StoreParam node.
case NVPTXISD::StoreParamU32: {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexMaclean This has removed everything but the default case (which causes MSVC build warnings) - remove the outer switch(n->getOpcode())?

Opcode = NVPTX::StoreParamI32_r;
SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
MVT::i32);
SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u32_u16, DL,
MVT::i32, Ops[0], CvtNone);
Ops[0] = SDValue(Cvt, 0);
break;
}
case NVPTXISD::StoreParamS32: {
Opcode = NVPTX::StoreParamI32_r;
SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
MVT::i32);
SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_s32_s16, DL,
MVT::i32, Ops[0], CvtNone);
Ops[0] = SDValue(Cvt, 0);
break;
}
}

SDVTList RetVTs = CurDAG->getVTList(MVT::Other, MVT::Glue);
Expand Down Expand Up @@ -2105,22 +2057,14 @@ static inline bool isAddLike(const SDValue V) {
// selectBaseADDR - Match a dag node which will serve as the base address for an
// ADDR operand pair.
static SDValue selectBaseADDR(SDValue N, SelectionDAG *DAG) {
// Return true if TGA or ES.
if (N.getOpcode() == ISD::TargetGlobalAddress ||
N.getOpcode() == ISD::TargetExternalSymbol)
return N;

if (N.getOpcode() == NVPTXISD::Wrapper)
return N.getOperand(0);

// addrspacecast(Wrapper(arg_symbol) to addrspace(PARAM)) -> arg_symbol
if (AddrSpaceCastSDNode *CastN = dyn_cast<AddrSpaceCastSDNode>(N))
if (CastN->getSrcAddressSpace() == ADDRESS_SPACE_GENERIC &&
CastN->getDestAddressSpace() == ADDRESS_SPACE_PARAM &&
CastN->getOperand(0).getOpcode() == NVPTXISD::Wrapper)
return selectBaseADDR(CastN->getOperand(0).getOperand(0), DAG);

if (auto *FIN = dyn_cast<FrameIndexSDNode>(N))
if (const auto *GA = dyn_cast<GlobalAddressSDNode>(N))
return DAG->getTargetGlobalAddress(GA->getGlobal(), SDLoc(N),
GA->getValueType(0), GA->getOffset(),
GA->getTargetFlags());
if (const auto *ES = dyn_cast<ExternalSymbolSDNode>(N))
return DAG->getTargetExternalSymbol(ES->getSymbol(), ES->getValueType(0),
ES->getTargetFlags());
if (const auto *FIN = dyn_cast<FrameIndexSDNode>(N))
return DAG->getTargetFrameIndex(FIN->getIndex(), FIN->getValueType(0));

return N;
Expand Down
1 change: 0 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
#include "NVPTXGenDAGISel.inc"

void Select(SDNode *N) override;
bool tryIntrinsicNoChain(SDNode *N);
bool tryIntrinsicChain(SDNode *N);
bool tryIntrinsicVoid(SDNode *N);
void SelectTexSurfHandle(SDNode *N);
Expand Down
121 changes: 23 additions & 98 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,9 +702,6 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::BR_JT, MVT::Other, Custom);
setOperationAction(ISD::BRIND, MVT::Other, Expand);

setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);

// We want to legalize constant related memmove and memcopy
// intrinsics.
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
Expand Down Expand Up @@ -1055,45 +1052,24 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
case NVPTXISD::FIRST_NUMBER:
break;

MAKE_CASE(NVPTXISD::CALL)
MAKE_CASE(NVPTXISD::RET_GLUE)
MAKE_CASE(NVPTXISD::LOAD_PARAM)
MAKE_CASE(NVPTXISD::Wrapper)
MAKE_CASE(NVPTXISD::DeclareParam)
MAKE_CASE(NVPTXISD::DeclareScalarParam)
MAKE_CASE(NVPTXISD::DeclareRet)
MAKE_CASE(NVPTXISD::DeclareScalarRet)
MAKE_CASE(NVPTXISD::DeclareRetParam)
MAKE_CASE(NVPTXISD::PrintCall)
MAKE_CASE(NVPTXISD::PrintConvergentCall)
MAKE_CASE(NVPTXISD::PrintCallUni)
MAKE_CASE(NVPTXISD::PrintConvergentCallUni)
MAKE_CASE(NVPTXISD::CALL)
MAKE_CASE(NVPTXISD::LoadParam)
MAKE_CASE(NVPTXISD::LoadParamV2)
MAKE_CASE(NVPTXISD::LoadParamV4)
MAKE_CASE(NVPTXISD::StoreParam)
MAKE_CASE(NVPTXISD::StoreParamV2)
MAKE_CASE(NVPTXISD::StoreParamV4)
MAKE_CASE(NVPTXISD::StoreParamS32)
MAKE_CASE(NVPTXISD::StoreParamU32)
MAKE_CASE(NVPTXISD::CallArgBegin)
MAKE_CASE(NVPTXISD::CallArg)
MAKE_CASE(NVPTXISD::LastCallArg)
MAKE_CASE(NVPTXISD::CallArgEnd)
MAKE_CASE(NVPTXISD::CallVoid)
MAKE_CASE(NVPTXISD::CallVal)
MAKE_CASE(NVPTXISD::CallSymbol)
MAKE_CASE(NVPTXISD::Prototype)
MAKE_CASE(NVPTXISD::MoveParam)
MAKE_CASE(NVPTXISD::StoreRetval)
MAKE_CASE(NVPTXISD::StoreRetvalV2)
MAKE_CASE(NVPTXISD::StoreRetvalV4)
MAKE_CASE(NVPTXISD::PseudoUseParam)
MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
MAKE_CASE(NVPTXISD::BUILD_VECTOR)
MAKE_CASE(NVPTXISD::RETURN)
MAKE_CASE(NVPTXISD::CallSeqBegin)
MAKE_CASE(NVPTXISD::CallSeqEnd)
MAKE_CASE(NVPTXISD::CallPrototype)
MAKE_CASE(NVPTXISD::ProxyReg)
MAKE_CASE(NVPTXISD::LoadV2)
Expand All @@ -1115,7 +1091,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::STACKSAVE)
MAKE_CASE(NVPTXISD::SETP_F16X2)
MAKE_CASE(NVPTXISD::SETP_BF16X2)
MAKE_CASE(NVPTXISD::Dummy)
MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
MAKE_CASE(NVPTXISD::BrxEnd)
Expand Down Expand Up @@ -1189,15 +1164,6 @@ SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
}
}

SDValue
NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
SDLoc dl(Op);
const GlobalAddressSDNode *GAN = cast<GlobalAddressSDNode>(Op);
auto PtrVT = getPointerTy(DAG.getDataLayout(), GAN->getAddressSpace());
Op = DAG.getTargetGlobalAddress(GAN->getGlobal(), dl, PtrVT);
return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
}

std::string NVPTXTargetLowering::getPrototype(
const DataLayout &DL, Type *retTy, const ArgListTy &Args,
const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign RetAlign,
Expand Down Expand Up @@ -1601,9 +1567,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
? promoteScalarArgumentSize(TypeSize * 8)
: TypeSize * 8;

Chain = DAG.getNode(
NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
{Chain, GetI32(ArgI), GetI32(PromotedSize), GetI32(0), InGlue});
Chain =
DAG.getNode(NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
{Chain, GetI32(ArgI), GetI32(PromotedSize), InGlue});
}
InGlue = Chain.getValue(1);

Expand Down Expand Up @@ -1740,16 +1706,13 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const unsigned ResultSize = DL.getTypeAllocSizeInBits(RetTy);
if (!shouldPassAsArray(RetTy)) {
const unsigned PromotedResultSize = promoteScalarArgumentSize(ResultSize);
SDValue DeclareRetOps[] = {Chain, GetI32(1), GetI32(PromotedResultSize),
GetI32(0), InGlue};
Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, {MVT::Other, MVT::Glue},
DeclareRetOps);
{Chain, GetI32(PromotedResultSize), InGlue});
InGlue = Chain.getValue(1);
} else {
SDValue DeclareRetOps[] = {Chain, GetI32(RetAlign->value()),
GetI32(ResultSize / 8), GetI32(0), InGlue};
Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl,
{MVT::Other, MVT::Glue}, DeclareRetOps);
Chain = DAG.getNode(
NVPTXISD::DeclareRetParam, dl, {MVT::Other, MVT::Glue},
{Chain, GetI32(RetAlign->value()), GetI32(ResultSize / 8), InGlue});
InGlue = Chain.getValue(1);
}
}
Expand Down Expand Up @@ -1800,25 +1763,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, *CB,
UniqueCallSite);
const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
SDValue ProtoOps[] = {
Chain,
DAG.getTargetExternalSymbol(ProtoStr, MVT::i32),
InGlue,
};
Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, {MVT::Other, MVT::Glue},
ProtoOps);
Chain = DAG.getNode(
NVPTXISD::CallPrototype, dl, {MVT::Other, MVT::Glue},
{Chain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32), InGlue});
InGlue = Chain.getValue(1);
}
// Op to just print "call"
SDValue PrintCallOps[] = {Chain, GetI32(Ins.empty() ? 0 : 1), InGlue};
// We model convergent calls as separate opcodes.
unsigned Opcode =
IsIndirectCall ? NVPTXISD::PrintCall : NVPTXISD::PrintCallUni;
if (CLI.IsConvergent)
Opcode = Opcode == NVPTXISD::PrintCallUni ? NVPTXISD::PrintConvergentCallUni
: NVPTXISD::PrintConvergentCall;
Chain = DAG.getNode(Opcode, dl, {MVT::Other, MVT::Glue}, PrintCallOps);
InGlue = Chain.getValue(1);

if (ConvertToIndirectCall) {
// Copy the function ptr to a ptx register and use the register to call the
Expand All @@ -1832,38 +1781,17 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
}

// Ops to print out the function name
SDValue CallVoidOps[] = { Chain, Callee, InGlue };
Chain =
DAG.getNode(NVPTXISD::CallVoid, dl, {MVT::Other, MVT::Glue}, CallVoidOps);
InGlue = Chain.getValue(1);

// Ops to print out the param list
SDValue CallArgBeginOps[] = { Chain, InGlue };
Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, {MVT::Other, MVT::Glue},
CallArgBeginOps);
const unsigned Proto = IsIndirectCall ? UniqueCallSite : 0;
const unsigned NumArgs =
std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
/// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
/// NumParams, Callee, Proto, InGlue)
Chain = DAG.getNode(NVPTXISD::CALL, dl, {MVT::Other, MVT::Glue},
{Chain, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall),
GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee,
GetI32(Proto), InGlue});
InGlue = Chain.getValue(1);

const unsigned E = std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
for (const unsigned I : llvm::seq(E)) {
const unsigned Opcode =
I == (E - 1) ? NVPTXISD::LastCallArg : NVPTXISD::CallArg;
SDValue CallArgOps[] = {Chain, GetI32(1), GetI32(I), InGlue};
Chain = DAG.getNode(Opcode, dl, {MVT::Other, MVT::Glue}, CallArgOps);
InGlue = Chain.getValue(1);
}
SDValue CallArgEndOps[] = {Chain, GetI32(IsIndirectCall ? 0 : 1), InGlue};
Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, {MVT::Other, MVT::Glue},
CallArgEndOps);
InGlue = Chain.getValue(1);

if (IsIndirectCall) {
SDValue PrototypeOps[] = {Chain, GetI32(UniqueCallSite), InGlue};
Chain = DAG.getNode(NVPTXISD::Prototype, dl, {MVT::Other, MVT::Glue},
PrototypeOps);
InGlue = Chain.getValue(1);
}

SmallVector<SDValue, 16> ProxyRegOps;
// An item of the vector is filled if the element does not need a ProxyReg
// operation on it and should be added to InVals as is. ProxyRegOps and
Expand Down Expand Up @@ -2919,8 +2847,6 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return SDValue();
case ISD::ADDRSPACECAST:
return LowerADDRSPACECAST(Op, DAG);
case ISD::GlobalAddress:
return LowerGlobalAddress(Op, DAG);
case ISD::INTRINSIC_W_CHAIN:
return Op;
case ISD::INTRINSIC_WO_CHAIN:
Expand Down Expand Up @@ -3129,8 +3055,7 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
EVT PtrVT = TLI->getPointerTy(DAG.getDataLayout());

// Store the address of unsized array <function>_vararg[] in the ap object.
SDValue Arg = getParamSymbol(DAG, /* vararg */ -1, PtrVT);
SDValue VAReg = DAG.getNode(NVPTXISD::Wrapper, DL, PtrVT, Arg);
SDValue VAReg = getParamSymbol(DAG, /* vararg */ -1, PtrVT);

const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
return DAG.getStore(Op.getOperand(0), DL, VAReg, Op.getOperand(1),
Expand Down Expand Up @@ -3370,7 +3295,7 @@ SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx,
EVT v) const {
StringRef SavedStr = nvTM->getStrPool().save(
getParamName(&DAG.getMachineFunction().getFunction(), idx));
return DAG.getTargetExternalSymbol(SavedStr.data(), v);
return DAG.getExternalSymbol(SavedStr.data(), v);
}

SDValue NVPTXTargetLowering::LowerFormalArguments(
Expand Down Expand Up @@ -3438,7 +3363,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(

SDValue P;
if (isKernelFunction(*F)) {
P = DAG.getNode(NVPTXISD::Wrapper, dl, ByvalIn.VT, ArgSymbol);
P = ArgSymbol;
P.getNode()->setIROrder(Arg.getArgNo() + 1);
} else {
P = DAG.getNode(NVPTXISD::MoveParam, dl, ByvalIn.VT, ArgSymbol);
Expand Down
Loading