Skip to content

Commit 73aefcf

Browse files
committed
Follow the convention of eltType in the naming
1 parent 31f8036 commit 73aefcf

File tree

4 files changed

+18
-18
lines changed

4 files changed

+18
-18
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2013,7 +2013,7 @@ def LdStMatrixEltTypeAttr : EnumAttr<NVVM_Dialect, LdStMatrixEltType, "ld_st_mat
20132013

20142014
def NVVM_StMatrixOp: NVVM_Op<"stmatrix">,
20152015
Arguments<(ins LLVM_AnyPointer: $ptr, Variadic<I32>:$sources, MMALayoutAttr:$layout,
2016-
LdStMatrixShapeAttr:$shape, LdStMatrixEltTypeAttr:$elttype)> {
2016+
LdStMatrixShapeAttr:$shape, LdStMatrixEltTypeAttr:$eltType)> {
20172017
let summary = "cooperative matrix store";
20182018
let description = [{
20192019
Collectively store one or more matrices across all threads in a warp to the
@@ -2023,7 +2023,7 @@ def NVVM_StMatrixOp: NVVM_Op<"stmatrix">,
20232023
}];
20242024
string llvmBuilder = [{
20252025
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
2026-
auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape, $elttype);
2026+
auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape, $eltType);
20272027
createIntrinsicCall(builder, intId, operands, operands[0]->getType());
20282028
}];
20292029
let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -830,11 +830,11 @@ LogicalResult NVVM::StMatrixOp::verify() {
830830

831831
int m = getShape().getM(), n = getShape().getN();
832832
if (m == 8 && n == 8) {
833-
if (getElttype() != NVVM::LdStMatrixEltType::B16) {
833+
if (getEltType() != NVVM::LdStMatrixEltType::B16) {
834834
return emitOpError("expected element type to be B16 for 8x8 matrix");
835835
}
836836
} else if (m == 16 && n == 8) {
837-
if (getElttype() != NVVM::LdStMatrixEltType::B8) {
837+
if (getEltType() != NVVM::LdStMatrixEltType::B8) {
838838
return emitOpError("expected element type to be B8 for 16x8 matrix");
839839
}
840840
if (getLayout() != NVVM::MMALayout::col) {

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,38 +1146,38 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
11461146

11471147
llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
11481148
// expected-error@+1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}}
1149-
nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32, i32, i32
1149+
nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32, i32, i32
11501150
llvm.return
11511151
}
11521152

11531153
// -----
11541154

11551155
llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
11561156
// expected-error@+1 {{'nvvm.stmatrix' op expected shape to be 8x8 or 16x8}}
1157-
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32
1157+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32
11581158
llvm.return
11591159
}
11601160

11611161
// -----
11621162

11631163
llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
11641164
// expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B16 for 8x8 matrix}}
1165-
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : !llvm.ptr<3>, i32
1165+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b8>} : !llvm.ptr<3>, i32
11661166
llvm.return
11671167
}
11681168
// -----
11691169

11701170
llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
11711171
// expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B8 for 16x8 matrix}}
1172-
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32
1172+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32
11731173
llvm.return
11741174
}
11751175

11761176
// -----
11771177

11781178
llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
11791179
// expected-error@+1 {{'nvvm.stmatrix' op expected layout to be col for 16x8 matrix}}
1180-
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : !llvm.ptr<3>, i32
1180+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b8>} : !llvm.ptr<3>, i32
11811181
llvm.return
11821182
}
11831183

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -576,23 +576,23 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
576576
// CHECK-LABEL: @st_matrix
577577
llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
578578
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
579-
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32
579+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32
580580
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
581-
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32
581+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32
582582
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
583-
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : !llvm.ptr<3>, i32
583+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b8>} : !llvm.ptr<3>, i32
584584
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
585-
nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32, i32
585+
nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32, i32
586586
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
587-
nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32, i32
587+
nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32, i32
588588
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
589-
nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : !llvm.ptr<3>, i32, i32
589+
nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b8>} : !llvm.ptr<3>, i32, i32
590590
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
591-
nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32, i32, i32, i32
591+
nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32, i32, i32, i32
592592
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
593-
nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32, i32, i32, i32
593+
nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32, i32, i32, i32
594594
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x4.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
595-
nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : !llvm.ptr<3>, i32, i32, i32, i32
595+
nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elttype<b8>} : !llvm.ptr<3>, i32, i32, i32, i32
596596
llvm.return
597597
}
598598

0 commit comments

Comments
 (0)