Skip to content

Commit b5e9386

Browse files
committed
Check for static shapes in parallel conv recomposition
Signed-off-by: Rickert, Jonas <[email protected]>
1 parent f261660 commit b5e9386

File tree

2 files changed

+38
-8
lines changed

2 files changed

+38
-8
lines changed

src/Dialect/ONNX/Transforms/Recompose.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,10 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
10571057
return rewriter.notifyMatchFailure(
10581058
convOp1, "input must be a ranked tensor with static shape");
10591059

1060+
if (!cast<ShapedType>(convOp1.getType()).hasStaticShape())
1061+
return rewriter.notifyMatchFailure(
1062+
convOp1, "output type must be a ranked tensor with static shape");
1063+
10601064
// Collect all ONNXConvOps using this input.
10611065
SmallVector<ONNXConvOp> candidateConvs;
10621066
for (auto user : input.getUsers()) {
@@ -1084,6 +1088,20 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
10841088

10851089
SmallVector<ONNXConvOp> parallelConvs = candidateConvs;
10861090

1091+
SmallVector<Value> weightValues;
1092+
int64_t totalOutputChannels = 0;
1093+
for (auto conv : parallelConvs) {
1094+
auto weightType = mlir::cast<ShapedType>(conv.getW().getType());
1095+
if (!weightType.hasStaticShape())
1096+
return rewriter.notifyMatchFailure(
1097+
conv, "weight must be a ranked tensor with static shape");
1098+
if (!cast<ShapedType>(conv.getType()).hasStaticShape())
1099+
return rewriter.notifyMatchFailure(
1100+
conv, "output type must be a ranked tensor with static shape");
1101+
weightValues.push_back(conv.getW());
1102+
totalOutputChannels += weightType.getShape()[0];
1103+
}
1104+
10871105
auto *latestConv =
10881106
llvm::max_element(parallelConvs, [](ONNXConvOp a, ONNXConvOp b) {
10891107
return a->isBeforeInBlock(b.getOperation());
@@ -1132,14 +1150,6 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11321150

11331151
int64_t concatAxis = 1;
11341152

1135-
SmallVector<Value> weightValues;
1136-
int64_t totalOutputChannels = 0;
1137-
for (auto conv : parallelConvs) {
1138-
auto weightType = mlir::cast<ShapedType>(conv.getW().getType());
1139-
weightValues.push_back(conv.getW());
1140-
totalOutputChannels += weightType.getShape()[0];
1141-
}
1142-
11431153
auto firstWeightType =
11441154
mlir::cast<ShapedType>(parallelConvs[0].getW().getType());
11451155
SmallVector<int64_t> newWeightShape(

test/mlir/onnx/onnx_recompose_combine_parallel_conv.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,24 @@ func.func @test_conv_concat_dependency(%arg0: tensor<1x1x512x512xf32>) -> tensor
141141
// CHECK: [[VAR_6_:%.+]] = "onnx.Concat"([[VAR_3_]], [[VAR_5_]]) {axis = 1 : si64, onnx_node_name = "onnx.Concat_11"} : (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) -> tensor<1x64x512x512xf32>
142142
// CHECK: return [[VAR_6_]] : tensor<1x64x512x512xf32>
143143
// CHECK: }
144+
}
145+
146+
func.func @test_conv_concat_not_static_shape(%arg0: tensor<1x1x512x512xf32>, %0: tensor<*xf32>) -> tensor<1x64x512x512xf32> {
147+
%1 = onnx.Constant dense<0.00999999977> : tensor<32xf32>
148+
%2 = onnx.Constant dense<0.00999999977> : tensor<32x1x3x3xf32>
149+
%3 = onnx.Constant dense<0.00999999977> : tensor<32xf32>
150+
%4 = "onnx.Conv"(%arg0, %0, %1) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<*xf32>, tensor<32xf32>) -> tensor<1x32x512x512xf32>
151+
%5 = "onnx.Conv"(%arg0, %2, %3) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<32x1x3x3xf32>, tensor<32xf32>) -> tensor<1x?x512x512xf32>
152+
%6 = "onnx.Concat"(%4, %5) {axis = 1 : si64} : (tensor<1x32x512x512xf32>, tensor<1x?x512x512xf32>) -> tensor<1x64x512x512xf32>
153+
return %6 : tensor<1x64x512x512xf32>
154+
155+
// CHECK-LABEL: func.func @test_conv_concat_not_static_shape
156+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x512x512xf32>, [[PARAM_1_:%.+]]: tensor<*xf32>) -> tensor<1x64x512x512xf32> {
157+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32xf32>
158+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32x1x3x3xf32>
159+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {auto_pad = "NOTSET", group = 1 : si64, onnx_node_name = "onnx.Conv_12", pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<*xf32>, tensor<32xf32>) -> tensor<1x32x512x512xf32>
160+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_1_]], [[VAR_0_]]) {auto_pad = "NOTSET", group = 1 : si64, onnx_node_name = "onnx.Conv_13", pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<32x1x3x3xf32>, tensor<32xf32>) -> tensor<1x32x512x512xf32>
161+
// CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_3_]]) {axis = 1 : si64, onnx_node_name = "onnx.Concat_14"} : (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) -> tensor<1x64x512x512xf32>
162+
// CHECK: return [[VAR_4_]] : tensor<1x64x512x512xf32>
163+
// CHECK: }
144164
}

0 commit comments

Comments
 (0)