diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc index a9b458acd87f2..d3fd4c3d1d3e1 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc @@ -35,9 +35,11 @@ profileComplianceMap = { {fp16T, fp16T, fp32T, fp32T}, {fp32T, fp32T, fp32T, fp32T}}}}}, {"tosa.matmul", - {{{Profile::pro_int}, {{i8T, i8T, i32T}}}, + {{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T}, {fp16T, fp16T, fp32T}, {fp32T, fp32T, fp32T}}}}}, + {{fp16T, fp16T, fp16T, fp16T, fp16T}, + {fp16T, fp16T, fp16T, fp16T, fp32T}, + {fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, {"tosa.max_pool2d", {{{Profile::pro_int}, {{i8T, i8T}}}, {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, @@ -273,10 +275,10 @@ extensionComplianceMap = { {{Extension::int16}, {{i16T, i8T, i48T, i48T}}}, {{Extension::bf16}, {{bf16T, bf16T, fp32T, fp32T}}}}}, {"tosa.matmul", - {{{Extension::int16}, {{i16T, i16T, i48T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T}}}, - {{Extension::bf16}, {{bf16T, bf16T, fp32T}}}}}, + {{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}}, + {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}}}, + {{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}}, {"tosa.max_pool2d", {{{Extension::int16}, {{i16T, i16T}}}, {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 5340ce52d73fc..525aa4806c657 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -311,8 +311,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> { let arguments = (ins Tosa_Tensor3D:$a, Tosa_Tensor3D:$b, - OptionalAttr:$a_zp, - OptionalAttr:$b_zp + Tosa_ScalarIntOrFloatTensor:$a_zp, + Tosa_ScalarIntOrFloatTensor:$b_zp ); let results = (outs @@ -324,6 +324,13 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> { Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, ]; + let extraClassDeclaration = [{ + FailureOr getAZeroPoint(); + FailureOr getBZeroPoint(); + LogicalResult verifyAZeroPoint(int64_t zp); + LogicalResult verifyBZeroPoint(int64_t zp); + }]; + let builders = [Tosa_MatMulOpQuantInfoBuilder]; let hasVerifier = 1; } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 2a2589e19d0ac..13c62b2d3e91c 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -270,8 +270,8 @@ class ConvConverter : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "weight zero point cannot be statically determined"); - int64_t inputZpVal = *maybeIZp; - int64_t weightZpVal = *maybeWZp; + const int64_t inputZpVal = *maybeIZp; + const int64_t weightZpVal = *maybeWZp; if (op.verifyInputZeroPoint(inputZpVal).failed()) return rewriter.notifyMatchFailure( @@ -466,8 +466,8 @@ class DepthwiseConvConverter return rewriter.notifyMatchFailure( op, "weight zero point cannot be statically determined"); - int64_t inputZpVal = *maybeIZp; - int64_t weightZpVal = *maybeWZp; + const int64_t inputZpVal = *maybeIZp; + const int64_t weightZpVal = *maybeWZp; if (op.verifyInputZeroPoint(inputZpVal).failed()) return rewriter.notifyMatchFailure( @@ -621,15 +621,38 @@ class MatMulConverter : public OpConversionPattern { .create(loc, ValueRange{zero}, ValueRange{emptyTensor}) .result(); - if (!op.getAZp() && !op.getBZp()) { + + FailureOr maybeAZp = op.getAZeroPoint(); + FailureOr maybeBZp = op.getBZeroPoint(); + if (failed(maybeAZp)) + return rewriter.notifyMatchFailure( + op, "input a zero point cannot be statically determined"); + if (failed(maybeBZp)) + return rewriter.notifyMatchFailure( + op, "input b zero point cannot be statically determined"); + + const int64_t aZpVal = *maybeAZp; + const int64_t bZpVal = *maybeBZp; + + if (op.verifyAZeroPoint(aZpVal).failed()) + return rewriter.notifyMatchFailure( + op, "input a zero point must be zero for non-int8 integer types"); + + if (op.verifyBZeroPoint(bZpVal).failed()) + return rewriter.notifyMatchFailure( + op, "input b zero point must be zero for non-int8 integer types"); + + if (aZpVal == 0 && bZpVal == 0) { rewriter.replaceOpWithNewOp( op, TypeRange{op.getType()}, ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor}); return success(); } - auto aZp = rewriter.create(loc, op.getAZpAttr()); - auto bZp = rewriter.create(loc, op.getBZpAttr()); + auto aZp = rewriter.create( + loc, rewriter.getI32IntegerAttr(aZpVal)); + auto bZp = rewriter.create( + loc, rewriter.getI32IntegerAttr(bZpVal)); rewriter.replaceOpWithNewOp( op, TypeRange{op.getType()}, ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor); @@ -834,8 +857,8 @@ class AvgPool2dConverter : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "output zero point could not be statically determined"); - int64_t inputZpVal = *maybeIZp; - int64_t outputZpVal = *maybeOZp; + const int64_t inputZpVal = *maybeIZp; + const int64_t outputZpVal = *maybeOZp; // Apply padding as necessary. llvm::SmallVector pad; diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp index ffbb707344b8c..6dcb7c845b21f 100644 --- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp @@ -55,6 +55,8 @@ struct MatMulOpSharding SmallVector maps; maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx)); maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx)); + maps.push_back(AffineMap::get(0, 0, {}, ctx)); + maps.push_back(AffineMap::get(0, 0, {}, ctx)); maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx)); return maps; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 854196250bb0c..f8299e45b4f63 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -636,23 +636,13 @@ buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b) { - result.addOperands({a, b}); - auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b); + auto zps = createZPsAsConst(builder, a, b); + result.addOperands({a, b, zps.first, zps.second}); - if (quantAttr) { - result.addAttribute("a_zp", builder.getI32IntegerAttr( - static_cast(quantAttr.getAZp()))); - result.addAttribute("b_zp", builder.getI32IntegerAttr( - static_cast(quantAttr.getBZp()))); - - auto inputType = llvm::dyn_cast(a.getType()); - assert(inputType && "Input must be a shaped tensor type!"); - - auto inputQType = llvm::dyn_cast( - inputType.getElementType()); - assert(inputQType && "Tensor must have quantized datatype!"); - - unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); + Type finalOutputType{outputType}; + if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) { + auto eType = getStorageElementTypeOrSelf(a.getType()); + auto inputBits = eType.getIntOrFloatBitWidth(); auto outputShapedType = llvm::dyn_cast(outputType); assert(outputShapedType && "Output must be a shaped type"); @@ -662,11 +652,10 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder, accElementType = builder.getIntegerType(48); else accElementType = builder.getI32Type(); - auto accType = outputShapedType.clone(accElementType); - result.addTypes(accType); - } else { - result.addTypes(outputType); + + finalOutputType = outputShapedType.clone(accElementType); } + result.addTypes(finalOutputType); } /// Both the tosa.avg_pool2d and unary ops use the same @@ -1147,16 +1136,39 @@ LogicalResult MatMulOp::verify() { return emitOpError("expect quantized operands to have same widths, got ") << aQuantWidth << " and " << bQuantWidth; } + } else { + // non-quantized element types + if (aElementType != bElementType) { + return emitOpError("expect same element type for inputs a and b, got ") + << aElementType << " and " << bElementType; + } + } - return success(); + // check a_zp and b_zp + auto aEType = getStorageElementTypeOrSelf(aType); + auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType()); + if (aEType != aZpEType) { + return emitOpError("expect input a and a_zp have the same " + "element type, got ") + << aEType << " and " << aZpEType; } - // non-quantized element types - if (aElementType != bElementType) { - return emitOpError("expect same element type for inputs a and b, got ") - << aElementType << " and " << bElementType; + auto bEType = getStorageElementTypeOrSelf(bType); + auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType()); + if (bEType != bZpEType) { + return emitOpError("expect input b and b_zp have the same " + "element type, got ") + << bEType << " and " << bZpEType; } + FailureOr maybeAZp = getAZeroPoint(); + if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed()) + return failure(); + + FailureOr maybeBZp = getBZeroPoint(); + if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed()) + return failure(); + return success(); } @@ -1721,6 +1733,8 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input) ZERO_POINT_HELPER(TransposeConv2DOp, Weight) ZERO_POINT_HELPER(AvgPool2dOp, Input) ZERO_POINT_HELPER(AvgPool2dOp, Output) +ZERO_POINT_HELPER(MatMulOp, A) +ZERO_POINT_HELPER(MatMulOp, B) #undef ZERO_POINT_HELPER LogicalResult tosa::TransposeOp::inferReturnTypeComponents( diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 345616c9563b5..983062ffd7912 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -178,6 +178,15 @@ void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) { addValue(op.getOutput()); } +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) { + addValue(op.getA()); + addValue(op.getB()); + addValue(op.getAZp()); + addValue(op.getBZp()); + addValue(op.getOutput()); +} + LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { // This helper function only populates the info for the customised operands. #define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \ @@ -218,6 +227,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { POPULATE_PROFILE_INFO_CUSTOM(Resize) POPULATE_PROFILE_INFO_CUSTOM(Select) POPULATE_PROFILE_INFO_CUSTOM(Rescale) + POPULATE_PROFILE_INFO_CUSTOM(MatMul) // Type Invariant Extension, a capability extension that is independent // of the data type, meaning any compatible type can be used. No type @@ -235,7 +245,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { POPULATE_PROFILE_INFO_COMMON(Cast) POPULATE_PROFILE_INFO_COMMON(Const) POPULATE_PROFILE_INFO_COMMON(ArgMax) - POPULATE_PROFILE_INFO_COMMON(MatMul) POPULATE_PROFILE_INFO_COMMON(Sub) POPULATE_PROFILE_INFO_COMMON(Maximum) POPULATE_PROFILE_INFO_COMMON(Minimum) diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index 5bb4a3bddb51b..341f773c79a5e 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -8,7 +8,9 @@ func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : f32) outs([[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x6xf32>) outs([[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> - %0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x3xf32>, tensor<1x3x6xf32>) -> tensor<1x5x6xf32> + %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xf32>, tensor<1x3x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x6xf32> return %0 : tensor<1x5x6xf32> } @@ -23,7 +25,9 @@ func.func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) -> // CHECK: [[ONE:%.+]] = arith.constant 1 // CHECK: [[TWO:%.+]] = arith.constant 2 // CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32> - %0 = tosa.matmul %arg0, %arg1 {a_zp = 1 : i32, b_zp = 2 : i32} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32> + %a_zp = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8> + %b_zp = "tosa.const"() <{values = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xi8>, tensor<1x3x6xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x5x6xi32> return %0 : tensor<1x5x6xi32> } @@ -37,7 +41,9 @@ func.func @matmul_dyn_batch(%arg0: tensor, %arg1: tensor) // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0_0]] : f32) outs(%[[INIT]] : tensor) -> tensor // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor, tensor) outs(%[[FILLED]] : tensor) -> tensor - %0 = tosa.matmul %arg0, %arg1 : (tensor, tensor) -> tensor + %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor, tensor, tensor<1xf32>, tensor<1xf32>) -> tensor return %0 : tensor } @@ -51,7 +57,9 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32> // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x?xf32>) outs(%[[FILLED]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32> - %0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x3xf32>, tensor<1x3x?xf32>) -> tensor<1x5x?xf32> + %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xf32>, tensor<1x3x?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x?xf32> return %0 : tensor<1x5x?xf32> } @@ -63,7 +71,9 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x // CHECK: %[[INIT:.+]] = tensor.empty() // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x?xf32>, tensor<1x?x6xf32>) outs(%[[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> - %0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x?xf32>, tensor<1x?x6xf32>) -> tensor<1x5x6xf32> + %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x?xf32>, tensor<1x?x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x6xf32> return %0 : tensor<1x5x6xf32> } @@ -77,7 +87,9 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>) // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor // CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor) -> tensor // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x1x8xf32>, tensor<1x8x1xf32>) outs(%[[FILLED]] : tensor) -> tensor - %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor + %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x1x8xf32>, tensor<1x8x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor return %0 : tensor } diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir index 83136f613b020..14c67e670e921 100644 --- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir +++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir @@ -98,14 +98,16 @@ func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor } // CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32> -func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> { +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32> +func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0]] : !mesh.sharding // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> - // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]] - %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32> + // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding + // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32> + // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]] + %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32> %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding @@ -115,14 +117,16 @@ func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: ten } // CHECK-LABEL: func.func @matmul_on_def_shard_m_and_k -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32> -func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> { +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32> +func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> - // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]] - %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32> + // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding + // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32> + // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]] + %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32> %s1 = mesh.sharding @mesh_2d split_axes = [[], [1]] partial = sum [0] : !mesh.sharding @@ -132,16 +136,18 @@ func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor< } // CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32> -func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> { +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32> +func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> %s0 = mesh.sharding @mesh_2d split_axes = [[], [1], [0]] : !mesh.sharding %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<2x16x8xf32> // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> - // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]] - %1 = tosa.matmul %0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32> + // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding + // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32> + // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]] + %1 = tosa.matmul %0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32> // CHECK-NEXT: return %[[V3]] @@ -149,8 +155,8 @@ func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor< } // CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32> -func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> { +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32> +func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> { // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32> %s0 = mesh.sharding @mesh_2d split_axes = [[], [1], [0]] : !mesh.sharding @@ -159,8 +165,10 @@ func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32> %s1 = mesh.sharding @mesh_2d split_axes = [[], [0]] : !mesh.sharding %1 = mesh.shard %arg1 to %s1 annotate_for_users : tensor<2x8x32xf32> - // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]] - %2 = tosa.matmul %0, %1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32> + // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding + // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32> + // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]] + %2 = tosa.matmul %0, %1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32> // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] partial = sum [0] : !mesh.sharding // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32> // CHECK-NEXT: return %[[V3]] @@ -199,14 +207,16 @@ func.func @resolve_conflicting_annotations( // https://arxiv.org/abs/2211.05102 Figure 2(a) // CHECK-LABEL: func.func @mlp_1d_weight_stationary -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32> -func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> { +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>, %[[ARG3:.*]]: tensor<1xf32> +func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> { %s0 = mesh.sharding @mesh_1d split_axes = [[], [], [0]] : !mesh.sharding %0 = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32> // CHECK-DAG: %[[S1:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0]] : !mesh.sharding // CHECK-DAG: %[[S2:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0]] : !mesh.sharding + // CHECK-DAG: %[[S3:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}]] : !mesh.sharding + // CHECK-DAG: %[[ZP:.*]] = mesh.shard %[[ARG3]] to %[[S3]] annotate_for_users : tensor<1xf32> // CHECK: %[[V0:.*]] = tosa.matmul - %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32> + %1 = tosa.matmul %0, %arg1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32> // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S2]] : tensor<2x4x32xf32> // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S2]] annotate_for_users : tensor<2x4x32xf32> // CHECK-DAG: %[[V3:.*]] = tosa.sigmoid %[[V2]] @@ -215,8 +225,8 @@ func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x // CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S2]] annotate_for_users : tensor<2x4x32xf32> // CHECK-DAG: %[[S6:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [0]] : !mesh.sharding // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[ARG2]] to %[[S6]] annotate_for_users : tensor<2x32x8xf32> - // CHECK-DAG: %[[V7:.*]] = tosa.matmul %[[V5]], %[[V6]] - %3 = tosa.matmul %2, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32> + // CHECK-DAG: %[[V7:.*]] = tosa.matmul %[[V5]], %[[V6]], %[[ZP]], %[[ZP]] + %3 = tosa.matmul %2, %arg2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32> %s4 = mesh.sharding @mesh_1d split_axes = [[], [], []] partial = sum [0] : !mesh.sharding %4 = mesh.shard %3 to %s4 : tensor<2x4x8xf32> // CHECK: %[[S8:.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], []] partial = sum [0] : !mesh.sharding @@ -230,8 +240,8 @@ func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x // https://arxiv.org/abs/2211.05102 Figure 2(b) // CHECK-LABEL: func.func @mlp_2d_weight_stationary -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32> -func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> { +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>, %[[ARG3:.*]]: tensor<1xf32> +func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> { // CHECK-DAG: %[[S0:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] : tensor<2x4x8xf32> %s0 = mesh.sharding @mesh_3d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding @@ -240,8 +250,10 @@ func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S1]] annotate_for_users : tensor<2x4x8xf32> // CHECK-DAG: %[[S2:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [0], [1, 2]] : !mesh.sharding // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[ARG1]] to %[[S2]] annotate_for_users : tensor<2x8x32xf32> - // CHECK-DAG: %[[V3:.*]] = tosa.matmul %[[V1]], %[[V2]] - %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32> + // CHECK-DAG: %[[S3:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}]] : !mesh.sharding + // CHECK-DAG: %[[ZP:.*]] = mesh.shard %[[ARG3]] to %[[S3]] annotate_for_users : tensor<1xf32> + // CHECK-DAG: %[[V3:.*]] = tosa.matmul %[[V1]], %[[V2]], %[[ZP]], %[[ZP]] + %1 = tosa.matmul %0, %arg1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32> // CHECK-DAG: %[[S4:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [1, 2]] partial = sum [0] : !mesh.sharding // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to %[[S4]] : tensor<2x4x32xf32> %s2 = mesh.sharding @mesh_3d split_axes = [[], [], [1, 2]] partial = sum [0] : !mesh.sharding @@ -254,8 +266,8 @@ func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S5]] annotate_for_users : tensor<2x4x32xf32> // CHECK-DAG: %[[S9:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [1, 2], [0]] : !mesh.sharding // CHECK-NEXT: %[[V9:.*]] = mesh.shard %[[ARG2]] to %[[S9]] annotate_for_users : tensor<2x32x8xf32> - // CHECK-DAG: %[[V10:.*]] = tosa.matmul %[[V8]], %[[V9]] - %4 = tosa.matmul %3, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32> + // CHECK-DAG: %[[V10:.*]] = tosa.matmul %[[V8]], %[[V9]], %[[ZP]], %[[ZP]] + %4 = tosa.matmul %3, %arg2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32> // CHECK-DAG: %[[S11:.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0]] partial = sum [1, 2] : !mesh.sharding // CHECK-NEXT: %[[V11:.*]] = mesh.shard %[[V10]] to %[[S11]] : tensor<2x4x8xf32> %s5 = mesh.sharding @mesh_3d split_axes = [[], [], [0]] partial = sum[1, 2] : !mesh.sharding diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index 1952ad79392c7..b786264d84106 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -69,10 +69,10 @@ func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (te // ----- // CHECK-LABEL: matmul -func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { +func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>, %a_zp: tensor<1xf32>, %b_zp: tensor<1xf32>) -> tensor<1x14x28xf32> { // CHECK: profiles: [ [pro_int, pro_fp] ] // CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ] - %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32> + %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32> return %0 : tensor<1x14x28xf32> } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 05700ca3765e4..f536444f6379e 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -287,7 +287,7 @@ func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tenso // ----- func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { - %0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> + %0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> %pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32> // expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}} %1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<13x21x3xf32> @@ -1612,3 +1612,43 @@ func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3 %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16> return %0 : tensor<13x21x3xi16> } + +// ----- +// CHECK-LABEL: test_matmul_a_zp_same_element_type +func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { +%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16> +%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> +// expected-error@+1 {{'tosa.matmul' op expect input a and a_zp have the same element type, got 'f32' and 'f16'}} +%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf16>, tensor<1xf32>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_b_zp_same_element_type +func.func @test_matmul_b_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { +%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> +%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16> +// expected-error@+1 {{'tosa.matmul' op expect input b and b_zp have the same element type, got 'f32' and 'f16'}} +%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf16>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_a_zp_non_zero +func.func @test_matmul_a_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { +%azp0 = "tosa.const"() <{values = dense<1.0> : tensor<1xf32>}> : () -> tensor<1xf32> +%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> +// expected-error@+1 {{'tosa.matmul' op a zero point must be zero for non-int8 integer types}} +%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_b_zp_non_zero +func.func @test_matmul_b_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { +%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> +%bzp0 = "tosa.const"() <{values = dense<-1.0> : tensor<1xf32>}> : () -> tensor<1xf32> +// expected-error@+1 {{'tosa.matmul' op b zero point must be zero for non-int8 integer types}} +%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> +} diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index bc13b614e3f9d..6d8237635d0ec 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1110,8 +1110,9 @@ func.func @test_rfft2d_tensor_size_invalid(%arg0: tensor<536870912x8x16xf32>) -> // ----- func.func @test_matmul_tensor_size_invalid(%arg0: tensor<23178x20000x19xf32>, %arg1: tensor<23178x19x28xf32>) -> tensor<23178x20000x28xf32> { + %zero = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> // expected-error@+1 {{'tosa.matmul' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} - %0 = tosa.matmul %arg0, %arg1 : (tensor<23178x20000x19xf32>, tensor<23178x19x28xf32>) -> tensor<23178x20000x28xf32> + %0 = tosa.matmul %arg0, %arg1, %zero, %zero : (tensor<23178x20000x19xf32>, tensor<23178x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<23178x20000x28xf32> return %0 : tensor<23178x20000x28xf32> } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index e1ac7d5f51d0e..920d66b00d544 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -145,7 +145,9 @@ func.func @test_fft2d_with_local_bound(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1 // ----- // CHECK-LABEL: test_matmul func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { - %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32> +%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> +%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> +%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32> return %0 : tensor<1x14x28xf32> } diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir index 342c57b0dd85c..d0e97e46f1f6a 100644 --- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir @@ -26,9 +26,9 @@ func.func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>, %arg1: tensor<1xf32>, %ar } // ----- -func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { +func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>, %arg2: tensor<1xf32>) -> tensor<1x14x28xf32> { // expected-error@+1 {{'tosa.matmul' op illegal: requires [pro_fp] but not enabled in target}} - %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32> + %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2: (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32> return %0 : tensor<1x14x28xf32> } diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir index 3dd0344e3647d..28c7abdeaf7f7 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir @@ -19,9 +19,9 @@ func.func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>, %arg1: tensor<1xf32>, %ar } // ----- -func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { +func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>, %arg2: tensor<1xf32>) -> tensor<1x14x28xf32> { // expected-error@+1 {{'tosa.matmul' op illegal: requires [pro_fp] but not enabled in target}} - %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32> + %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32> return %0 : tensor<1x14x28xf32> } diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 55c5c3f6bdfb6..deede4b0afadc 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -279,8 +279,10 @@ func.func @test_dynamic_argmax(%arg0 : tensor<2x?xi32>) -> () { // CHECK-LABEL: @test_static_matmul func.func @test_static_matmul(%arg0 : tensor<2x3x4xi32>, %arg1 : tensor<2x4x5xi32>) -> () { - // CHECK: tosa.matmul %arg0, %arg1 : (tensor<2x3x4xi32>, tensor<2x4x5xi32>) -> tensor<2x3x5xi32> - %0 = tosa.matmul %arg0, %arg1 : (tensor<2x3x4xi32>, tensor<2x4x5xi32>) -> tensor + // CHECK tosa.matmul %arg0, %arg1, %0, %1 : (tensor<2x3x4xi32>, tensor<2x4x5xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x3x5xi32> + %0 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> + %1 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> + %2 = tosa.matmul %arg0, %arg1, %0, %1 : (tensor<2x3x4xi32>, tensor<2x4x5xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor return } @@ -289,8 +291,10 @@ func.func @test_static_matmul(%arg0 : tensor<2x3x4xi32>, %arg1 : tensor<2x4x5xi3 // CHECK-LABEL: @test_dynamic_lhs_matmul func.func @test_dynamic_lhs_matmul(%arg0 : tensor, %arg1 : tensor<2x4x5xi32>) -> () { - // CHECK: tosa.matmul %arg0, %arg1 : (tensor, tensor<2x4x5xi32>) -> tensor<2x?x5xi32> - %0 = tosa.matmul %arg0, %arg1 : (tensor, tensor<2x4x5xi32>) -> tensor + // CHECK: tosa.matmul %arg0, %arg1, %0, %1 : (tensor, tensor<2x4x5xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x?x5xi32> + %0 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> + %1 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> + %2 = tosa.matmul %arg0, %arg1, %0, %1 : (tensor, tensor<2x4x5xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor return } @@ -299,8 +303,10 @@ func.func @test_dynamic_lhs_matmul(%arg0 : tensor, %arg1 : tensor<2x4 // CHECK-LABEL: @test_dynamic_rhs_matmul func.func @test_dynamic_rhs_matmul(%arg0 : tensor<2x3x4xi32>, %arg1 : tensor) -> () { - // CHECK: tosa.matmul %arg0, %arg1 : (tensor<2x3x4xi32>, tensor) -> tensor<2x3x?xi32> - %0 = tosa.matmul %arg0, %arg1 : (tensor<2x3x4xi32>, tensor) -> tensor + // CHECK: tosa.matmul %arg0, %arg1, %0, %1 : (tensor<2x3x4xi32>, tensor, tensor<1xi32>, tensor<1xi32>) -> tensor<2x3x?xi32> + %0 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> + %1 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> + %2 = tosa.matmul %arg0, %arg1, %0, %1 : (tensor<2x3x4xi32>, tensor, tensor<1xi32>, tensor<1xi32>) -> tensor return } @@ -309,8 +315,10 @@ func.func @test_dynamic_rhs_matmul(%arg0 : tensor<2x3x4xi32>, %arg1 : tensor, %arg1 : tensor) -> () { - // CHECK: tosa.matmul %arg0, %arg1 : (tensor, tensor) -> tensor - %0 = tosa.matmul %arg0, %arg1 : (tensor, tensor) -> tensor + // CHECK: tosa.matmul %arg0, %arg1, %0, %1 : (tensor, tensor, tensor<1xi32>, tensor<1xi32>) -> tensor + %0 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> + %1 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> + %2 = tosa.matmul %arg0, %arg1, %0, %1 : (tensor, tensor, tensor<1xi32>, tensor<1xi32>) -> tensor return } @@ -1405,11 +1413,13 @@ func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index) // CHECK-LABEL: test_non_tosa_consumer_still_propagates func.func @test_non_tosa_consumer_still_propagates(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>) -> tensor { - // CHECK: tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<1x1x1xf32> - %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor - %1 = arith.constant dense<[1, 1]> : tensor<2xindex> - %2 = tensor.reshape %0(%1) : (tensor, tensor<2xindex>) -> tensor - return %2 : tensor + // CHECK: tosa.matmul %arg0, %arg1, %0, %1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x1xf32> + %0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %1 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %2 = tosa.matmul %arg0, %arg1, %0, %1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor + %3 = arith.constant dense<[1, 1]> : tensor<2xindex> + %4 = tensor.reshape %2(%3) : (tensor, tensor<2xindex>) -> tensor + return %4 : tensor } // -----