Skip to content

Commit 9b713f5

Browse files
[MLIR][NVVM] Add PTX predefined special registers (llvm#112343)
This commit adds support for the following PTX predefined special registers * warpid * nwarpid * smid * nsmid * gridid * lanemask.* * globaltimer * envreg* And added lit tests under nvvmir.mlir
1 parent 6902b39 commit 9b713f5

File tree

2 files changed

+109
-5
lines changed

2 files changed

+109
-5
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,22 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []>
139139
}
140140

141141
//===----------------------------------------------------------------------===//
142-
// Lane index and range
142+
// Lane, Warp, SM, Grid index and range
143143
def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
144144
def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
145+
def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
146+
def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
147+
def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
148+
def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
149+
def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
150+
151+
//===----------------------------------------------------------------------===//
152+
// Lane Mask Comparison Ops
153+
def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
154+
def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
155+
def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
156+
def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
157+
def NVVM_LaneMaskGtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.gt">;
145158

146159
//===----------------------------------------------------------------------===//
147160
// Thread index and range
@@ -189,6 +202,13 @@ def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nct
189202
// Clock registers
190203
def NVVM_ClockOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock">;
191204
def NVVM_Clock64Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock64">;
205+
def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">;
206+
207+
//===----------------------------------------------------------------------===//
208+
// envreg registers
209+
foreach index = !range(0, 32) in {
210+
def NVVM_EnvReg # index # Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
211+
}
192212

193213
//===----------------------------------------------------------------------===//
194214
// NVVM approximate op definitions

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,94 @@ llvm.func @nvvm_special_regs() -> i32 {
6262
%29 = nvvm.read.ptx.sreg.clock : i32
6363
// CHECK: call i64 @llvm.nvvm.read.ptx.sreg.clock64
6464
%30 = nvvm.read.ptx.sreg.clock64 : i64
65-
66-
// CHECK: %31 = call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
67-
%31 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 64> : i32
68-
65+
// CHECK: call i64 @llvm.nvvm.read.ptx.sreg.globaltimer
66+
%31 = nvvm.read.ptx.sreg.globaltimer : i64
67+
// CHECK: %32 = call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
68+
%32 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 64> : i32
69+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.warpid
70+
%33 = nvvm.read.ptx.sreg.warpid : i32
71+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nwarpid
72+
%34 = nvvm.read.ptx.sreg.nwarpid : i32
73+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.smid
74+
%35 = nvvm.read.ptx.sreg.smid : i32
75+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nsmid
76+
%36 = nvvm.read.ptx.sreg.nsmid : i32
77+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.gridid
78+
%37 = nvvm.read.ptx.sreg.gridid : i32
79+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg0
80+
%38 = nvvm.read.ptx.sreg.envreg0 : i32
81+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg1
82+
%39 = nvvm.read.ptx.sreg.envreg1 : i32
83+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg2
84+
%40 = nvvm.read.ptx.sreg.envreg2 : i32
85+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg3
86+
%41 = nvvm.read.ptx.sreg.envreg3 : i32
87+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg4
88+
%42 = nvvm.read.ptx.sreg.envreg4 : i32
89+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg5
90+
%43 = nvvm.read.ptx.sreg.envreg5 : i32
91+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg6
92+
%44 = nvvm.read.ptx.sreg.envreg6 : i32
93+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg7
94+
%45 = nvvm.read.ptx.sreg.envreg7 : i32
95+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg8
96+
%46 = nvvm.read.ptx.sreg.envreg8 : i32
97+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg9
98+
%47 = nvvm.read.ptx.sreg.envreg9 : i32
99+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg10
100+
%48 = nvvm.read.ptx.sreg.envreg10 : i32
101+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg11
102+
%49 = nvvm.read.ptx.sreg.envreg11 : i32
103+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg12
104+
%50 = nvvm.read.ptx.sreg.envreg12 : i32
105+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg13
106+
%51 = nvvm.read.ptx.sreg.envreg13 : i32
107+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg14
108+
%52 = nvvm.read.ptx.sreg.envreg14 : i32
109+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg15
110+
%53 = nvvm.read.ptx.sreg.envreg15 : i32
111+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg16
112+
%54 = nvvm.read.ptx.sreg.envreg16 : i32
113+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg17
114+
%55 = nvvm.read.ptx.sreg.envreg17 : i32
115+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg18
116+
%56 = nvvm.read.ptx.sreg.envreg18 : i32
117+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg19
118+
%57 = nvvm.read.ptx.sreg.envreg19 : i32
119+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg20
120+
%58 = nvvm.read.ptx.sreg.envreg20 : i32
121+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg21
122+
%59 = nvvm.read.ptx.sreg.envreg21 : i32
123+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg22
124+
%60 = nvvm.read.ptx.sreg.envreg22 : i32
125+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg23
126+
%61 = nvvm.read.ptx.sreg.envreg23 : i32
127+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg24
128+
%62 = nvvm.read.ptx.sreg.envreg24 : i32
129+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg25
130+
%63 = nvvm.read.ptx.sreg.envreg25 : i32
131+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg26
132+
%64 = nvvm.read.ptx.sreg.envreg26 : i32
133+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg27
134+
%65 = nvvm.read.ptx.sreg.envreg27 : i32
135+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg28
136+
%66 = nvvm.read.ptx.sreg.envreg28 : i32
137+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg29
138+
%67 = nvvm.read.ptx.sreg.envreg29 : i32
139+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg30
140+
%68 = nvvm.read.ptx.sreg.envreg30 : i32
141+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg31
142+
%69 = nvvm.read.ptx.sreg.envreg31 : i32
143+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.eq
144+
%70 = nvvm.read.ptx.sreg.lanemask.eq : i32
145+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.le
146+
%71 = nvvm.read.ptx.sreg.lanemask.le : i32
147+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.lt
148+
%72 = nvvm.read.ptx.sreg.lanemask.lt : i32
149+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.ge
150+
%73 = nvvm.read.ptx.sreg.lanemask.ge : i32
151+
//CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.gt
152+
%74 = nvvm.read.ptx.sreg.lanemask.gt : i32
69153
llvm.return %1 : i32
70154
}
71155

0 commit comments

Comments
 (0)