diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp index 9f31b72bbceb1..486ec76f60ec3 100644 --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -1588,12 +1588,21 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { if (isKernelFunc) { if (PTy) { - // Special handling for pointer arguments to kernel O << "\t.param .u" << PTySizeInBits << " "; - if (static_cast(TM).getDrvInterface() != + int addrSpace = PTy->getAddressSpace(); + if (static_cast(TM).getDrvInterface() == NVPTX::CUDA) { - int addrSpace = PTy->getAddressSpace(); + // Special handling for pointer arguments to kernel + // CUDA kernels assume that pointers are in global address space + // See: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space + O << ".ptr .global "; + if (I->getParamAlign().valueOrOne() != 1) { + Align ParamAlign = I->getParamAlign().value(); + O << ".align " << ParamAlign.value() << " "; + } + } else { switch (addrSpace) { default: O << ".ptr "; diff --git a/llvm/test/CodeGen/NVPTX/i1-param.ll b/llvm/test/CodeGen/NVPTX/i1-param.ll index 375752b619a58..3673ee7c77a13 100644 --- a/llvm/test/CodeGen/NVPTX/i1-param.ll +++ b/llvm/test/CodeGen/NVPTX/i1-param.ll @@ -8,7 +8,7 @@ target triple = "nvptx-nvidia-cuda" ; CHECK: .entry foo ; CHECK: .param .u8 foo_param_0 -; CHECK: .param .u64 foo_param_1 +; CHECK: .param .u64 .ptr .global foo_param_1 define void @foo(i1 %p, ptr %out) { %val = zext i1 %p to i32 store i32 %val, ptr %out diff --git a/llvm/test/CodeGen/NVPTX/kernel-param-align.ll b/llvm/test/CodeGen/NVPTX/kernel-param-align.ll new file mode 100644 index 0000000000000..3350b4fcda83e --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/kernel-param-align.ll @@ -0,0 +1,36 @@ +; RUN: llc < %s -march=nvptx64 -mcpu=sm_72 2>&1 | FileCheck %s +; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_72 | %ptxas-verify %} + +%struct.Large = type { [16 x double] } + +; CHECK-LABEL: func_align +; CHECK: .param .u64 .ptr .global .align 16 func_align_param_0, +; CHECK: .param .u64 .ptr .global func_align_param_1, +; CHECK: .param .u32 .ptr .global func_align_param_2 +define void @func_align(ptr nocapture readonly align 16 %input, ptr nocapture %out, ptr addrspace(3) %n) { +entry: + %0 = addrspacecast ptr %out to ptr addrspace(1) + %1 = addrspacecast ptr %input to ptr addrspace(1) + %getElem = getelementptr inbounds %struct.Large, ptr addrspace(1) %1, i64 0, i32 0, i64 5 + %tmp2 = load i32, ptr addrspace(1) %getElem, align 8 + store i32 %tmp2, ptr addrspace(1) %0, align 4 + ret void +} + +; CHECK-LABEL: func +; CHECK: .param .u64 .ptr .global func_param_0, +; CHECK: .param .u64 .ptr .global func_param_1, +; CHECK: .param .u32 func_param_2 +define void @func(ptr nocapture readonly %input, ptr nocapture %out, i32 %n) { +entry: + %0 = addrspacecast ptr %out to ptr addrspace(1) + %1 = addrspacecast ptr %input to ptr addrspace(1) + %getElem = getelementptr inbounds %struct.Large, ptr addrspace(1) %1, i64 0, i32 0, i64 5 + %tmp2 = load i32, ptr addrspace(1) %getElem, align 8 + store i32 %tmp2, ptr addrspace(1) %0, align 4 + ret void +} + +!nvvm.annotations = !{!0, !1} +!0 = !{ptr @func_align, !"kernel", i32 1} +!1 = !{ptr @func, !"kernel", i32 1} \ No newline at end of file