Skip to content

Commit b8bea83

Browse files
committed
[mlir][spirv] Refactor vendor op definitions
Use dedicated vendor op classes/categories. This is so that we can later change the mnemonics of all vendor ops by changing the base class: `SPV_VendorOp`. Issue: llvm/llvm-project#56863
1 parent 6a378b3 commit b8bea83

File tree

9 files changed

+88
-76
lines changed

9 files changed

+88
-76
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def SPV_AtomicExchangeOp : SPV_Op<"AtomicExchange", []> {
262262

263263
// -----
264264

265-
def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> {
265+
def SPV_EXTAtomicFAddOp : SPV_ExtVendorOp<"AtomicFAdd", []> {
266266
let summary = "TBD";
267267

268268
let description = [{
@@ -279,7 +279,7 @@ def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> {
279279

280280
3) store the New Value back through Pointer.
281281

282-
The instructions result is the Original Value.
282+
The instruction's result is the Original Value.
283283

284284
Result Type must be a floating-point type scalar.
285285

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
// -----
1717

18-
def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
18+
def SPV_NVCooperativeMatrixLengthOp : SPV_NvVendorOp<"CooperativeMatrixLength",
1919
[NoSideEffect]> {
2020
let summary = "See extension SPV_NV_cooperative_matrix";
2121

@@ -60,7 +60,7 @@ def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
6060

6161
// -----
6262

63-
def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
63+
def SPV_NVCooperativeMatrixLoadOp : SPV_NvVendorOp<"CooperativeMatrixLoad", []> {
6464
let summary = "See extension SPV_NV_cooperative_matrix";
6565

6666
let description = [{
@@ -136,7 +136,7 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
136136

137137
// -----
138138

139-
def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
139+
def SPV_NVCooperativeMatrixMulAddOp : SPV_NvVendorOp<"CooperativeMatrixMulAdd",
140140
[NoSideEffect, AllTypesMatch<["c", "result"]>]> {
141141
let summary = "See extension SPV_NV_cooperative_matrix";
142142

@@ -210,7 +210,7 @@ def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
210210

211211
// -----
212212

213-
def SPV_CooperativeMatrixStoreNVOp : SPV_Op<"CooperativeMatrixStoreNV", []> {
213+
def SPV_NVCooperativeMatrixStoreOp : SPV_NvVendorOp<"CooperativeMatrixStore", []> {
214214
let summary = "See extension SPV_NV_cooperative_matrix";
215215

216216
let description = [{

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def SPV_GroupBroadcastOp : SPV_Op<"GroupBroadcast",
9292

9393
// -----
9494

95-
def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
95+
def SPV_KHRSubgroupBallotOp : SPV_KhrVendorOp<"SubgroupBallot", []> {
9696
let summary = "See extension SPV_KHR_shader_ballot";
9797

9898
let description = [{
@@ -146,7 +146,7 @@ def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
146146

147147
// -----
148148

149-
def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
149+
def SPV_INTELSubgroupBlockReadOp : SPV_IntelVendorOp<"SubgroupBlockRead", []> {
150150
let summary = "See extension SPV_INTEL_subgroups";
151151

152152
let description = [{
@@ -197,7 +197,7 @@ def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
197197

198198
// -----
199199

200-
def SPV_SubgroupBlockWriteINTELOp : SPV_Op<"SubgroupBlockWriteINTEL", []> {
200+
def SPV_INTELSubgroupBlockWriteOp : SPV_IntelVendorOp<"SubgroupBlockWrite", []> {
201201
let summary = "See extension SPV_INTEL_subgroups";
202202

203203
let description = [{

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515

1616
// -----
1717

18-
def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTEL",
18+
def SPV_INTELJointMatrixWorkItemLengthOp : SPV_IntelVendorOp<"JointMatrixWorkItemLength",
1919
[NoSideEffect]> {
2020
let summary = "See extension SPV_INTEL_joint_matrix";
2121

2222
let description = [{
23-
Return number of components owned by the current work-item in
23+
Return number of components owned by the current work-item in
2424
a joint matrix.
2525

2626
Result Type must be an 32-bit unsigned integer type scalar.
@@ -60,34 +60,34 @@ def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTE
6060

6161
// -----
6262

63-
def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
63+
def SPV_INTELJointMatrixLoadOp : SPV_IntelVendorOp<"JointMatrixLoad", []> {
6464
let summary = "See extension SPV_INTEL_joint_matrix";
6565

6666
let description = [{
6767
Load a matrix through a pointer.
6868

6969
Result Type is the type of the loaded matrix. It must be OpTypeJointMatrixINTEL.
7070

71-
Pointer is the pointer to load through. It specifies start of memory region where
71+
Pointer is the pointer to load through. It specifies start of memory region where
7272
elements of the matrix are stored and arranged according to Layout.
7373

74-
Stride is the number of elements in memory between beginnings of successive rows,
74+
Stride is the number of elements in memory between beginnings of successive rows,
7575
columns (or words) in the result. It must be a scalar integer type.
7676

77-
Layout indicates how the values loaded from memory are arranged. It must be the
77+
Layout indicates how the values loaded from memory are arranged. It must be the
7878
result of a constant instruction.
7979

80-
Scope is syncronization scope for operation on the matrix. It must be the result
80+
Scope is syncronization scope for operation on the matrix. It must be the result
8181
of a constant instruction with scalar integer type.
8282

83-
If present, any Memory Operands must begin with a memory operand literal. If not
83+
If present, any Memory Operands must begin with a memory operand literal. If not
8484
present, it is the same as specifying the memory operand None.
8585

8686
#### Example:
8787
```mlir
88-
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride
89-
{memory_access = #spv.memory_access<Volatile>} :
90-
(!spv.ptr<i32, CrossWorkgroup>, i32) ->
88+
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride
89+
{memory_access = #spv.memory_access<Volatile>} :
90+
(!spv.ptr<i32, CrossWorkgroup>, i32) ->
9191
!spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
9292
```
9393
}];
@@ -119,39 +119,39 @@ def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
119119

120120
// -----
121121

122-
def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
122+
def SPV_INTELJointMatrixMadOp : SPV_IntelVendorOp<"JointMatrixMad",
123123
[NoSideEffect, AllTypesMatch<["c", "result"]>]> {
124124
let summary = "See extension SPV_INTEL_joint_matrix";
125125

126126
let description = [{
127-
Multiply matrix A by matrix B and add matrix C to the result
128-
of the multiplication: A*B+C. Here A is a M x K matrix, B is
127+
Multiply matrix A by matrix B and add matrix C to the result
128+
of the multiplication: A*B+C. Here A is a M x K matrix, B is
129129
a K x N matrix and C is a M x N matrix.
130130

131-
Behavior is undefined if sizes of operands do not meet the
132-
conditions above. All operands and the Result Type must be
131+
Behavior is undefined if sizes of operands do not meet the
132+
conditions above. All operands and the Result Type must be
133133
OpTypeJointMatrixINTEL.
134134

135-
A must be a OpTypeJointMatrixINTEL whose Component Type is a
136-
signed numerical type, Row Count equals to M and Column Count
135+
A must be a OpTypeJointMatrixINTEL whose Component Type is a
136+
signed numerical type, Row Count equals to M and Column Count
137137
equals to K
138138

139-
B must be a OpTypeJointMatrixINTEL whose Component Type is a
140-
signed numerical type, Row Count equals to K and Column Count
139+
B must be a OpTypeJointMatrixINTEL whose Component Type is a
140+
signed numerical type, Row Count equals to K and Column Count
141141
equals to N
142142

143-
C and Result Type must be a OpTypeJointMatrixINTEL with Row
143+
C and Result Type must be a OpTypeJointMatrixINTEL with Row
144144
Count equals to M and Column Count equals to N
145145

146-
Scope is syncronization scope for operation on the matrix.
147-
It must be the result of a constant instruction with scalar
146+
Scope is syncronization scope for operation on the matrix.
147+
It must be the result of a constant instruction with scalar
148148
integer type.
149149

150150
#### Example:
151151
```mlir
152-
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c :
153-
!spv.jointmatrix<8x32xi8, RowMajor, Subgroup>,
154-
!spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>
152+
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c :
153+
!spv.jointmatrix<8x32xi8, RowMajor, Subgroup>,
154+
!spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>
155155
-> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
156156
```
157157

@@ -182,38 +182,38 @@ def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
182182

183183
// -----
184184

185-
def SPV_JointMatrixStoreINTELOp : SPV_Op<"JointMatrixStoreINTEL", []> {
185+
def SPV_INTELJointMatrixStoreOp : SPV_IntelVendorOp<"JointMatrixStore", []> {
186186
let summary = "See extension SPV_INTEL_joint_matrix";
187187

188188
let description = [{
189189
Store a matrix through a pointer.
190190

191-
Pointer is the pointer to store through. It specifies
192-
start of memory region where elements of the matrix must
191+
Pointer is the pointer to store through. It specifies
192+
start of memory region where elements of the matrix must
193193
be stored and arranged according to Layout.
194194

195-
Object is the matrix to store. It must be
195+
Object is the matrix to store. It must be
196196
OpTypeJointMatrixINTEL.
197197

198-
Stride is the number of elements in memory between beginnings
199-
of successive rows, columns (or words) of the Object. It must
198+
Stride is the number of elements in memory between beginnings
199+
of successive rows, columns (or words) of the Object. It must
200200
be a scalar integer type.
201201

202-
Layout indicates how the values stored to memory are arranged.
202+
Layout indicates how the values stored to memory are arranged.
203203
It must be the result of a constant instruction.
204204

205-
Scope is syncronization scope for operation on the matrix.
206-
It must be the result of a constant instruction with scalar
205+
Scope is syncronization scope for operation on the matrix.
206+
It must be the result of a constant instruction with scalar
207207
integer type.
208208

209-
If present, any Memory Operands must begin with a memory operand
210-
literal. If not present, it is the same as specifying the memory
209+
If present, any Memory Operands must begin with a memory operand
210+
literal. If not present, it is the same as specifying the memory
211211
operand None.
212212

213213
#### Example:
214214
```mlir
215-
spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride
216-
{memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>,
215+
spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride
216+
{memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>,
217217
!spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
218218
```
219219

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
1818

1919
// -----
2020

21-
def SPV_AssumeTrueKHROp : SPV_Op<"AssumeTrueKHR", []> {
21+
def SPV_KHRAssumeTrueOp : SPV_KhrVendorOp<"AssumeTrue", []> {
2222
let summary = "TBD";
2323

2424
let description = [{

0 commit comments

Comments
 (0)