Skip to content

Commit a026655

Browse files
lhutton1Tai78641
authored andcommitted
[TOSA] Switch zero point of avgpool2d to input variable type
This commit changes the zero point attribute to an input to align with the 1.0 spec. Change-Id: Ieee6ba824327913bc8462cbcb7a74c0b6dd53d21 Signed-off-by: Luke Hutton <[email protected]>
1 parent f409340 commit a026655

File tree

15 files changed

+232
-118
lines changed

15 files changed

+232
-118
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ profileComplianceMap = {
55
{{{Profile::pro_int}, {{i8T, i32T}}},
66
{{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}},
77
{"tosa.avg_pool2d",
8-
{{{Profile::pro_int}, {{i8T, i32T, i8T}}},
8+
{{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}},
99
{{Profile::pro_fp},
10-
{{fp16T, fp16T, fp16T}, {fp16T, fp32T, fp16T}, {fp32T, fp32T, fp32T}}}}},
10+
{{fp16T, fp16T, fp16T, fp16T, fp16T}, {fp16T, fp16T, fp16T, fp32T, fp16T}, {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
1111
{"tosa.conv2d",
1212
{{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
1313
{{Profile::pro_fp},
@@ -243,10 +243,10 @@ extensionComplianceMap = {
243243
{{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}},
244244
{{Extension::bf16}, {{bf16T, i32T}}}}},
245245
{"tosa.avg_pool2d",
246-
{{{Extension::int16}, {{i16T, i32T, i16T}}},
247-
{{Extension::fp8e4m3}, {{fp8e4m3T, fp16T, fp8e4m3T}}},
248-
{{Extension::fp8e5m2}, {{fp8e5m2T, fp16T, fp8e5m2T}}},
249-
{{Extension::bf16}, {{bf16T, fp32T, bf16T}}}}},
246+
{{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}},
247+
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}},
248+
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}},
249+
{{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
250250
{"tosa.conv2d",
251251
{{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
252252
{{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,12 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
7979

8080
let arguments = (ins
8181
Tosa_Tensor4D:$input,
82+
Tosa_ScalarTensor:$input_zp,
83+
Tosa_ScalarTensor:$output_zp,
8284
Tosa_IntArrayAttr2:$kernel,
8385
Tosa_IntArrayAttr2:$stride,
8486
Tosa_IntArrayAttr4:$pad,
85-
TypeAttrOf<Tosa_AccType>:$acc_type,
86-
OptionalAttr<I32Attr>:$input_zp,
87-
OptionalAttr<I32Attr>:$output_zp
87+
TypeAttrOf<Tosa_AccType>:$acc_type
8888
);
8989

9090
let results = (outs
@@ -97,6 +97,14 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
9797
];
9898

9999
let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];
100+
101+
let extraClassDeclaration = [{
102+
LogicalResult getInputZeroPoint(int64_t &zp);
103+
LogicalResult getOutputZeroPoint(int64_t &zp);
104+
LogicalResult verifyInputZeroPoint(int64_t zp);
105+
LogicalResult verifyOutputZeroPoint(int64_t zp);
106+
}];
107+
100108
let hasVerifier = 1;
101109
}
102110

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,15 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
804804
return failure();
805805
SmallVector<Value> dynamicDims = *dynamicDimsOr;
806806

807+
int64_t inputZpVal;
808+
int64_t outputZpVal;
809+
if (op.getInputZeroPoint(inputZpVal).failed() ||
810+
op.getOutputZeroPoint(outputZpVal).failed()) {
811+
(void)rewriter.notifyMatchFailure(
812+
op, "zero points could not be statically determined");
813+
return failure();
814+
}
815+
807816
// Apply padding as necessary.
808817
llvm::SmallVector<int64_t> pad;
809818
pad.resize(2, 0);
@@ -923,9 +932,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
923932

924933
// If we have quantization information we need to apply an offset
925934
// for the input zp value.
926-
if (op.getInputZp()) {
927-
auto inputZp =
928-
rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
935+
if (inputZpVal != 0) {
936+
auto inputZp = rewriter.create<arith::ConstantOp>(
937+
loc, b.getIntegerAttr(accETy, inputZpVal));
929938
Value offset =
930939
rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
931940
poolVal =
@@ -977,9 +986,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
977986

978987
// If we have quantization information we need to apply output
979988
// zeropoint.
980-
if (op.getOutputZp()) {
981-
auto outputZp =
982-
rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr());
989+
if (outputZpVal != 0) {
990+
auto outputZp = rewriter.create<arith::ConstantOp>(
991+
loc, b.getIntegerAttr(scaled.getType(), outputZpVal));
983992
scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
984993
.getResult();
985994
}

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 64 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -455,18 +455,10 @@ LogicalResult tosa::ArgMaxOp::verify() {
455455
}
456456

457457
LogicalResult tosa::AvgPool2dOp::verify() {
458-
auto inputType = llvm::cast<ShapedType>(getInput().getType());
459-
460-
auto inputETy = inputType.getElementType();
461-
auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
462-
463-
if (auto quantType =
464-
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
465-
inputETy = quantType.getStorageType();
466-
467-
if (auto quantType =
468-
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
469-
resultETy = quantType.getStorageType();
458+
const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
459+
const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
460+
const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
461+
const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
470462

471463
auto accType = getAccType();
472464
if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
@@ -481,6 +473,28 @@ LogicalResult tosa::AvgPool2dOp::verify() {
481473
if (inputETy.isF32() && !accType.isF32())
482474
return emitOpError("accumulator type for f32 tensor is not f32");
483475

476+
if (inputETy != inputZpETy)
477+
return emitOpError("expect both input and its zero point are the same "
478+
"element type, got ")
479+
<< inputETy << " and " << inputZpETy;
480+
481+
if (resultETy != outputZpETy)
482+
return emitOpError("expect both output and its zero point are the same "
483+
"element type, got ")
484+
<< resultETy << " and " << outputZpETy;
485+
486+
int64_t inputZpVal;
487+
if (getInputZeroPoint(inputZpVal).succeeded() &&
488+
verifyInputZeroPoint(inputZpVal).failed())
489+
return emitOpError(
490+
"input zero point must be zero for non-int8 integer types");
491+
492+
int64_t outputZpVal;
493+
if (getOutputZeroPoint(outputZpVal).succeeded() &&
494+
verifyOutputZeroPoint(outputZpVal).failed())
495+
return emitOpError(
496+
"output zero point must be zero for non-int8 integer types");
497+
484498
if ((inputETy.isF32() && resultETy.isF32()) ||
485499
(inputETy.isF16() && resultETy.isF16()) ||
486500
(inputETy.isBF16() && resultETy.isBF16()) ||
@@ -629,27 +643,37 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
629643
}
630644

631645
/// Both the tosa.avg_pool2d and unary ops use the same
632-
/// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it
646+
/// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
633647
/// has additional parameters not part of the unary ops.
634648
static void
635649
buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
636650
Type outputType, Value input,
637651
DenseArrayAttr kernel, DenseArrayAttr stride,
638652
DenseArrayAttr pad, TypeAttr accType) {
639-
result.addOperands(input);
653+
const Location loc{result.location};
654+
int64_t inputZp{0};
655+
int64_t outputZp{0};
656+
657+
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
658+
if (quantAttr) {
659+
inputZp = quantAttr.getInputZp();
660+
outputZp = quantAttr.getOutputZp();
661+
}
662+
const std::optional<Value> inputZpOp =
663+
createZeroPointTensor(builder, loc, input.getType(), inputZp);
664+
assert(
665+
inputZpOp.has_value() &&
666+
"Failed to create input zero point tensor for quantized AVG_POOL2D op");
667+
const std::optional<Value> outputZpOp =
668+
createZeroPointTensor(builder, loc, outputType, outputZp);
669+
assert(
670+
outputZpOp.has_value() &&
671+
"Failed to create output zero point tensor for quantized AVG_POOL2D op");
672+
result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
640673
result.addAttribute("kernel", kernel);
641674
result.addAttribute("stride", stride);
642675
result.addAttribute("pad", pad);
643676
result.addAttribute("acc_type", accType);
644-
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
645-
if (quantAttr) {
646-
result.addAttribute("input_zp",
647-
builder.getI32IntegerAttr(
648-
static_cast<int32_t>(quantAttr.getInputZp())));
649-
result.addAttribute("output_zp",
650-
builder.getI32IntegerAttr(
651-
static_cast<int32_t>(quantAttr.getOutputZp())));
652-
}
653677
result.types.push_back(outputType);
654678
}
655679

@@ -1425,13 +1449,6 @@ static LogicalResult getZeroPoint(T op, Value val, int64_t &zp) {
14251449

14261450
template <typename T>
14271451
static LogicalResult verifyZeroPoint(T op, Value val, int64_t &zp) {
1428-
// TODO clean it up when the entire zero point (attribute -> input tensor
1429-
// type) change is done. Remaining Matmul, Rescale, Negate, and AvgPool2D.
1430-
if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
1431-
!std::is_same_v<T, DepthwiseConv2DOp> &&
1432-
!std::is_same_v<T, TransposeConv2DOp>)
1433-
return failure();
1434-
14351452
Type zpElemType = getElementTypeOrSelf(val);
14361453

14371454
if (!zpElemType.isIntOrFloat())
@@ -1446,24 +1463,24 @@ static LogicalResult verifyZeroPoint(T op, Value val, int64_t &zp) {
14461463
return success();
14471464
}
14481465

1449-
#define ZERO_POINT_HELPER(OP) \
1450-
LogicalResult tosa::OP::getInputZeroPoint(int64_t &zp) { \
1451-
return getZeroPoint(*this, getInputZp(), zp); \
1466+
#define ZERO_POINT_HELPER(OP, OPERAND_NAME) \
1467+
LogicalResult tosa::OP::get##OPERAND_NAME##ZeroPoint(int64_t &zp) { \
1468+
return getZeroPoint(*this, get##OPERAND_NAME##Zp(), zp); \
14521469
} \
1453-
LogicalResult tosa::OP::getWeightZeroPoint(int64_t &zp) { \
1454-
return getZeroPoint(*this, getWeightZp(), zp); \
1455-
} \
1456-
LogicalResult tosa::OP::verifyInputZeroPoint(int64_t zp) { \
1457-
return verifyZeroPoint(*this, getInputZp(), zp); \
1458-
} \
1459-
LogicalResult tosa::OP::verifyWeightZeroPoint(int64_t zp) { \
1460-
return verifyZeroPoint(*this, getWeightZp(), zp); \
1461-
}
1462-
1463-
ZERO_POINT_HELPER(Conv2DOp)
1464-
ZERO_POINT_HELPER(Conv3DOp)
1465-
ZERO_POINT_HELPER(DepthwiseConv2DOp)
1466-
ZERO_POINT_HELPER(TransposeConv2DOp)
1470+
LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
1471+
return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp); \
1472+
}
1473+
1474+
ZERO_POINT_HELPER(Conv2DOp, Input)
1475+
ZERO_POINT_HELPER(Conv2DOp, Weight)
1476+
ZERO_POINT_HELPER(Conv3DOp, Input)
1477+
ZERO_POINT_HELPER(Conv3DOp, Weight)
1478+
ZERO_POINT_HELPER(DepthwiseConv2DOp, Input)
1479+
ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight)
1480+
ZERO_POINT_HELPER(TransposeConv2DOp, Input)
1481+
ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
1482+
ZERO_POINT_HELPER(AvgPool2dOp, Input)
1483+
ZERO_POINT_HELPER(AvgPool2dOp, Output)
14671484
#undef ZERO_POINT_HELPER
14681485

14691486
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ void ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
5858
template <>
5959
void ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
6060
addValue(op.getInput());
61+
addValue(op.getInputZp());
62+
addValue(op.getOutputZp());
6163
addType(op.getAccType());
6264
addValue(op.getOutput());
6365
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -verify-diagnostics
22

33
// CHECK-LABEL: @avg_pool2d_with_unsupported_quant_type
4-
func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
4+
func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
55
// expected-error@+1 {{failed to legalize operation 'tosa.avg_pool2d'}}
6-
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
6+
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
77
return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
88
}
99

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,9 @@ func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
290290
// CHECK: %[[FLT:.+]] = arith.sitofp %[[CAST]]
291291
// CHECK: %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]]
292292
// CHECK: linalg.yield %[[DIV]]
293-
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf32>) -> tensor<1x5x33x62xf32>
293+
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
294+
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
295+
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x33x62xf32>
294296
return %0 : tensor<1x5x33x62xf32>
295297
}
296298

@@ -375,7 +377,9 @@ func.func @avg_pool_f16_f32acc(%arg0: tensor<1x6x34x62xf16>) -> (tensor<1x5x33x6
375377
// CHECK: %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]]
376378
// CHECK: %[[TRUNC:.+]] = arith.truncf %[[DIV]]
377379
// CHECK: linalg.yield %[[TRUNC]]
378-
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf16>) -> tensor<1x5x33x62xf16>
380+
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
381+
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
382+
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x5x33x62xf16>
379383
return %0 : tensor<1x5x33x62xf16>
380384
}
381385

@@ -416,7 +420,9 @@ func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
416420
// CHECK: %[[CLAMP:.+]] = arith.minsi %[[CMAX]], %[[LOW]]
417421
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLAMP]]
418422
// CHECK: linalg.yield %[[TRUNC]]
419-
%0 = tosa.avg_pool2d %arg0 {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi8>) -> tensor<1x5x33x62xi8>
423+
%input_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
424+
%output_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
425+
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x5x33x62xi8>
420426
return %0 : tensor<1x5x33x62xi8>
421427
}
422428

@@ -439,7 +445,9 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
439445
// CHECK-SAME: outs(%[[FILL]] : tensor<?x5x33x62xf32>) -> tensor<?x5x33x62xf32>
440446
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x5x33x62xf32>
441447
// CHECK: %[[GENERIC:.+]] = linalg.generic
442-
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<?x6x34x62xf32>) -> tensor<?x5x33x62xf32>
448+
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
449+
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
450+
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<?x6x34x62xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x5x33x62xf32>
443451
return %0 : tensor<?x5x33x62xf32>
444452
}
445453

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,18 @@ func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
2323
// -----
2424

2525
// check that tosa verify kick in
26-
func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
26+
func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x7x7x9xf32> {
2727
// expected-error@+1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}}
28-
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
29-
: (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
28+
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
29+
: (tensor<1x0x?x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
3030
return %0 : tensor<1x7x7x9xf32>
3131
}
3232

3333
// -----
3434

3535
// check that --tosa-to-linalg kick in
36-
func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
36+
func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
3737
// expected-error@+1 {{failed to legalize operation 'tosa.avg_pool2d'}}
38-
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
38+
%0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
3939
return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
4040
}

mlir/test/Dialect/Tosa/availability.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ func.func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> {
1919
func.func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> {
2020
// CHECK: profiles: [ [pro_int, pro_fp] ]
2121
// CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ]
22-
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32>
22+
%input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
23+
%output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
24+
%0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
2325
return %0 : tensor<1x7x7x9xf32>
2426
}
2527

0 commit comments

Comments
 (0)