Skip to content

Commit 9e2a6d2

Browse files
committed
Add verifier checks
1 parent 3a1cf10 commit 9e2a6d2

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,22 @@ LogicalResult NVVM::StMatrixOp::verify() {
828828
if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
829829
return emitOpError("expected num attribute to be 1, 2 or 4");
830830

831+
int m = getShape().getM(), n = getShape().getN();
832+
if (m == 8 && n == 8) {
833+
if (getElttype() != NVVM::LdStMatrixEltType::B16) {
834+
return emitOpError("expected element type to be B16 for 8x8 matrix");
835+
}
836+
} else if (m == 16 && n == 8) {
837+
if (getElttype() != NVVM::LdStMatrixEltType::B8) {
838+
return emitOpError("expected element type to be B8 for 16x8 matrix");
839+
}
840+
if (getLayout() != NVVM::MMALayout::col) {
841+
return emitOpError("expected layout to be col for 16x8 matrix");
842+
}
843+
} else {
844+
return emitOpError("expected shape to be 8x8 or 16x8");
845+
}
846+
831847
return success();
832848
}
833849

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,8 +1144,44 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
11441144
llvm.return
11451145
}
11461146

1147+
llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
1148+
// expected-error@+1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}}
1149+
nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32, i32, i32
1150+
llvm.return
1151+
}
1152+
11471153
// -----
11481154

1155+
llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
1156+
// expected-error@+1 {{'nvvm.stmatrix' op expected shape to be 8x8 or 16x8}}
1157+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32
1158+
llvm.return
1159+
}
1160+
1161+
// -----
1162+
1163+
llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
1164+
// expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B16 for 8x8 matrix}}
1165+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : !llvm.ptr<3>, i32
1166+
llvm.return
1167+
}
1168+
// -----
1169+
1170+
llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
1171+
// expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B8 for 16x8 matrix}}
1172+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32
1173+
llvm.return
1174+
}
1175+
1176+
// -----
1177+
1178+
llvm.func @wmmast_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
1179+
// expected-error@+1 {{'nvvm.stmatrix' op expected layout to be col for 16x8 matrix}}
1180+
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : !llvm.ptr<3>, i32
1181+
llvm.return
1182+
}
1183+
1184+
// -----
11491185
llvm.func @caller() {
11501186
// expected-error @below {{expected function call to produce a value}}
11511187
llvm.call @callee() : () -> ()

0 commit comments

Comments
 (0)