Skip to content

Conversation

@Jerry-Ge
Copy link
Member

Removed out_shape from transpose_conv2d to match the latest TOSA v1.0 specification (https://www.mlplatform.org/tosa/tosa_spec.html#_transpose_conv2d).

@llvmbot
Copy link
Member

llvmbot commented Feb 27, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Jerry-Ge (Jerry-Ge)

Changes

Removed out_shape from transpose_conv2d to match the latest TOSA v1.0 specification (https://www.mlplatform.org/tosa/tosa_spec.html#_transpose_conv2d).


Full diff: https://github.com/llvm/llvm-project/pull/129133.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+1-3)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (-1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+6-8)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+1-10)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+9-14)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+8-8)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 23692478755c6..ce17ad9362227 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -148,13 +148,11 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
        "::mlir::Value":$weight, "mlir::Value":$bias,
        "::mlir::DenseI64ArrayAttr":$outpad,
        "::mlir::DenseI64ArrayAttr":$stride,
-       "::mlir::DenseI64ArrayAttr":$outputShape,
        "::mlir::TypeAttr":$acc_type),
   [{
     buildTransConvOpWithQuantInfo($_builder, $_state, outputType,
                                   input, weight, bias,
-                                  outpad, stride,
-                                  outputShape, acc_type);
+                                  outpad, stride, acc_type);
   }]>;
 
 // The tosa.matmul op is also intended to be generated where a fully_connected
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ddfec2c9bfcd3..a9bf18f4af6b9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -408,7 +408,6 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
 
     Tosa_IntArrayAttr4:$out_pad,
     Tosa_IntArrayAttr2:$stride,
-    Tosa_IntArrayAttr4:$out_shape,
     TypeAttrOf<Tosa_AccType>:$acc_type,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7b50eceb081dd..54f9fa917f2e0 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -569,15 +569,15 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
 
 /// Handles tosa.transpose_conv2d which has outpad and output shape
 /// attributes.
