Skip to content

Commit 608202f

Browse files
Vandana2896LewisCrawford
authored andcommitted
Update .global and .align
1 parent 1814093 commit 608202f

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,8 +1609,16 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
16091609
int addrSpace = PTy->getAddressSpace();
16101610
if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() ==
16111611
NVPTX::CUDA) {
1612+
// Special handling for pointer arguments to kernel
1613+
// CUDA kernels assume that pointers are in global address space
1614+
// See:
1615+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
16121616
assert(addrSpace == 0 && "Invalid address space");
16131617
O << ".ptr .global ";
1618+
if (I->getParamAlign().valueOrOne() != 1) {
1619+
Align ParamAlign = I->getParamAlign().value();
1620+
O << ".align " << ParamAlign.value() << " ";
1621+
}
16141622
} else {
16151623
switch (addrSpace) {
16161624
default:
@@ -1626,9 +1634,9 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
16261634
O << ".ptr .global ";
16271635
break;
16281636
}
1637+
Align ParamAlign = I->getParamAlign().valueOrOne();
1638+
O << ".align " << ParamAlign.value() << " ";
16291639
}
1630-
Align ParamAlign = I->getParamAlign().valueOrOne();
1631-
O << ".align " << ParamAlign.value() << " ";
16321640
O << TLI->getParamName(F, paramIndex);
16331641
continue;
16341642
}

llvm/test/CodeGen/NVPTX/kernel-param-align.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
%struct.Large = type { [16 x double] }
55

66
; CHECK: .param .u64 .ptr .global .align 16 func_align_param_0,
7-
; CHECK: .param .u64 func_align_param_1,
7+
; CHECK: .param .u64 .ptr .global func_align_param_1,
88
; CHECK: .param .u32 func_align_param_2
99
define void @func_align(ptr nocapture readonly align 16 %input, ptr nocapture %out, i32 %n) {
1010
entry:
@@ -16,8 +16,8 @@ entry:
1616
ret void
1717
}
1818

19-
; CHECK: .param .u64 func_param_0,
20-
; CHECK: .param .u64 func_param_1,
19+
; CHECK: .param .ptr .global .u64 func_param_0,
20+
; CHECK: .param .ptr .global .u64 func_param_1,
2121
; CHECK: .param .u32 func_param_2
2222
define void @func(ptr nocapture readonly %input, ptr nocapture %out, i32 %n) {
2323
entry:

0 commit comments

Comments
 (0)