Skip to content

Commit def8ecb

Browse files
authored
[tosa] : Relax dynamic dimension checks for batch for conv decompositions (#168764)
This PR relaxes the validation checks to allow input/output data to have dynamic batch dimensions.
1 parent 2c3aa92 commit def8ecb

File tree

4 files changed

+65
-6
lines changed

4 files changed

+65
-6
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,13 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
3333
ShapedType weightType = cast<ShapedType>(weight.getType());
3434
ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
3535

36-
if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
37-
resultType.hasStaticShape())) {
36+
// Any dimensions other than batchSize cannot be dynamic for input/output
37+
for (unsigned int i = 1; i < 4; ++i) {
38+
if (inputType.isDynamicDim(i) || resultType.isDynamicDim(i))
39+
return failure();
40+
}
41+
42+
if (!weightType.hasStaticShape()) {
3843
return failure();
3944
}
4045

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,13 @@ class TransposeConvNonStridedConverter
4949
if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
5050
return failure();
5151

52-
if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
53-
!biasTy.hasStaticShape() || !resultTy.hasStaticShape())
52+
// Any dimensions other than batchSize cannot be dynamic for input/output
53+
for (unsigned int i = 1; i < 4; ++i) {
54+
if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
55+
return failure();
56+
}
57+
58+
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
5459
return failure();
5560

5661
int64_t kernelHeight = weightTy.getDimSize(1);
@@ -113,8 +118,13 @@ class TransposeConvStridedConverter
113118
if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
114119
return rewriter.notifyMatchFailure(op, "non-one stride found.");
115120

116-
if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
117-
!biasTy.hasStaticShape() || !resultTy.hasStaticShape())
121+
// Any dimensions other than batchSize cannot be dynamic for input/output
122+
for (unsigned int i = 1; i < 4; ++i) {
123+
if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
124+
return failure();
125+
}
126+
127+
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
118128
return failure();
119129

120130
int64_t batch = inputTy.getDimSize(0);

mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,26 @@ func.func @depthwise_conv2d_no_const_zero_point(%arg0: tensor<4x10x10x2xi8>, %ar
9898
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6xi32>
9999
return %0 : tensor<4x10x10x6xi32>
100100
}
101+
102+
// -----
103+
// CHECK-LABEL: func.func @depthwise_conv2d_as_mul_dynamic_batch_bias(
104+
// CHECK-SAME: %[[INP:.*]]: tensor<?x10x10x2xf32>,
105+
// CHECK-SAME: %[[WTS:.*]]: tensor<1x1x2x3xf32>,
106+
// CHECK-SAME: %[[BIAS:.*]]: tensor<?xf32>) -> tensor<?x10x10x6xf32> {
107+
// CHECK: %[[BIAS_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 1, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
108+
// CHECK: %[[RES_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[-1, 10, 10, 6]> : tensor<4xindex>} : () -> !tosa.shape<4>
109+
// CHECK: %[[MUL_SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
110+
// CHECK: %[[WTS_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 2, 3]> : tensor<5xindex>} : () -> !tosa.shape<5>
111+
// CHECK: %[[INP_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[-1, 10, 10, 2, 1]> : tensor<5xindex>} : () -> !tosa.shape<5>
112+
// CHECK: %[[INP_RESHAPED:.*]] = tosa.reshape %[[INP]], %[[INP_EXPANDED_SHAPE]] : (tensor<?x10x10x2xf32>, !tosa.shape<5>) -> tensor<?x10x10x2x1xf32>
113+
// CHECK: %[[WTS_RESHAPED:.*]] = tosa.reshape %[[WTS]], %[[WTS_EXPANDED_SHAPE]] : (tensor<1x1x2x3xf32>, !tosa.shape<5>) -> tensor<1x1x1x2x3xf32>
114+
// CHECK: %[[MUL:.*]] = tosa.mul %[[INP_RESHAPED]], %[[WTS_RESHAPED]], %[[MUL_SHIFT]] : (tensor<?x10x10x2x1xf32>, tensor<1x1x1x2x3xf32>, tensor<1xi8>) -> tensor<?x10x10x2x3xf32>
115+
// CHECK: %[[RES_RESHAPED:.*]] = tosa.reshape %[[MUL]], %[[RES_EXPANDED_SHAPE]] : (tensor<?x10x10x2x3xf32>, !tosa.shape<4>) -> tensor<?x10x10x6xf32>
116+
// CHECK: %[[BIAS_RESHAPED:.*]] = tosa.reshape %[[BIAS]], %[[BIAS_EXPANDED_SHAPE]] : (tensor<?xf32>, !tosa.shape<4>) -> tensor<1x1x1x?xf32>
117+
// CHECK: %[[RES:.*]] = tosa.add %[[RES_RESHAPED]], %[[BIAS_RESHAPED]] : (tensor<?x10x10x6xf32>, tensor<1x1x1x?xf32>) -> tensor<?x10x10x6xf32>
118+
// CHECK: return %[[RES]]
119+
func.func @depthwise_conv2d_as_mul_dynamic_batch_bias(%arg0: tensor<?x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<?xf32>) -> tensor<?x10x10x6xf32> {
120+
%zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
121+
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<?x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x10x10x6xf32>
122+
return %0 : tensor<?x10x10x6xf32>
123+
}

mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,24 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
181181
(tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32>
182182
"func.return" (%2) : (tensor<1x19x2x1xi32>) -> ()
183183
}
184+
185+
186+
// -----
187+
// CHECK-LABEL: @transpose_conv2d_non_strided_dynamic_batch
188+
// CHECK: tosa.conv2d
189+
// CHECK-NOT: tosa.transpose_conv2d
190+
func.func @transpose_conv2d_non_strided_dynamic_batch(%arg0: tensor<?x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<?x18x19x5xf32> {
191+
%zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
192+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x18x19x5xf32>
193+
return %0 : tensor<?x18x19x5xf32>
194+
}
195+
196+
// -----
197+
// CHECK-LABEL: @transpose_conv2d_strided_dynamic_batch
198+
// CHECK: tosa.conv2d
199+
// CHECK-NOT: tosa.transpose_conv2d
200+
func.func @transpose_conv2d_strided_dynamic_batch(%arg0: tensor<?x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<?x35x47x5xf32> {
201+
%zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
202+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<?x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x35x47x5xf32>
203+
return %0 : tensor<?x35x47x5xf32>
204+
}

0 commit comments

Comments
 (0)