Skip to content

Commit cfb4f29

Browse files
authored
[mlir][tosa] Relax constraint on matmul verifier requiring equal operand types (#155799)
Removes the verifier constraint allowing support for matmul with different operand types such as fp8e5m2xfp8e4m3. Support for specific operand types strictly adhering to the TOSA specification will still be caught in the validation pass.
1 parent c5ded52 commit cfb4f29

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,12 +1843,6 @@ LogicalResult MatMulOp::verify() {
18431843
return emitOpError("expect quantized operands to have same widths, got ")
18441844
<< aQuantWidth << " and " << bQuantWidth;
18451845
}
1846-
} else {
1847-
// non-quantized element types
1848-
if (aElementType != bElementType) {
1849-
return emitOpError("expect same element type for inputs a and b, got ")
1850-
<< aElementType << " and " << bElementType;
1851-
}
18521846
}
18531847

18541848
// check a_zp and b_zp

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,15 @@ func.func @test_matmul_f8E5M2(%arg0: tensor<1x14x19xf8E5M2>, %arg1: tensor<1x19x
976976
return %0 : tensor<1x14x28xf16>
977977
}
978978

979+
// -----
980+
// CHECK-LABEL: test_matmul_f8E5M2_f8E4M3
981+
func.func @test_matmul_f8E5M2_f8E4M3(%arg0: tensor<1x14x19xf8E5M2>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf16> {
982+
%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
983+
%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
984+
%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E5M2>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E5M2>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf16>
985+
return %0 : tensor<1x14x28xf16>
986+
}
987+
979988
// -----
980989
// CHECK-LABEL: max_pool2d_f8E5M2
981990
func.func @test_max_pool2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>) -> tensor<1x32x32x8xf8E5M2> {

0 commit comments

Comments
 (0)