Skip to content

Commit b08b219

Browse files
authored
[MLIR][NVVM] Add "blocksareclusters" kernel attribute support (#154519)
This change adds "nvvm.blocksareclusters" kernel attribute support in NVVM Dialect/MLIR.
1 parent be179d0 commit b08b219

File tree

5 files changed

+52
-2
lines changed

5 files changed

+52
-2
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,15 @@ def NVVM_Dialect : Dialect {
8383
/// are grid constants.
8484
static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; }
8585

86+
/// Get the name of the attribute used to annotate the `.blocksareclusters`
87+
/// PTX directive for kernel functions.
88+
/// This attribute implies that the grid launch configuration for the
89+
/// corresponding kernel function is specifying the number of clusters
90+
/// instead of the number of thread blocks. This attribute is only
91+
/// allowed for kernel functions and requires nvvm.reqntid and
92+
/// nvvm.cluster_dim attributes.
93+
static StringRef getBlocksAreClustersAttrName() { return "nvvm.blocksareclusters"; }
94+
8695
/// Verify an attribute from this dialect on the argument at 'argIndex' for
8796
/// the region at 'regionIndex' on the given operation. Returns failure if
8897
/// the verification failed, success otherwise. This hook may optionally be

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1925,19 +1925,31 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
19251925
attrName == NVVMDialect::getReqntidAttrName() ||
19261926
attrName == NVVMDialect::getClusterDimAttrName()) {
19271927
auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
1928-
if (!values || values.empty() || values.size() > 3)
1928+
if (!values || values.empty() || values.size() > 3) {
19291929
return op->emitError()
19301930
<< "'" << attrName
19311931
<< "' attribute must be integer array with maximum 3 index";
1932+
}
19321933
}
19331934
// If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
19341935
// attribute
19351936
if (attrName == NVVMDialect::getMinctasmAttrName() ||
19361937
attrName == NVVMDialect::getMaxnregAttrName() ||
19371938
attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
1938-
if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
1939+
if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) {
19391940
return op->emitError()
19401941
<< "'" << attrName << "' attribute must be integer constant";
1942+
}
1943+
}
1944+
// blocksareclusters must be used along with reqntid and cluster_dim
1945+
if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
1946+
if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) ||
1947+
!op->hasAttr(NVVMDialect::getClusterDimAttrName())) {
1948+
return op->emitError()
1949+
<< "'" << attrName << "' attribute must be used along with "
1950+
<< "'" << NVVMDialect::getReqntidAttrName() << "' and "
1951+
<< "'" << NVVMDialect::getClusterDimAttrName() << "'";
1952+
}
19411953
}
19421954

19431955
return success();

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,11 @@ class NVVMDialectLLVMIRTranslationInterface
468468
} else if (attribute.getName() ==
469469
NVVM::NVVMDialect::getKernelFuncAttrName()) {
470470
llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel);
471+
} else if (attribute.getName() ==
472+
NVVM::NVVMDialect::getBlocksAreClustersAttrName()) {
473+
llvmFunc->addFnAttr("nvvm.blocksareclusters");
471474
}
475+
472476
return success();
473477
}
474478

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,22 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 3, 4
5656

5757
// -----
5858

59+
// expected-error @below {{'"nvvm.blocksareclusters"' attribute must be used along with 'nvvm.reqntid' and 'nvvm.cluster_dim'}}
60+
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.blocksareclusters,
61+
nvvm.cluster_dim = array<i32: 3, 5, 7>} {
62+
llvm.return
63+
}
64+
65+
// -----
66+
67+
// expected-error @below {{'"nvvm.blocksareclusters"' attribute must be used along with 'nvvm.reqntid' and 'nvvm.cluster_dim'}}
68+
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.blocksareclusters,
69+
nvvm.reqntid = array<i32: 1, 23, 32>} {
70+
llvm.return
71+
}
72+
73+
// -----
74+
5975
llvm.func @nvvm_fence_proxy_acquire(%addr : !llvm.ptr, %size : i32) {
6076
// expected-error @below {{'nvvm.fence.proxy.acquire' op uni-directional proxies only support generic for from_proxy attribute}}
6177
nvvm.fence.proxy.acquire #nvvm.mem_scope<cta> %addr, %size from_proxy=#nvvm.proxy_kind<tensormap> to_proxy=#nvvm.proxy_kind<generic>

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,16 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 1, 2
692692

693693
// CHECK: define ptx_kernel void @kernel_func() #[[ATTR0:[0-9]+]]
694694
// CHECK: attributes #[[ATTR0]] = { "nvvm.maxnreg"="32" "nvvm.maxntid"="1,23,32" "nvvm.minctasm"="16" }
695+
// -----
695696

697+
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.blocksareclusters,
698+
nvvm.reqntid = array<i32: 1, 23, 32>,
699+
nvvm.cluster_dim = array<i32: 3, 5, 7>} {
700+
llvm.return
701+
}
702+
703+
// CHECK: define ptx_kernel void @kernel_func() #[[ATTR0:[0-9]+]]
704+
// CHECK: attributes #[[ATTR0]] = { "nvvm.blocksareclusters" "nvvm.cluster_dim"="3,5,7" "nvvm.reqntid"="1,23,32" }
696705
// -----
697706
// CHECK: define ptx_kernel void @kernel_func
698707
// CHECK: !nvvm.annotations =

0 commit comments

Comments
 (0)