Skip to content

Commit 09e00a9

Browse files
committed
[NVPTX][MLIR] Updated MmaOp verification and negative MMA tests. PR156040.
1 parent 3e6c7f8 commit 09e00a9

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,9 @@ LogicalResult MmaOp::verify() {
621621
expectedB.emplace_back(unitB, multiplicandFragType);
622622
allowedShapes.push_back({16, 8, kFactor});
623623
allowedShapes.push_back({16, 8, kFactor * 2});
624+
625+
if (resultPtxType() != accumPtxType())
626+
return emitOpError("ctype does not match dtype");
624627
}
625628

626629
// In the M=8 case, there is only 1 possible case per data type.

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,7 @@ llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
749749
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
750750
%c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> {
751751
// C and D should have the same type according to PTX ISA
752+
// expected-error@+1 {{'nvvm.mma.sync' op ctype does not match dtype}}
752753
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
753754
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
754755
shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
@@ -763,6 +764,7 @@ llvm.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
763764
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
764765
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
765766
// C and D should have the same type according to PTX ISA
767+
// expected-error@+1 {{'nvvm.mma.sync' op ctype does not match dtype}}
766768
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
767769
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
768770
shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>

0 commit comments

Comments
 (0)