Skip to content
Closed
16 changes: 13 additions & 3 deletions llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1588,12 +1588,22 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {

if (isKernelFunc) {
if (PTy) {
// Special handling for pointer arguments to kernel
O << "\t.param .u" << PTySizeInBits << " ";

if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
int addrSpace = PTy->getAddressSpace();
if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() ==
NVPTX::CUDA) {
int addrSpace = PTy->getAddressSpace();
// Special handling for pointer arguments to kernel
// CUDA kernels assume that pointers are in global address space
// See:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
assert(addrSpace == 0 && "Invalid address space");
O << ".ptr .global ";
if (I->getParamAlign().valueOrOne() != 1) {
Align ParamAlign = I->getParamAlign().value();
O << ".align " << ParamAlign.value() << " ";
}
} else {
switch (addrSpace) {
default:
O << ".ptr ";
Expand Down
36 changes: 36 additions & 0 deletions llvm/test/CodeGen/NVPTX/kernel-param-align.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
; RUN: llc < %s -march=nvptx64 -mcpu=sm_72 2>&1 | FileCheck %s
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_72 | %ptxas-verify %}

%struct.Large = type { [16 x double] }

; CHECK-LABEL: func_align
; CHECK: .param .u64 .ptr .global .align 16 func_align_param_0,
; CHECK: .param .u64 .ptr .global func_align_param_1,
; CHECK: .param .u32 func_align_param_2
define void @func_align(ptr nocapture readonly align 16 %input, ptr addrspace(3) nocapture %out, i32 %n) {
entry:
%0 = addrspacecast ptr %out to ptr addrspace(1)
%1 = addrspacecast ptr %input to ptr addrspace(1)
%getElem = getelementptr inbounds %struct.Large, ptr addrspace(1) %1, i64 0, i32 0, i64 5
%tmp2 = load i32, ptr addrspace(1) %getElem, align 8
store i32 %tmp2, ptr addrspace(1) %0, align 4
ret void
}

; CHECK-LABEL: func
; CHECK: .param .ptr .global .u64 func_param_0,
; CHECK: .param .ptr .global .u64 func_param_1,
; CHECK: .param .u32 func_param_2
define void @func(ptr nocapture readonly %input, ptr nocapture %out, i32 %n) {
entry:
%0 = addrspacecast ptr %out to ptr addrspace(1)
%1 = addrspacecast ptr %input to ptr addrspace(1)
%getElem = getelementptr inbounds %struct.Large, ptr addrspace(1) %1, i64 0, i32 0, i64 5
%tmp2 = load i32, ptr addrspace(1) %getElem, align 8
store i32 %tmp2, ptr addrspace(1) %0, align 4
ret void
}

!nvvm.annotations = !{!0, !1}
!0 = !{ptr @func_align, !"kernel", i32 1}
!1 = !{ptr @func, !"kernel", i32 1}