Skip to content

Commit e69fb42

Browse files
authored
[MLIR][NVPTX] Add intrinsics and Ops to read smem-sizes (#173089)
This patch adds three intrinsics and their corresponding Ops representing the PTX special-register read instructions that report various configurations of shared-memory sizes. Signed-off-by: Durgadoss R <[email protected]>
1 parent 80887c7 commit e69fb42

File tree

7 files changed

+105
-2
lines changed

7 files changed

+105
-2
lines changed

llvm/docs/NVPTXUsage.rst

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,36 @@ map in the following way to CUDA builtins:
264264
``gridDim`` ``@llvm.nvvm.read.ptx.sreg.nctaid.*``
265265
============ =====================================
266266

267+
'``llvm.nvvm.read.ptx.sreg.*_smem_size``'
268+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
269+
270+
Syntax:
271+
"""""""
272+
273+
.. code-block:: llvm
274+
275+
declare i32 @llvm.nvvm.read.ptx.sreg.total_smem_size()
276+
declare i32 @llvm.nvvm.read.ptx.sreg.aggr_smem_size()
277+
declare i32 @llvm.nvvm.read.ptx.sreg.dynamic_smem_size()
278+
279+
Overview:
280+
"""""""""
281+
282+
The '``@llvm.nvvm.read.ptx.sreg.total_smem_size``' intrinsic reads the
283+
PTX special register that holds the total amount of shared memory
284+
allocated per CTA for the kernel at launch.
285+
286+
The reported value includes both statically allocated and dynamically
287+
requested shared memory, but excludes any shared memory reserved for
288+
system use. The size is expressed in units of the architecture-specific
289+
shared memory allocation granularity. For targets sm_8x and newer,
290+
this granularity is 128 bytes.
291+
292+
The '``aggr_smem_size``' variant returns the aggregate shared memory size,
293+
including the portion reserved for system software use.
294+
295+
The '``dynamic_smem_size``' variant returns the amount of dynamic shared
296+
memory allocated per CTA for the kernel at launch time.
267297

268298
Barriers
269299
--------

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2303,8 +2303,8 @@ foreach vec = [TV_I8, TV_I16, TV_I32,
23032303
//
23042304
// Accessing special registers.
23052305
//
2306-
class PTXReadSRegIntrinsicNB_r32<list<IntrinsicProperty> properties = []>
2307-
: PureIntrinsic<[llvm_i32_ty], [], [NoUndef<RetIndex>] # properties>;
2306+
class PTXReadSRegIntrinsicNB_r32<list<IntrinsicProperty> properties = [], string name = "">
2307+
: PureIntrinsic<[llvm_i32_ty], [], [NoUndef<RetIndex>] # properties, name>;
23082308

23092309
class PTXReadSRegIntrinsic_r32<list<IntrinsicProperty> properties = []>
23102310
: PTXReadSRegIntrinsicNB_r32<properties>, NVVMBuiltin;
@@ -2406,6 +2406,13 @@ defm int_nvvm_read_ptx_sreg_cluster_nctaid : PTXReadSRegIntrinsicNB_v4i32<MAX_GR
24062406
def int_nvvm_read_ptx_sreg_cluster_ctarank : PTXReadSRegIntrinsicNB_r32;
24072407
def int_nvvm_read_ptx_sreg_cluster_nctarank : PTXReadSRegIntrinsicNB_r32;
24082408

2409+
def int_nvvm_read_ptx_sreg_total_smem_size :
2410+
PTXReadSRegIntrinsicNB_r32<[], "llvm.nvvm.read.ptx.sreg.total_smem_size">;
2411+
def int_nvvm_read_ptx_sreg_aggr_smem_size :
2412+
PTXReadSRegIntrinsicNB_r32<[], "llvm.nvvm.read.ptx.sreg.aggr_smem_size">;
2413+
def int_nvvm_read_ptx_sreg_dynamic_smem_size :
2414+
PTXReadSRegIntrinsicNB_r32<[], "llvm.nvvm.read.ptx.sreg.dynamic_smem_size">;
2415+
24092416
//
24102417
// SHUFFLE
24112418
//

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4761,6 +4761,14 @@ def INT_PTX_SREG_CLUSTER_NCTARANK:
47614761
int_nvvm_read_ptx_sreg_cluster_nctarank,
47624762
[hasSM<90>, hasPTX<78>]>;
47634763

4764+
def INT_PTX_SREG_TOTAL_SMEM_SIZE :
4765+
PTX_READ_SREG_R32<"total_smem_size", int_nvvm_read_ptx_sreg_total_smem_size>;
4766+
def INT_PTX_SREG_DYNAMIC_SMEM_SIZE :
4767+
PTX_READ_SREG_R32<"dynamic_smem_size", int_nvvm_read_ptx_sreg_dynamic_smem_size>;
4768+
def INT_PTX_SREG_AGGR_SMEM_SIZE :
4769+
PTX_READ_SREG_R32<"aggr_smem_size",
4770+
int_nvvm_read_ptx_sreg_aggr_smem_size,
4771+
[hasSM<90>, hasPTX<81>]>;
47644772

47654773
def SREG_LANEID : PTX_READ_SREG_R32<"laneid", int_nvvm_read_ptx_sreg_laneid>;
47664774
def SREG_WARPID : PTX_READ_SREG_R32<"warpid", int_nvvm_read_ptx_sreg_warpid>;
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
2+
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx81| FileCheck --check-prefixes=CHECK %s
3+
; RUN: %if ptxas-sm_90 && ptxas-isa-8.1 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx81| %ptxas-verify -arch=sm_90 %}
4+
5+
define i32 @test_aggr_smem_size() {
6+
; CHECK-LABEL: test_aggr_smem_size(
7+
; CHECK: {
8+
; CHECK-NEXT: .reg .b32 %r<2>;
9+
; CHECK-EMPTY:
10+
; CHECK-NEXT: // %bb.0:
11+
; CHECK-NEXT: mov.u32 %r1, %aggr_smem_size;
12+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
13+
; CHECK-NEXT: ret;
14+
%a = tail call i32 @llvm.nvvm.read.ptx.sreg.aggr_smem_size()
15+
ret i32 %a
16+
}
17+
18+
declare i32 @llvm.nvvm.read.ptx.sreg.aggr_smem_size()

llvm/test/CodeGen/NVPTX/intrinsics.ll

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,32 @@ define i64 @test_steadycounter() {
318318
ret i64 %ret
319319
}
320320

321+
define i32 @test_total_smem_size() {
322+
; CHECK-LABEL: test_total_smem_size(
323+
; CHECK: {
324+
; CHECK-NEXT: .reg .b32 %r<2>;
325+
; CHECK-EMPTY:
326+
; CHECK-NEXT: // %bb.0:
327+
; CHECK-NEXT: mov.u32 %r1, %total_smem_size;
328+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
329+
; CHECK-NEXT: ret;
330+
%a = tail call i32 @llvm.nvvm.read.ptx.sreg.total_smem_size()
331+
ret i32 %a
332+
}
333+
334+
define i32 @test_dynamic_smem_size() {
335+
; CHECK-LABEL: test_dynamic_smem_size(
336+
; CHECK: {
337+
; CHECK-NEXT: .reg .b32 %r<2>;
338+
; CHECK-EMPTY:
339+
; CHECK-NEXT: // %bb.0:
340+
; CHECK-NEXT: mov.u32 %r1, %dynamic_smem_size;
341+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
342+
; CHECK-NEXT: ret;
343+
%a = tail call i32 @llvm.nvvm.read.ptx.sreg.dynamic_smem_size()
344+
ret i32 %a
345+
}
346+
321347
declare float @llvm.fabs.f32(float)
322348
declare double @llvm.fabs.f64(double)
323349
declare float @llvm.nvvm.sqrt.f(float)
@@ -335,3 +361,5 @@ declare void @llvm.nvvm.exit()
335361
declare i64 @llvm.nvvm.read.ptx.sreg.globaltimer()
336362
declare i64 @llvm.readcyclecounter()
337363
declare i64 @llvm.readsteadycounter()
364+
declare i32 @llvm.nvvm.read.ptx.sreg.total_smem_size()
365+
declare i32 @llvm.nvvm.read.ptx.sreg.dynamic_smem_size()

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,12 @@ def NVVM_ClusterDimBlocksZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sre
353353
def NVVM_ClusterId : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank", [NVVMRequiresSM<90>]>;
354354
def NVVM_ClusterDim : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">;
355355

356+
//===----------------------------------------------------------------------===//
357+
// Various configurations of Shared memory sizes
358+
def NVVM_TotalSmemSize : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.total.smem.size">;
359+
def NVVM_DynamicSmemSize : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.dynamic.smem.size">;
360+
def NVVM_AggrSmemSize : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.aggr.smem.size", [NVVMRequiresSM<90>]>;
361+
356362
//===----------------------------------------------------------------------===//
357363
// Clock registers
358364
def NVVM_ClockOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock">;

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,12 @@ llvm.func @nvvm_special_regs() -> i32 {
156156
%76 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 0> : i32
157157
// CHECK: %77 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
158158
%77 = nvvm.read.ptx.sreg.tid.x range <i32, 4294967295, 4294967295> : i32
159+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.total_smem_size()
160+
%78 = nvvm.read.ptx.sreg.total.smem.size : i32
161+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.dynamic_smem_size()
162+
%79 = nvvm.read.ptx.sreg.dynamic.smem.size : i32
163+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.aggr_smem_size()
164+
%80 = nvvm.read.ptx.sreg.aggr.smem.size : i32
159165
llvm.return %1 : i32
160166
}
161167

0 commit comments

Comments
 (0)