Skip to content

Commit 06bcc34

Browse files
authored
[NVPTX] Auto-upgrade nvvm.grid_constant to param attribute (#155489)
Upgrade the !"grid_constant" !nvvm.annotation to a "nvvm.grid_constant" attribute. This attribute is much simpler for front-ends to apply and faster and simpler to query.
1 parent 5bca8f2 commit 06bcc34

File tree

9 files changed

+91
-240
lines changed

9 files changed

+91
-240
lines changed

clang/lib/CodeGen/Targets/NVPTX.cpp

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,6 @@ class NVPTXTargetCodeGenInfo : public TargetCodeGenInfo {
8787
static void addNVVMMetadata(llvm::GlobalValue *GV, StringRef Name,
8888
int Operand);
8989

90-
static void
91-
addGridConstantNVVMMetadata(llvm::GlobalValue *GV,
92-
const SmallVectorImpl<int> &GridConstantArgs);
93-
9490
private:
9591
static void emitBuiltinSurfTexDeviceCopy(CodeGenFunction &CGF, LValue Dst,
9692
LValue Src) {
@@ -265,27 +261,24 @@ void NVPTXTargetCodeGenInfo::setTargetAttributes(
265261
// By default, all functions are device functions
266262
if (FD->hasAttr<DeviceKernelAttr>() || FD->hasAttr<CUDAGlobalAttr>()) {
267263
// OpenCL/CUDA kernel functions get kernel metadata
268-
// Create !{<func-ref>, metadata !"kernel", i32 1} node
269264
// And kernel functions are not subject to inlining
270265
F->addFnAttr(llvm::Attribute::NoInline);
271266
if (FD->hasAttr<CUDAGlobalAttr>()) {
272-
SmallVector<int, 10> GCI;
267+
F->setCallingConv(llvm::CallingConv::PTX_Kernel);
268+
273269
for (auto IV : llvm::enumerate(FD->parameters()))
274270
if (IV.value()->hasAttr<CUDAGridConstantAttr>())
275-
// For some reason arg indices are 1-based in NVVM
276-
GCI.push_back(IV.index() + 1);
277-
// Create !{<func-ref>, metadata !"kernel", i32 1} node
278-
F->setCallingConv(llvm::CallingConv::PTX_Kernel);
279-
addGridConstantNVVMMetadata(F, GCI);
271+
F->addParamAttr(
272+
IV.index(),
273+
llvm::Attribute::get(F->getContext(), "nvvm.grid_constant"));
280274
}
281275
if (CUDALaunchBoundsAttr *Attr = FD->getAttr<CUDALaunchBoundsAttr>())
282276
M.handleCUDALaunchBoundsAttr(F, Attr);
283277
}
284278
}
285279
// Attach kernel metadata directly if compiling for NVPTX.
286-
if (FD->hasAttr<DeviceKernelAttr>()) {
280+
if (FD->hasAttr<DeviceKernelAttr>())
287281
F->setCallingConv(llvm::CallingConv::PTX_Kernel);
288-
}
289282
}
290283

291284
void NVPTXTargetCodeGenInfo::addNVVMMetadata(llvm::GlobalValue *GV,
@@ -305,29 +298,6 @@ void NVPTXTargetCodeGenInfo::addNVVMMetadata(llvm::GlobalValue *GV,
305298
MD->addOperand(llvm::MDNode::get(Ctx, MDVals));
306299
}
307300

308-
void NVPTXTargetCodeGenInfo::addGridConstantNVVMMetadata(
309-
llvm::GlobalValue *GV, const SmallVectorImpl<int> &GridConstantArgs) {
310-
311-
llvm::Module *M = GV->getParent();
312-
llvm::LLVMContext &Ctx = M->getContext();
313-
314-
// Get "nvvm.annotations" metadata node
315-
llvm::NamedMDNode *MD = M->getOrInsertNamedMetadata("nvvm.annotations");
316-
317-
SmallVector<llvm::Metadata *, 5> MDVals = {llvm::ConstantAsMetadata::get(GV)};
318-
if (!GridConstantArgs.empty()) {
319-
SmallVector<llvm::Metadata *, 10> GCM;
320-
for (int I : GridConstantArgs)
321-
GCM.push_back(llvm::ConstantAsMetadata::get(
322-
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), I)));
323-
MDVals.append({llvm::MDString::get(Ctx, "grid_constant"),
324-
llvm::MDNode::get(Ctx, GCM)});
325-
}
326-
327-
// Append metadata to nvvm.annotations
328-
MD->addOperand(llvm::MDNode::get(Ctx, MDVals));
329-
}
330-
331301
bool NVPTXTargetCodeGenInfo::shouldEmitStaticExternCAliases() const {
332302
return false;
333303
}

clang/test/CodeGenCUDA/grid-constant.cu

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,9 @@ void foo() {
1919
tkernel_const<S><<<1,1>>>({});
2020
tkernel<const S><<<1,1>>>(1, {});
2121
}
22-
//.
23-
//.
24-
// CHECK: [[META0:![0-9]+]] = !{ptr @_Z6kernel1Sii, !"grid_constant", [[META1:![0-9]+]]}
25-
// CHECK: [[META1]] = !{i32 1, i32 3}
26-
// CHECK: [[META2:![0-9]+]] = !{ptr @_Z13tkernel_constIK1SEvT_, !"grid_constant", [[META3:![0-9]+]]}
27-
// CHECK: [[META3]] = !{i32 1}
28-
// CHECK: [[META4:![0-9]+]] = !{ptr @_Z13tkernel_constI1SEvT_, !"grid_constant", [[META3]]}
29-
// CHECK: [[META5:![0-9]+]] = !{ptr @_Z7tkernelIK1SEviT_, !"grid_constant", [[META6:![0-9]+]]}
30-
// CHECK: [[META6]] = !{i32 2}
31-
//.
22+
23+
// CHECK: define dso_local ptx_kernel void @_Z6kernel1Sii(ptr noundef byval(%struct.S) align 1 "nvvm.grid_constant" %gc_arg1, i32 noundef %arg2, i32 noundef "nvvm.grid_constant" %gc_arg3)
24+
// CHECK: define ptx_kernel void @_Z13tkernel_constIK1SEvT_(ptr noundef byval(%struct.S) align 1 "nvvm.grid_constant" %arg)
25+
// CHECK: define ptx_kernel void @_Z13tkernel_constI1SEvT_(ptr noundef byval(%struct.S) align 1 "nvvm.grid_constant" %arg)
26+
// CHECK: define ptx_kernel void @_Z7tkernelIK1SEviT_(i32 noundef %dummy, ptr noundef byval(%struct.S) align 1 "nvvm.grid_constant" %arg)
27+

llvm/docs/NVPTXUsage.rst

Lines changed: 19 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,19 @@ not.
5757
5858
When compiled, the PTX kernel functions are callable by host-side code.
5959

60+
61+
Parameter Attributes
62+
--------------------
63+
64+
``"nvvm.grid_constant"``
65+
This attribute may be attached to a ``byval`` parameter of a kernel function
66+
to indicate that the parameter should be lowered as a direct reference to
67+
the grid-constant memory of the parameter, as opposed to a copy of the
68+
parameter in local memory. Writing to a grid-constant parameter is
69+
undefined behavior. Unlike a normal ``byval`` parameter, the address of a
70+
grid-constant parameter is not unique to a given function invocation but
71+
instead is shared by all kernels in the grid.
72+
6073
.. _nvptx_fnattrs:
6174

6275
Function Attributes
@@ -2289,9 +2302,9 @@ The Kernel
22892302
; Intrinsic to read X component of thread ID
22902303
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() readnone nounwind
22912304
2292-
define void @kernel(ptr addrspace(1) %A,
2293-
ptr addrspace(1) %B,
2294-
ptr addrspace(1) %C) {
2305+
define ptx_kernel void @kernel(ptr addrspace(1) %A,
2306+
ptr addrspace(1) %B,
2307+
ptr addrspace(1) %C) {
22952308
entry:
22962309
; What is my ID?
22972310
%id = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() readnone nounwind
@@ -2314,9 +2327,6 @@ The Kernel
23142327
ret void
23152328
}
23162329
2317-
!nvvm.annotations = !{!0}
2318-
!0 = !{ptr @kernel, !"kernel", i32 1}
2319-
23202330
23212331
We can use the LLVM ``llc`` tool to directly run the NVPTX code generator:
23222332

@@ -2442,34 +2452,6 @@ and non-generic address spaces.
24422452
See :ref:`address_spaces` and :ref:`nvptx_intrinsics` for more information.
24432453

24442454

2445-
Kernel Metadata
2446-
^^^^^^^^^^^^^^^
2447-
2448-
In PTX, a function can be either a `kernel` function (callable from the host
2449-
program), or a `device` function (callable only from GPU code). You can think
2450-
of `kernel` functions as entry-points in the GPU program. To mark an LLVM IR
2451-
function as a `kernel` function, we make use of special LLVM metadata. The
2452-
NVPTX back-end will look for a named metadata node called
2453-
``nvvm.annotations``. This named metadata must contain a list of metadata that
2454-
describe the IR. For our purposes, we need to declare a metadata node that
2455-
assigns the "kernel" attribute to the LLVM IR function that should be emitted
2456-
as a PTX `kernel` function. These metadata nodes take the form:
2457-
2458-
.. code-block:: text
2459-
2460-
!{<function ref>, metadata !"kernel", i32 1}
2461-
2462-
For the previous example, we have:
2463-
2464-
.. code-block:: llvm
2465-
2466-
!nvvm.annotations = !{!0}
2467-
!0 = !{ptr @kernel, !"kernel", i32 1}
2468-
2469-
Here, we have a single metadata declaration in ``nvvm.annotations``. This
2470-
metadata annotates our ``@kernel`` function with the ``kernel`` attribute.
2471-
2472-
24732455
Running the Kernel
24742456
------------------
24752457

@@ -2669,9 +2651,9 @@ Libdevice provides an ``__nv_powf`` function that we will use.
26692651
; libdevice function
26702652
declare float @__nv_powf(float, float)
26712653
2672-
define void @kernel(ptr addrspace(1) %A,
2673-
ptr addrspace(1) %B,
2674-
ptr addrspace(1) %C) {
2654+
define ptx_kernel void @kernel(ptr addrspace(1) %A,
2655+
ptr addrspace(1) %B,
2656+
ptr addrspace(1) %C) {
26752657
entry:
26762658
; What is my ID?
26772659
%id = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() readnone nounwind
@@ -2694,9 +2676,6 @@ Libdevice provides an ``__nv_powf`` function that we will use.
26942676
ret void
26952677
}
26962678
2697-
!nvvm.annotations = !{!0}
2698-
!0 = !{ptr @kernel, !"kernel", i32 1}
2699-
27002679
27012680
To compile this kernel, we perform the following steps:
27022681

llvm/lib/IR/AutoUpgrade.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5381,6 +5381,16 @@ bool static upgradeSingleNVVMAnnotation(GlobalValue *GV, StringRef K,
53815381
upgradeNVVMFnVectorAttr("nvvm.cluster_dim", K[0], GV, V);
53825382
return true;
53835383
}
5384+
if (K == "grid_constant") {
5385+
const auto Attr = Attribute::get(GV->getContext(), "nvvm.grid_constant");
5386+
for (const auto &Op : cast<MDNode>(V)->operands()) {
5387+
// For some reason, the index is 1-based in the metadata. Good thing we're
5388+
// able to auto-upgrade it!
5389+
const auto Index = mdconst::extract<ConstantInt>(Op)->getZExtValue() - 1;
5390+
cast<Function>(GV)->addParamAttr(Index, Attr);
5391+
}
5392+
return true;
5393+
}
53845394

53855395
return false;
53865396
}

llvm/lib/Target/NVPTX/NVPTXUtilities.cpp

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,6 @@ void clearAnnotationCache(const Module *Mod) {
5555
AC.Cache.erase(Mod);
5656
}
5757

58-
static void readIntVecFromMDNode(const MDNode *MetadataNode,
59-
std::vector<unsigned> &Vec) {
60-
for (unsigned i = 0, e = MetadataNode->getNumOperands(); i != e; ++i) {
61-
ConstantInt *Val =
62-
mdconst::extract<ConstantInt>(MetadataNode->getOperand(i));
63-
Vec.push_back(Val->getZExtValue());
64-
}
65-
}
66-
6758
static void cacheAnnotationFromMD(const MDNode *MetadataNode,
6859
key_val_pair_t &retval) {
6960
auto &AC = getAnnotationCache();
@@ -83,19 +74,8 @@ static void cacheAnnotationFromMD(const MDNode *MetadataNode,
8374
if (ConstantInt *Val = mdconst::dyn_extract<ConstantInt>(
8475
MetadataNode->getOperand(i + 1))) {
8576
retval[Key].push_back(Val->getZExtValue());
86-
} else if (MDNode *VecMd =
87-
dyn_cast<MDNode>(MetadataNode->getOperand(i + 1))) {
88-
// note: only "grid_constant" annotations support vector MDNodes.
89-
// assert: there can only exist one unique key value pair of
90-
// the form (string key, MDNode node). Operands of such a node
91-
// shall always be unsigned ints.
92-
auto [It, Inserted] = retval.try_emplace(Key);
93-
if (Inserted) {
94-
readIntVecFromMDNode(VecMd, It->second);
95-
continue;
96-
}
9777
} else {
98-
llvm_unreachable("Value operand not a constant int or an mdnode");
78+
llvm_unreachable("Value operand not a constant int");
9979
}
10080
}
10181
}
@@ -179,16 +159,13 @@ static bool globalHasNVVMAnnotation(const Value &V, const std::string &Prop) {
179159
}
180160

181161
static bool argHasNVVMAnnotation(const Value &Val,
182-
const std::string &Annotation,
183-
const bool StartArgIndexAtOne = false) {
162+
const std::string &Annotation) {
184163
if (const Argument *Arg = dyn_cast<Argument>(&Val)) {
185164
const Function *Func = Arg->getParent();
186165
std::vector<unsigned> Annot;
187166
if (findAllNVVMAnnotation(Func, Annotation, Annot)) {
188-
const unsigned BaseOffset = StartArgIndexAtOne ? 1 : 0;
189-
if (is_contained(Annot, BaseOffset + Arg->getArgNo())) {
167+
if (is_contained(Annot, Arg->getArgNo()))
190168
return true;
191-
}
192169
}
193170
}
194171
return false;
@@ -250,8 +227,7 @@ bool isParamGridConstant(const Argument &Arg) {
250227
}
251228

252229
// "grid_constant" counts argument indices starting from 1
253-
if (argHasNVVMAnnotation(Arg, "grid_constant",
254-
/*StartArgIndexAtOne*/ true))
230+
if (Arg.hasAttribute("nvvm.grid_constant"))
255231
return true;
256232

257233
return false;

0 commit comments

Comments
 (0)