-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][tosa] Remove out_shape from transpose_conv2d #129133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Jerry-Ge (Jerry-Ge) ChangesRemoved out_shape from transpose_conv2d to match the latest TOSA v1.0 specification (https://www.mlplatform.org/tosa/tosa_spec.html#_transpose_conv2d). Full diff: https://github.com/llvm/llvm-project/pull/129133.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 23692478755c6..ce17ad9362227 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -148,13 +148,11 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
"::mlir::Value":$weight, "mlir::Value":$bias,
"::mlir::DenseI64ArrayAttr":$outpad,
"::mlir::DenseI64ArrayAttr":$stride,
- "::mlir::DenseI64ArrayAttr":$outputShape,
"::mlir::TypeAttr":$acc_type),
[{
buildTransConvOpWithQuantInfo($_builder, $_state, outputType,
input, weight, bias,
- outpad, stride,
- outputShape, acc_type);
+ outpad, stride, acc_type);
}]>;
// The tosa.matmul op is also intended to be generated where a fully_connected
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ddfec2c9bfcd3..a9bf18f4af6b9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -408,7 +408,6 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
- Tosa_IntArrayAttr4:$out_shape,
TypeAttrOf<Tosa_AccType>:$acc_type,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7b50eceb081dd..54f9fa917f2e0 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -569,15 +569,15 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
/// Handles tosa.transpose_conv2d which has outpad and output shape
/// attributes.
-static void buildTransConvOpWithQuantInfo(
- OpBuilder &builder, OperationState &result, Type outputType, Value input,
- Value weight, Value bias, DenseI64ArrayAttr outpad,
- DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
+static void
+buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
+ Type outputType, Value input, Value weight,
+ Value bias, DenseI64ArrayAttr outpad,
+ DenseI64ArrayAttr stride, TypeAttr accType) {
auto zps = createZPsAsConst(builder, input, weight);
result.addOperands({input, weight, bias, zps.first, zps.second});
result.addAttribute("out_pad", outpad);
result.addAttribute("stride", stride);
- result.addAttribute("out_shape", outputShape);
result.addAttribute("acc_type", accType);
Type finalOutputType = outputType;
auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
@@ -2327,9 +2327,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TransposeConv2DOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- // outputShape is mutable.
- llvm::SmallVector<int64_t> outputShape =
- convertToMlirShape(adaptor.getOutShape());
+ llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
int64_t inputWidth = ShapedType::kDynamic;
int64_t inputHeight = ShapedType::kDynamic;
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 123c65e1b4fcd..5b928a2489eea 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -168,7 +168,7 @@ func.func @test_depthwise_conv2d_acc_type(%arg0: tensor<1x4x4x4xi8>, %arg1: tens
func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xi8>, %arg1: tensor<16x1x1x8xi8>, %arg2: tensor<16xi8>) -> tensor<1x32x32x16xi8> {
%zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.transpose_conv2d' op accumulator type for i8 tensor is not i32}}
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16xi8>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16xi8>
return %0 : tensor<1x32x32x16xi8>
}
@@ -741,15 +741,6 @@ func.func @test_table_io_shape_mismatch(%arg0: tensor<?x16xi16>, %arg1: tensor<6
// -----
-// CHECK-LABEL: test_transpose_conv2d_invalid_outshape
-func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
- // expected-error@+1 {{'tosa.transpose_conv2d' op attribute 'out_shape' failed to satisfy constraint: i64 dense array attribute with exactly 4 elements}}
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
- return %0 : tensor<1x32x32x16xf32>
-}
-
-// -----
-
// CHECK-LABEL: test_mul_type_mismatch
func.func @test_mul_type_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf16>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index bb3c16cf52d63..7235c52c5dd05 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -30,21 +30,17 @@ func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor
// -----
-// CHECK-LABEL: @transpose_conv2d_quantized_padded
func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x21x26x5xi32>) {
- // CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-22> : tensor<1xi8>}
- // CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<42> : tensor<1xi8>}
- // CHECK-DAG: %[[REV0:.+]] = tosa.reverse %2 {axis = 2 : i32}
- // CHECK-DAG: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
- // CHECK: tosa.conv2d %arg0, %3, %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]]
- // CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 3, 4, 8, 9>,
- // CHECK-SAME: stride = array<i64: 1, 1>}
- %input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
- %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+ // CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-22> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<42> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // CHECK-DAG: %[[REV0:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
+ // CHECK-DAG: %[[REV1:.+]] = tosa.reverse %[[REV0]] {axis = 2 : i32}
+ // CHECK: tosa.conv2d %arg0, %[[REV1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 3, 4, 8, 9>, stride = array<i64: 1, 1>}
+ %input_zp = "tosa.const"() <{value = dense<-22> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %weight_zp = "tosa.const"() <{value = dense<42> : tensor<1xi8>}> : () -> tensor<1xi8>
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {
acc_type = i32,
out_pad = array<i64: 1, 2, 3, 4>,
- out_shape = array<i64: -1, -1, -1, -1>,
stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<2x21x26x5xi32>
return %0 : tensor<2x21x26x5xi32>
}
@@ -160,12 +156,11 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
// CHECK: %[[PAD_RESULT:.+]] = tosa.pad %[[RESHAPE_RESULT_1]], %[[RESULT_PAD]]
// CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2, %[[CONST10]]
// CHECK: %[[ADD:.+]] = tosa.add %[[PAD_RESULT]], %[[RESHAPE_ARG2]]
- %input_zp = "tosa.const"() {value = dense<-103> : tensor<1xi8>} : () -> tensor<1xi8>
- %weight_zp = "tosa.const"() {value = dense<93> : tensor<1xi8>} : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() <{value = dense<-103> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %weight_zp = "tosa.const"() <{value = dense<93> : tensor<1xi8>}> : () -> tensor<1xi8>
%2 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {
acc_type = i32,
out_pad = array<i64: 2, 0, 0, 1>,
- out_shape = array<i64: 1, -1, -1, 1>,
stride = array<i64: 1, 2>} :
(tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32>
"func.return" (%2) : (tensor<1x19x2x1xi32>) -> ()
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index b87e9a78bf144..8a3dbfe17d686 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -907,7 +907,7 @@ func.func @depthwise_conv2d_strided(%arg0: tensor<1x13x14x1xf32>, %arg1: tensor<
// CHECK-LABEL: @transpose_conv2d_out_shape
func.func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<2x8x9x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, 8, 9, -1>, stride = array<i64: 1, 1>} : (tensor<2x?x?x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x8x9x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x?x?x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x8x9x5xf32>
return
}
@@ -916,7 +916,7 @@ func.func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<
// CHECK-LABEL: @transpose_conv2d_static
func.func @transpose_conv2d_static(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<2x18x19x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
return
}
@@ -925,7 +925,7 @@ func.func @transpose_conv2d_static(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5
// CHECK-LABEL: @transpose_conv2d_static_strided
func.func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<2x33x45x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
return
}
@@ -934,7 +934,7 @@ func.func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1:
// CHECK-LABEL: @transpose_conv2d_dynamic_input
func.func @transpose_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<?x?x?x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x5xf32>
return
}
@@ -943,7 +943,7 @@ func.func @transpose_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: ten
// CHECK-LABEL: @transpose_conv2d_dynamic_weights
func.func @transpose_conv2d_dynamic_weights(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<2x?x?x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
return
}
@@ -952,7 +952,7 @@ func.func @transpose_conv2d_dynamic_weights(%arg0: tensor<2x6x4x3xf32>, %arg1: t
// CHECK-LABEL: @transpose_conv2d_dynamic_bias
func.func @transpose_conv2d_dynamic_bias(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<?xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<2x8x9x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x8x9x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x8x9x5xf32>
return
}
@@ -961,14 +961,14 @@ func.func @transpose_conv2d_dynamic_bias(%arg0: tensor<2x6x4x3xf32>, %arg1: tens
// CHECK-LABEL: @transpose_conv2d_padded
func.func @transpose_conv2d_padded(%arg0: tensor<2x9x11x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<2x10x13x5xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 1, 0, 3, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x9x11x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x10x13x5xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 1, 0, 3, 0>, stride = array<i64: 1, 1>} : (tensor<2x9x11x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x10x13x5xf32>
return
}
// CHECK-LABEL: @transpose_conv2d_strided
func.func @transpose_conv2d_strided(%arg0: tensor<1x5x7x1xf32>, %arg1: tensor<1x1x1x1xf32>, %arg2: tensor<1xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
// CHECK: -> tensor<1x13x13x1xf32>
- %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 3, 2>} : (tensor<1x5x7x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x13x13x1xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 3, 2>} : (tensor<1x5x7x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x13x13x1xf32>
return
}
|
lhutton1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
GeorgeARM
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM as well. Single comment from @lhutton1 to be addressed
Signed-off-by: Suraj Sudhir <[email protected]> Change-Id: I654a2489572859dafc0d0928cad8b4086ef1ba30
9bad506 to
afbe504
Compare
Removed out_shape from transpose_conv2d to match the latest TOSA v1.0 specification (https://www.mlplatform.org/tosa/tosa_spec.html#_transpose_conv2d).