diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 01abf9591e342..0f1cbcc9c2d43 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -53,6 +53,7 @@ #include "llvm/Support/CodeGen.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/NVPTXAddrSpace.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" @@ -667,8 +668,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction(ISD::ConstantFP, MVT::f16, Legal); setOperationAction(ISD::ConstantFP, MVT::bf16, Legal); - setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i32, Custom); - setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i64, Custom); + setOperationAction(ISD::DYNAMIC_STACKALLOC, {MVT::i32, MVT::i64}, Custom); + setOperationAction({ISD::STACKRESTORE, ISD::STACKSAVE}, MVT::Other, Custom); // TRAP can be lowered to PTX trap setOperationAction(ISD::TRAP, MVT::Other, Legal); @@ -961,6 +962,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(NVPTXISD::PRMT) MAKE_CASE(NVPTXISD::FCOPYSIGN) MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC) + MAKE_CASE(NVPTXISD::STACKRESTORE) + MAKE_CASE(NVPTXISD::STACKSAVE) MAKE_CASE(NVPTXISD::SETP_F16X2) MAKE_CASE(NVPTXISD::SETP_BF16X2) MAKE_CASE(NVPTXISD::Dummy) @@ -2287,6 +2290,54 @@ SDValue NVPTXTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op, return DAG.getNode(NVPTXISD::DYNAMIC_STACKALLOC, DL, RetTypes, AllocOps); } +SDValue NVPTXTargetLowering::LowerSTACKRESTORE(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op.getNode()); + if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) { + const Function &Fn = DAG.getMachineFunction().getFunction(); + + DiagnosticInfoUnsupported NoStackRestore( + Fn, + "Support for stackrestore requires PTX ISA version >= 7.3 and target " + ">= sm_52.", + DL.getDebugLoc()); + DAG.getContext()->diagnose(NoStackRestore); + return Op.getOperand(0); + } + + const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL); + SDValue Chain = Op.getOperand(0); + SDValue Ptr = Op.getOperand(1); + SDValue ASC = DAG.getAddrSpaceCast(DL, LocalVT, Ptr, ADDRESS_SPACE_GENERIC, + ADDRESS_SPACE_LOCAL); + return DAG.getNode(NVPTXISD::STACKRESTORE, DL, MVT::Other, {Chain, ASC}); +} + +SDValue NVPTXTargetLowering::LowerSTACKSAVE(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op.getNode()); + if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) { + const Function &Fn = DAG.getMachineFunction().getFunction(); + + DiagnosticInfoUnsupported NoStackSave( + Fn, + "Support for stacksave requires PTX ISA version >= 7.3 and target >= " + "sm_52.", + DL.getDebugLoc()); + DAG.getContext()->diagnose(NoStackSave); + auto Ops = {DAG.getConstant(0, DL, Op.getValueType()), Op.getOperand(0)}; + return DAG.getMergeValues(Ops, DL); + } + + const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL); + SDValue Chain = Op.getOperand(0); + SDValue SS = + DAG.getNode(NVPTXISD::STACKSAVE, DL, {LocalVT, MVT::Other}, Chain); + SDValue ASC = DAG.getAddrSpaceCast( + DL, Op.getValueType(), SS, ADDRESS_SPACE_LOCAL, ADDRESS_SPACE_GENERIC); + return DAG.getMergeValues({ASC, SDValue(SS.getNode(), 1)}, DL); +} + // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack() // (see LegalizeDAG.cpp). This is slow and uses local memory. // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5 @@ -2871,6 +2922,10 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { return LowerVectorArith(Op, DAG); case ISD::DYNAMIC_STACKALLOC: return LowerDYNAMIC_STACKALLOC(Op, DAG); + case ISD::STACKRESTORE: + return LowerSTACKRESTORE(Op, DAG); + case ISD::STACKSAVE: + return LowerSTACKSAVE(Op, DAG); case ISD::CopyToReg: return LowerCopyToReg_128(Op, DAG); default: diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 824a659671967..ead9ca4a311ae 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -63,6 +63,8 @@ enum NodeType : unsigned { PRMT, FCOPYSIGN, DYNAMIC_STACKALLOC, + STACKRESTORE, + STACKSAVE, BrxStart, BrxItem, BrxEnd, @@ -526,6 +528,8 @@ class NVPTXTargetLowering : public TargetLowering { SmallVectorImpl &InVals) const override; SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerSTACKRESTORE(SDValue Op, SelectionDAG &DAG) const; std::string getPrototype(const DataLayout &DL, Type *, const ArgListTy &, diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 1ca3aefb0b093..2658ca3271637 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -3860,6 +3860,44 @@ foreach a_type = ["s", "u"] in { } } +// +// Stack Manipulation +// + +def SDTStackRestore : SDTypeProfile<0, 1, [SDTCisInt<0>]>; + +def stackrestore : + SDNode<"NVPTXISD::STACKRESTORE", SDTStackRestore, + [SDNPHasChain, SDNPSideEffect]>; + +def stacksave : + SDNode<"NVPTXISD::STACKSAVE", SDTIntLeaf, + [SDNPHasChain, SDNPSideEffect]>; + +def STACKRESTORE_32 : + NVPTXInst<(outs), (ins Int32Regs:$ptr), + "stackrestore.u32 \t$ptr;", + [(stackrestore (i32 Int32Regs:$ptr))]>, + Requires<[hasPTX<73>, hasSM<52>]>; + +def STACKSAVE_32 : + NVPTXInst<(outs Int32Regs:$dst), (ins), + "stacksave.u32 \t$dst;", + [(set Int32Regs:$dst, (i32 stacksave))]>, + Requires<[hasPTX<73>, hasSM<52>]>; + +def STACKRESTORE_64 : + NVPTXInst<(outs), (ins Int64Regs:$ptr), + "stackrestore.u64 \t$ptr;", + [(stackrestore (i64 Int64Regs:$ptr))]>, + Requires<[hasPTX<73>, hasSM<52>]>; + +def STACKSAVE_64 : + NVPTXInst<(outs Int64Regs:$dst), (ins), + "stacksave.u64 \t$dst;", + [(set Int64Regs:$dst, (i64 stacksave))]>, + Requires<[hasPTX<73>, hasSM<52>]>; + include "NVPTXIntrinsics.td" //----------------------------------- diff --git a/llvm/test/CodeGen/NVPTX/stacksaverestore.ll b/llvm/test/CodeGen/NVPTX/stacksaverestore.ll new file mode 100644 index 0000000000000..f5a057fcb483c --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/stacksaverestore.ll @@ -0,0 +1,83 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -march=nvptx -mcpu=sm_60 -mattr=+ptx73 | FileCheck %s --check-prefix=CHECK-32 +; RUN: llc < %s -march=nvptx64 -mcpu=sm_60 -mattr=+ptx73 | FileCheck %s --check-prefix=CHECK-64 +; RUN: llc < %s -march=nvptx64 -nvptx-short-ptr -mcpu=sm_60 -mattr=+ptx73 | FileCheck %s --check-prefix=CHECK-MIXED +; RUN: %if ptxas && ptxas-12.0 %{ llc < %s -march=nvptx64 -mcpu=sm_60 -mattr=+ptx73 | %ptxas-verify %} + +target triple = "nvptx64-nvidia-cuda" + +define ptr @test_save() { +; CHECK-32-LABEL: test_save( +; CHECK-32: { +; CHECK-32-NEXT: .reg .b32 %r<3>; +; CHECK-32-EMPTY: +; CHECK-32-NEXT: // %bb.0: +; CHECK-32-NEXT: stacksave.u32 %r1; +; CHECK-32-NEXT: cvta.local.u32 %r2, %r1; +; CHECK-32-NEXT: st.param.b32 [func_retval0], %r2; +; CHECK-32-NEXT: ret; +; +; CHECK-64-LABEL: test_save( +; CHECK-64: { +; CHECK-64-NEXT: .reg .b64 %rd<3>; +; CHECK-64-EMPTY: +; CHECK-64-NEXT: // %bb.0: +; CHECK-64-NEXT: stacksave.u64 %rd1; +; CHECK-64-NEXT: cvta.local.u64 %rd2, %rd1; +; CHECK-64-NEXT: st.param.b64 [func_retval0], %rd2; +; CHECK-64-NEXT: ret; +; +; CHECK-MIXED-LABEL: test_save( +; CHECK-MIXED: { +; CHECK-MIXED-NEXT: .reg .b32 %r<2>; +; CHECK-MIXED-NEXT: .reg .b64 %rd<3>; +; CHECK-MIXED-EMPTY: +; CHECK-MIXED-NEXT: // %bb.0: +; CHECK-MIXED-NEXT: stacksave.u32 %r1; +; CHECK-MIXED-NEXT: cvt.u64.u32 %rd1, %r1; +; CHECK-MIXED-NEXT: cvta.local.u64 %rd2, %rd1; +; CHECK-MIXED-NEXT: st.param.b64 [func_retval0], %rd2; +; CHECK-MIXED-NEXT: ret; + %1 = call ptr @llvm.stacksave() + ret ptr %1 +} + + +define void @test_restore(ptr %p) { +; CHECK-32-LABEL: test_restore( +; CHECK-32: { +; CHECK-32-NEXT: .reg .b32 %r<3>; +; CHECK-32-EMPTY: +; CHECK-32-NEXT: // %bb.0: +; CHECK-32-NEXT: ld.param.u32 %r1, [test_restore_param_0]; +; CHECK-32-NEXT: cvta.to.local.u32 %r2, %r1; +; CHECK-32-NEXT: stackrestore.u32 %r2; +; CHECK-32-NEXT: ret; +; +; CHECK-64-LABEL: test_restore( +; CHECK-64: { +; CHECK-64-NEXT: .reg .b64 %rd<3>; +; CHECK-64-EMPTY: +; CHECK-64-NEXT: // %bb.0: +; CHECK-64-NEXT: ld.param.u64 %rd1, [test_restore_param_0]; +; CHECK-64-NEXT: cvta.to.local.u64 %rd2, %rd1; +; CHECK-64-NEXT: stackrestore.u64 %rd2; +; CHECK-64-NEXT: ret; +; +; CHECK-MIXED-LABEL: test_restore( +; CHECK-MIXED: { +; CHECK-MIXED-NEXT: .reg .b32 %r<2>; +; CHECK-MIXED-NEXT: .reg .b64 %rd<3>; +; CHECK-MIXED-EMPTY: +; CHECK-MIXED-NEXT: // %bb.0: +; CHECK-MIXED-NEXT: ld.param.u64 %rd1, [test_restore_param_0]; +; CHECK-MIXED-NEXT: cvta.to.local.u64 %rd2, %rd1; +; CHECK-MIXED-NEXT: cvt.u32.u64 %r1, %rd2; +; CHECK-MIXED-NEXT: stackrestore.u32 %r1; +; CHECK-MIXED-NEXT: ret; + call void @llvm.stackrestore(ptr %p) + ret void +} + +declare ptr @llvm.stacksave() +declare void @llvm.stackrestore(ptr)