Skip to content

Commit 78ea5e5

Browse files
committed
Improve kernel arg attrs accuracy
Allow explicitly defined kernel .ptr arg alignments of 1 to be specified. Avoid forcing .global annotations for generic pointers on CUDA. Allow .local pointers, since the PTX specification says they are legal here. Avoid outputting unnecessary extra spaces in PTX between .ptr and the memory space. Improve test-cases by removing unnecessary function bodies, and adding more varied alignments + address spaces.
1 parent aaade68 commit 78ea5e5

File tree

3 files changed

+36
-46
lines changed

3 files changed

+36
-46
lines changed

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,37 +1600,33 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
16001600

16011601
if (isKernelFunc) {
16021602
if (PTy) {
1603-
O << "\t.param .u" << PTySizeInBits << " ";
1603+
O << "\t.param .u" << PTySizeInBits << " .ptr ";
16041604

1605-
int addrSpace = PTy->getAddressSpace();
1606-
const bool IsCUDA =
1607-
static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() ==
1608-
NVPTX::CUDA;
1609-
1610-
O << ".ptr ";
1611-
switch (addrSpace) {
1605+
switch (PTy->getAddressSpace()) {
16121606
default:
1613-
// Special handling for pointer arguments to kernel
1614-
// CUDA kernels assume that pointers are in global address space
1615-
// See:
1616-
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
1617-
if (IsCUDA)
1618-
O << " .global ";
16191607
break;
1620-
case ADDRESS_SPACE_CONST:
1621-
O << " .const ";
1608+
case ADDRESS_SPACE_GLOBAL:
1609+
O << ".global ";
16221610
break;
16231611
case ADDRESS_SPACE_SHARED:
1624-
O << " .shared ";
1612+
O << ".shared ";
16251613
break;
1626-
case ADDRESS_SPACE_GLOBAL:
1627-
O << " .global ";
1614+
case ADDRESS_SPACE_CONST:
1615+
O << ".const ";
1616+
break;
1617+
case ADDRESS_SPACE_LOCAL:
1618+
O << ".local ";
16281619
break;
16291620
}
16301621

1631-
Align ParamAlign = I->getParamAlign().valueOrOne();
1632-
if (ParamAlign != 1 || !IsCUDA)
1633-
O << ".align " << ParamAlign.value() << " ";
1622+
const bool IsCUDA =
1623+
static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() ==
1624+
NVPTX::CUDA;
1625+
1626+
MaybeAlign ParamAlign = I->getParamAlign();
1627+
if (ParamAlign.has_value() || !IsCUDA)
1628+
O << ".align " << ParamAlign.valueOrOne().value() << " ";
1629+
16341630
O << TLI->getParamName(F, paramIndex);
16351631
continue;
16361632
}

llvm/test/CodeGen/NVPTX/i1-param.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ target triple = "nvptx-nvidia-cuda"
88

99
; CHECK: .entry foo
1010
; CHECK: .param .u8 foo_param_0
11-
; CHECK: .param .u64 .ptr .global foo_param_1
11+
; CHECK: .param .u64 .ptr foo_param_1
1212
define void @foo(i1 %p, ptr %out) {
1313
%val = zext i1 %p to i32
1414
store i32 %val, ptr %out

llvm/test/CodeGen/NVPTX/kernel-param-align.ll

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,42 +4,36 @@
44
%struct.Large = type { [16 x double] }
55

66
; CHECK-LABEL: .entry func_align(
7-
; CHECK: .param .u64 .ptr .global .align 16 func_align_param_0
8-
; CHECK: .param .u64 .ptr .global .align 16 func_align_param_1
9-
; CHECK: .param .u64 .ptr .global .align 16 func_align_param_2
10-
; CHECK: .param .u64 .ptr .shared .align 16 func_align_param_3
7+
; CHECK: .param .u64 .ptr .align 1 func_align_param_0
8+
; CHECK: .param .u64 .ptr .align 2 func_align_param_1
9+
; CHECK: .param .u64 .ptr .global .align 4 func_align_param_2
10+
; CHECK: .param .u64 .ptr .shared .align 8 func_align_param_3
1111
; CHECK: .param .u64 .ptr .const .align 16 func_align_param_4
12-
define void @func_align(ptr nocapture readonly align 16 %input,
13-
ptr nocapture align 16 %out,
14-
ptr addrspace(1) align 16 %global,
15-
ptr addrspace(3) align 16 %shared,
16-
ptr addrspace(4) align 16 %const) {
12+
; CHECK: .param .u64 .ptr .local .align 32 func_align_param_5
13+
define void @func_align(ptr nocapture readonly align 1 %input,
14+
ptr nocapture align 2 %out,
15+
ptr addrspace(1) align 4 %global,
16+
ptr addrspace(3) align 8 %shared,
17+
ptr addrspace(4) align 16 %const,
18+
ptr addrspace(5) align 32 %local) {
1719
entry:
18-
%0 = addrspacecast ptr %out to ptr addrspace(1)
19-
%1 = addrspacecast ptr %input to ptr addrspace(1)
20-
%getElem = getelementptr inbounds %struct.Large, ptr addrspace(1) %1, i64 0, i32 0, i64 5
21-
%tmp2 = load i32, ptr addrspace(1) %getElem, align 8
22-
store i32 %tmp2, ptr addrspace(1) %0, align 4
2320
ret void
2421
}
2522

2623
; CHECK-LABEL: .entry func_noalign(
27-
; CHECK: .param .u64 .ptr .global func_noalign_param_0
28-
; CHECK: .param .u64 .ptr .global func_noalign_param_1
24+
; CHECK: .param .u64 .ptr func_noalign_param_0
25+
; CHECK: .param .u64 .ptr func_noalign_param_1
2926
; CHECK: .param .u64 .ptr .global func_noalign_param_2
3027
; CHECK: .param .u64 .ptr .shared func_noalign_param_3
31-
; CHECK: .param .u64 .ptr .const func_noalign_param_4
28+
; CHECK: .param .u64 .ptr .const func_noalign_param_4
29+
; CHECK: .param .u64 .ptr .local func_noalign_param_5
3230
define void @func_noalign(ptr nocapture readonly %input,
3331
ptr nocapture %out,
3432
ptr addrspace(1) %global,
3533
ptr addrspace(3) %shared,
36-
ptr addrspace(4) %const) {
34+
ptr addrspace(4) %const,
35+
ptr addrspace(5) %local) {
3736
entry:
38-
%0 = addrspacecast ptr %out to ptr addrspace(1)
39-
%1 = addrspacecast ptr %input to ptr addrspace(1)
40-
%getElem = getelementptr inbounds %struct.Large, ptr addrspace(1) %1, i64 0, i32 0, i64 5
41-
%tmp2 = load i32, ptr addrspace(1) %getElem, align 8
42-
store i32 %tmp2, ptr addrspace(1) %0, align 4
4337
ret void
4438
}
4539

0 commit comments

Comments
 (0)