Skip to content

Commit ce10229

Browse files
lhutton1mahesh-attarde
authored andcommitted
[mlir][tosa] Relax constraint on matmul verifier requiring equal operand types (llvm#155799)
Removes the verifier constraint allowing support for matmul with different operand types such as fp8e5m2xfp8e4m3. Support for specific operand types strictly adhering to the TOSA specification will still be caught in the validation pass.
1 parent b6b6750 commit ce10229

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,12 +1843,6 @@ LogicalResult MatMulOp::verify() {
18431843
return emitOpError("expect quantized operands to have same widths, got ")
18441844
<< aQuantWidth << " and " << bQuantWidth;
18451845
}
1846-
} else {
1847-
// non-quantized element types
1848-
if (aElementType != bElementType) {
1849-
return emitOpError("expect same element type for inputs a and b, got ")
1850-
<< aElementType << " and " << bElementType;
1851-
}
18521846
}
18531847

18541848
// check a_zp and b_zp

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,15 @@ func.func @test_matmul_f8E5M2(%arg0: tensor<1x14x19xf8E5M2>, %arg1: tensor<1x19x
976976
return %0 : tensor<1x14x28xf16>
977977
}
978978

979+
// -----
980+
// CHECK-LABEL: test_matmul_f8E5M2_f8E4M3
981+
func.func @test_matmul_f8E5M2_f8E4M3(%arg0: tensor<1x14x19xf8E5M2>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf16> {
982+
%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
983+
%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
984+
%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E5M2>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E5M2>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf16>
985+
return %0 : tensor<1x14x28xf16>
986+
}
987+
979988
// -----
980989
// CHECK-LABEL: max_pool2d_f8E5M2
981990
func.func @test_max_pool2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>) -> tensor<1x32x32x8xf8E5M2> {

0 commit comments

Comments
 (0)