-static void buildTransConvOpWithQuantInfo(
-    OpBuilder &builder, OperationState &result, Type outputType, Value input,
-    Value weight, Value bias, DenseI64ArrayAttr outpad,
-    DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
+static void
+buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
+                              Type outputType, Value input, Value weight,
+                              Value bias, DenseI64ArrayAttr outpad,
+                              DenseI64ArrayAttr stride, TypeAttr accType) {
   auto zps = createZPsAsConst(builder, input, weight);
   result.addOperands({input, weight, bias, zps.first, zps.second});
   result.addAttribute("out_pad", outpad);
   result.addAttribute("stride", stride);
-  result.addAttribute("out_shape", outputShape);
   result.addAttribute("acc_type", accType);
   Type finalOutputType = outputType;
   auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
@@ -2327,9 +2327,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     TransposeConv2DOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  // outputShape is mutable.
-  llvm::SmallVector<int64_t> outputShape =
-      convertToMlirShape(adaptor.getOutShape());
+  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
 
   int64_t inputWidth = ShapedType::kDynamic;
   int64_t inputHeight = ShapedType::kDynamic;
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 123c65e1b4fcd..5b928a2489eea 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -168,7 +168,7 @@ func.func @test_depthwise_conv2d_acc_type(%arg0: tensor<1x4x4x4xi8>, %arg1: tens
 func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xi8>, %arg1: tensor<16x1x1x8xi8>, %arg2: tensor<16xi8>) -> tensor<1x32x32x16xi8> {
   %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   // expected-error@+1 {{'tosa.transpose_conv2d' op accumulator type for i8 tensor is not i32}}
-  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16xi8>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16xi8>
   return %0 : tensor<1x32x32x16xi8>
 }
 
@@ -741,15 +741,6 @@ func.func @test_table_io_shape_mismatch(%arg0: tensor<?x16xi16>, %arg1: tensor<6
 
 // -----
 
-// CHECK-LABEL: test_transpose_conv2d_invalid_outshape
-func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
-  // expected-error@+1 {{'tosa.transpose_conv2d' op attribute 'out_shape' failed to satisfy constraint: i64 dense array attribute with exactly 4 elements}}
-  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
-  return %0 : tensor<1x32x32x16xf32>
-}
-
-// -----
-
 // CHECK-LABEL: test_mul_type_mismatch
 func.func @test_mul_type_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf16>) -> tensor<13x21x3xf32> {
   %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index bb3c16cf52d63..7235c52c5dd05 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -30,21 +30,17 @@ func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor
 
 // -----
 
-// CHECK-LABEL: @transpose_conv2d_quantized_padded
 func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x21x26x5xi32>) {
-  // CHECK-DAG: %[[INPUT_ZP:.+]]  = "tosa.const"() <{value = dense<-22> : tensor<1xi8>}
-  // CHECK-DAG: %[[WEIGHT_ZP:.+]]  = "tosa.const"() <{value = dense<42> : tensor<1xi8>}
-  // CHECK-DAG: %[[REV0:.+]] = tosa.reverse %2 {axis = 2 : i32}
-  // CHECK-DAG: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
-  // CHECK: tosa.conv2d %arg0, %3, %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]]
-  // CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 3, 4, 8, 9>,
-  // CHECK-SAME: stride = array<i64: 1, 1>}
-  %input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
-  %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+  // CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-22> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<42> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // CHECK-DAG: %[[REV0:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
+  // CHECK-DAG: %[[REV1:.+]] = tosa.reverse %[[REV0]] {axis = 2 : i32}
+  // CHECK: tosa.conv2d %arg0, %[[REV1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 3, 4, 8, 9>, stride = array<i64: 1, 1>}
+  %input_zp = "tosa.const"() <{value = dense<-22> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %weight_zp = "tosa.const"() <{value = dense<42> : tensor<1xi8>}> : () -> tensor<1xi8>
   %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {
     acc_type = i32,
     out_pad = array<i64: 1, 2, 3, 4>,
-    out_shape = array<i64: -1, -1, -1, -1>,
     stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<2x21x26x5xi32>
   return %0 : tensor<2x21x26x5xi32>
 }
@@ -160,12 +156,11 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
   // CHECK: %[[PAD_RESULT:.+]] = tosa.pad %[[RESHAPE_RESULT_1]], %[[RESULT_PAD]]
   // CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2, %[[CONST10]]
   // CHECK: %[[ADD:.+]] = tosa.add %[[PAD_RESULT]], %[[RESHAPE_ARG2]]
-  %input_zp = "tosa.const"() {value = dense<-103> : tensor<1xi8>} : () -> tensor<1xi8>
-  %weight_zp = "tosa.const"() {value = dense<93> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() <{value = dense<-103> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %weight_zp = "tosa.const"() <{value = dense<93> : tensor<1xi8>}> : () -> tensor<1xi8>
   %2 =  tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {
     acc_type = i32,
     out_pad = array<i64: 2, 0, 0, 1>,
-    out_shape = array<i64: 1, -1, -1, 1>,
     stride = array<i64: 1, 2>} :
     (tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32>
   "func.return" (%2) : (tensor<1x19x2x1xi32>) -> ()
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index b87e9a78bf144..8a3dbfe17d686 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -907,7 +907,7 @@ func.func @depthwise_conv2d_strided(%arg0: tensor<1x13x14x1xf32>, %arg1: tensor<
 // CHECK-LABEL: @transpose_conv2d_out_shape
 func.func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
   // CHECK: -> tensor<2x8x9x5xf32>
-  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, 8, 9, -1>, stride = array<i64: 1, 1>} : (tensor<2x?x?x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x8x9x5xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x?x?x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x8x9x5xf32>
   return
 }
 
@@ -916,7 +916,7 @@ func.func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<
 // CHECK-LABEL: @transpose_conv2d_static
 func.func @transpose_conv2d_static(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
   // CHECK: -> tensor<2x18x19x5xf32>
-  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
   return
 }
 
@@ -925,7 +925,7 @@ func.func @transpose_conv2d_static(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5
 // CHECK-LABEL: @transpose_conv2d_static_strided
 func.func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
   // CHECK: -> tensor<2x33x45x5xf32>
-  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
   return
 }
 
@@ -934,7 +934,7 @@ func.func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1:
 // CHECK-LABEL: @transpose_conv2d_dynamic_input
 func.func @transpose_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
   // CHECK: -> tensor<?x?x?x5xf32>
-  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x5xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x5xf32>
   return
 }
 
@@ -943,7 +943,7 @@ func.func @transpose_conv2d_dynamic_input(%arg0: tensor<?x?x?x?xf32>, %arg1: ten
 // CHECK-LABEL: @transpose_conv2d_dynamic_weights
 func.func @transpose_conv2d_dynamic_weights(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
   // CHECK: -> tensor<2x?x?x5xf32>
-  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<?x?x?x?xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x?x5xf32>
   return
 }
 
@@ -952,7 +952,7 @@ func.func @transpose_conv2d_dynamic_weights(%arg0: tensor<2x6x4x3xf32>, %arg1: t
 // CHECK-LABEL: @transpose_conv2d_dynamic_bias
 func.func @transpose_conv2d_dynamic_bias(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<?xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
   // CHECK: -> tensor<2x8x9x5xf32>
-  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x8x9x5xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x8x9x5xf32>
   return
 }
 
@@ -961,14 +961,14 @@ func.func @transpose_conv2d_dynamic_bias(%arg0: tensor<2x6x4x3xf32>, %arg1: tens
 // CHECK-LABEL: @transpose_conv2d_padded
 func.func @transpose_conv2d_padded(%arg0: tensor<2x9x11x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
   // CHECK: -> tensor<2x10x13x5xf32>
-  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 1, 0, 3, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x9x11x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x10x13x5xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 1, 0, 3, 0>, stride = array<i64: 1, 1>} : (tensor<2x9x11x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x10x13x5xf32>
   return
 }
 
 // CHECK-LABEL: @transpose_conv2d_strided
 func.func @transpose_conv2d_strided(%arg0: tensor<1x5x7x1xf32>, %arg1: tensor<1x1x1x1xf32>, %arg2: tensor<1xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) {
   // CHECK: -> tensor<1x13x13x1xf32>
-  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 3, 2>} : (tensor<1x5x7x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x13x13x1xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 3, 2>} : (tensor<1x5x7x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x13x13x1xf32>
   return
 }
 

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as well. Single comment from @lhutton1 to be addressed

Signed-off-by: Suraj Sudhir <[email protected]>
Change-Id: I654a2489572859dafc0d0928cad8b4086ef1ba30
@GeorgeARM GeorgeARM merged commit f5749e7 into llvm:main Feb 28, 2025
11 checks passed
cheezeburglar pushed a commit to cheezeburglar/llvm-project that referenced this pull request Feb 28, 2025
@Jerry-Ge Jerry-Ge deleted the remove_outshape branch March 20, 2025 21:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants