Skip to content

Commit 9d0903b

Browse files
committed
[mlir][tosa] Relax constraint on matmul verifier requiring equal operand types
Removes the verifier constraint allowing support for matmul with different operand types such as fp8e5m2xfp8e4m3. Support for specific operand types strictly adhering to the TOSA specification will still be caught in the validation pass. Change-Id: I1453ded48326ea0460fa6caf52651c02b7d8c055
1 parent f262af4 commit 9d0903b

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,12 +1605,6 @@ LogicalResult MatMulOp::verify() {
16051605
return emitOpError("expect quantized operands to have same widths, got ")
16061606
<< aQuantWidth << " and " << bQuantWidth;
16071607
}
1608-
} else {
1609-
// non-quantized element types
1610-
if (aElementType != bElementType) {
1611-
return emitOpError("expect same element type for inputs a and b, got ")
1612-
<< aElementType << " and " << bElementType;
1613-
}
16141608
}
16151609

16161610
// check a_zp and b_zp

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,15 @@ func.func @test_matmul_f8E5M2(%arg0: tensor<1x14x19xf8E5M2>, %arg1: tensor<1x19x
934934
return %0 : tensor<1x14x28xf16>
935935
}
936936

937+
// -----
938+
// CHECK-LABEL: test_matmul_f8E5M2_f8E4M3
939+
func.func @test_matmul_f8E5M2_f8E4M3(%arg0: tensor<1x14x19xf8E5M2>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf16> {
940+
%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
941+
%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
942+
%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E5M2>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E5M2>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf16>
943+
return %0 : tensor<1x14x28xf16>
944+
}
945+
937946
// -----
938947
// CHECK-LABEL: max_pool2d_f8E5M2
939948
func.func @test_max_pool2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>) -> tensor<1x32x32x8xf8E5M2> {

0 commit comments

Comments
 (0)