Skip to content

Commit e4cf47f

Browse files
Merge pull request #446 from Xilinx/chaitany.fix_convtranpose_4x4_mismatch
fixing a mismatch in 4x4 kernel usecase
2 parents 8798e80 + de4f902 commit e4cf47f

File tree

2 files changed

+48
-44
lines changed

2 files changed

+48
-44
lines changed

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,7 +1648,7 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
16481648
};
16491649
auto stridesArrayAttr = rewriter.getI64ArrayAttr({1, 1});
16501650
Value conv;
1651-
if (needWeightsPadding) {
1651+
if (needWeightsPadding || (kernelShape[0] == 4)) {
16521652
Value conv1 = getActivationAppliedToConv(
16531653
addQDQNodesForActivationIfNeeded(rewriter.create<ONNXConvOp>(loc,
16541654
convOutputType, input, addDequantizeNodeIfNeeded(weightSlices[3]),
@@ -1683,42 +1683,44 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
16831683
convOutputType);
16841684
// Need to remove excess the ofm when weights are padded.
16851685

1686-
auto startOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {1, 1});
1687-
auto endOnnxConstant = getONNXConstOpFromVector(rewriter, loc,
1688-
{convOutputShape[convOutputShape.size() - 2] + 2,
1689-
convOutputShape[convOutputShape.size() - 1] + 2});
1690-
auto axisOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {2, 3});
1691-
auto stepOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {1, 1});
1692-
auto convSliceOutputType = RankedTensorType::get(
1693-
convOutputShape, convTransposeOutputType.getElementType());
1694-
conv1 = rewriter.create<ONNXSliceOp>(loc, convSliceOutputType, conv1,
1695-
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1696-
stepOnnxConstant);
1697-
1698-
startOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {0, 0});
1699-
endOnnxConstant = getONNXConstOpFromVector(rewriter, loc,
1700-
{convOutputShape[convOutputShape.size() - 2],
1701-
convOutputShape[convOutputShape.size() - 1]});
1702-
conv2 = rewriter.create<ONNXSliceOp>(loc, convSliceOutputType, conv2,
1703-
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1704-
stepOnnxConstant);
1705-
1706-
startOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {1, 0});
1707-
endOnnxConstant = getONNXConstOpFromVector(rewriter, loc,
1708-
{convOutputShape[convOutputShape.size() - 2] + 2,
1709-
convOutputShape[convOutputShape.size() - 1]});
1710-
conv3 = rewriter.create<ONNXSliceOp>(loc, convSliceOutputType, conv3,
1711-
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1712-
stepOnnxConstant);
1713-
1714-
startOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {0, 1});
1715-
endOnnxConstant = getONNXConstOpFromVector(rewriter, loc,
1716-
{convOutputShape[convOutputShape.size() - 2],
1717-
convOutputShape[convOutputShape.size() - 1] + 2});
1718-
conv4 = rewriter.create<ONNXSliceOp>(loc, convSliceOutputType, conv4,
1719-
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1720-
stepOnnxConstant);
1721-
1686+
if (needWeightsPadding) {
1687+
auto startOnnxConstant =
1688+
getONNXConstOpFromVector(rewriter, loc, {1, 1});
1689+
auto endOnnxConstant = getONNXConstOpFromVector(rewriter, loc,
1690+
{convOutputShape[convOutputShape.size() - 2] + 2,
1691+
convOutputShape[convOutputShape.size() - 1] + 2});
1692+
auto axisOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {2, 3});
1693+
auto stepOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {1, 1});
1694+
auto convSliceOutputType = RankedTensorType::get(
1695+
convOutputShape, convTransposeOutputType.getElementType());
1696+
conv1 = rewriter.create<ONNXSliceOp>(loc, convSliceOutputType, conv1,
1697+
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1698+
stepOnnxConstant);
1699+
1700+
startOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {0, 0});
1701+
endOnnxConstant = getONNXConstOpFromVector(rewriter, loc,
1702+
{convOutputShape[convOutputShape.size() - 2],
1703+
convOutputShape[convOutputShape.size() - 1]});
1704+
conv2 = rewriter.create<ONNXSliceOp>(loc, convSliceOutputType, conv2,
1705+
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1706+
stepOnnxConstant);
1707+
1708+
startOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {1, 0});
1709+
endOnnxConstant = getONNXConstOpFromVector(rewriter, loc,
1710+
{convOutputShape[convOutputShape.size() - 2] + 2,
1711+
convOutputShape[convOutputShape.size() - 1]});
1712+
conv3 = rewriter.create<ONNXSliceOp>(loc, convSliceOutputType, conv3,
1713+
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1714+
stepOnnxConstant);
1715+
1716+
startOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {0, 1});
1717+
endOnnxConstant = getONNXConstOpFromVector(rewriter, loc,
1718+
{convOutputShape[convOutputShape.size() - 2],
1719+
convOutputShape[convOutputShape.size() - 1] + 2});
1720+
conv4 = rewriter.create<ONNXSliceOp>(loc, convSliceOutputType, conv4,
1721+
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1722+
stepOnnxConstant);
1723+
}
17221724
// Four conv outputs are merged in channel dim
17231725
SmallVector<int64_t> outputShapeOfConcat = {
17241726
1, convOutputShape[1] * 4, convOutputShape[2], convOutputShape[3]};

test/mlir/onnx/onnx_decompose_convtranspose_phased_conv.mlir

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -502,13 +502,15 @@ func.func @test_convtrans_4phase_kernel_shape_44(%arg0: tensor<1x512x8x8xf32>, %
502502
// CHECK: %[[VAL_22:.*]] = "onnx.Slice"(%[[VAL_20]], %[[VAL_9]], %[[VAL_8]], %[[VAL_13]], %[[VAL_12]]) : (tensor<512x512x4x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<512x512x2x2xf32>
503503
// CHECK: %[[VAL_23:.*]] = "onnx.Slice"(%[[VAL_20]], %[[VAL_7]], %[[VAL_6]], %[[VAL_13]], %[[VAL_12]]) : (tensor<512x512x4x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<512x512x2x2xf32>
504504
// CHECK: %[[VAL_24:.*]] = "onnx.Slice"(%[[VAL_20]], %[[VAL_5]], %[[VAL_4]], %[[VAL_13]], %[[VAL_12]]) : (tensor<512x512x4x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<512x512x2x2xf32>
505-
// CHECK: %[[VAL_25:.*]] = "onnx.Concat"(%[[VAL_21]], %[[VAL_23]], %[[VAL_22]], %[[VAL_24]]) {axis = 0 : si64} : (tensor<512x512x2x2xf32>, tensor<512x512x2x2xf32>, tensor<512x512x2x2xf32>, tensor<512x512x2x2xf32>) -> tensor<2048x512x2x2xf32>
506-
// CHECK: %[[VAL_26:.*]] = "onnx.Concat"(%[[VAL_15]], %[[VAL_15]], %[[VAL_15]], %[[VAL_15]]) {axis = 0 : si64} : (tensor<512xf32>, tensor<512xf32>, tensor<512xf32>, tensor<512xf32>) -> tensor<2048xf32>
507-
// CHECK: %[[VAL_27:.*]] = "onnx.Conv"(%[[VAL_0]], %[[VAL_25]], %[[VAL_26]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : si64, kernel_shape = [2, 2], pads = [0, 0, 1, 1], strides = [1, 1]} : (tensor<1x512x8x8xf32>, tensor<2048x512x2x2xf32>, tensor<2048xf32>) -> tensor<1x2048x8x8xf32>
508-
// CHECK: %[[VAL_28:.*]] = "onnx.Reshape"(%[[VAL_27]], %[[VAL_3]]) {allowzero = 0 : si64} : (tensor<1x2048x8x8xf32>, tensor<5xi64>) -> tensor<2x2x512x8x8xf32>
509-
// CHECK: %[[VAL_29:.*]] = "onnx.Transpose"(%[[VAL_28]]) {perm = [2, 3, 0, 4, 1]} : (tensor<2x2x512x8x8xf32>) -> tensor<512x8x2x8x2xf32>
510-
// CHECK: %[[VAL_30:.*]] = "onnx.Reshape"(%[[VAL_29]], %[[VAL_2]]) {allowzero = 0 : si64} : (tensor<512x8x2x8x2xf32>, tensor<4xi64>) -> tensor<1x512x16x16xf32>
511-
// CHECK: onnx.Return %[[VAL_30]] : tensor<1x512x16x16xf32>
505+
// CHECK: %[[VAL_25:.*]] = "onnx.Conv"(%[[VAL_0]], %[[VAL_24]], %[[VAL_15]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : si64, kernel_shape = [2, 2], pads = [0, 0, 1, 1], strides = [1, 1]} : (tensor<1x512x8x8xf32>, tensor<512x512x2x2xf32>, tensor<512xf32>) -> tensor<1x512x8x8xf32>
506+
// CHECK: %[[VAL_26:.*]] = "onnx.Conv"(%[[VAL_0]], %[[VAL_21]], %[[VAL_15]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : si64, kernel_shape = [2, 2], pads = [1, 1, 0, 0], strides = [1, 1]} : (tensor<1x512x8x8xf32>, tensor<512x512x2x2xf32>, tensor<512xf32>) -> tensor<1x512x8x8xf32>
507+
// CHECK: %[[VAL_27:.*]] = "onnx.Conv"(%[[VAL_0]], %[[VAL_22]], %[[VAL_15]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : si64, kernel_shape = [2, 2], pads = [0, 1, 1, 0], strides = [1, 1]} : (tensor<1x512x8x8xf32>, tensor<512x512x2x2xf32>, tensor<512xf32>) -> tensor<1x512x8x8xf32>
508+
// CHECK: %[[VAL_28:.*]] = "onnx.Conv"(%[[VAL_0]], %[[VAL_23]], %[[VAL_15]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : si64, kernel_shape = [2, 2], pads = [1, 0, 0, 1], strides = [1, 1]} : (tensor<1x512x8x8xf32>, tensor<512x512x2x2xf32>, tensor<512xf32>) -> tensor<1x512x8x8xf32>
509+
// CHECK: %[[VAL_29:.*]] = "onnx.Concat"(%[[VAL_26]], %[[VAL_28]], %[[VAL_27]], %[[VAL_25]]) {axis = 1 : si64} : (tensor<1x512x8x8xf32>, tensor<1x512x8x8xf32>, tensor<1x512x8x8xf32>, tensor<1x512x8x8xf32>) -> tensor<1x2048x8x8xf32>
510+
// CHECK: %[[VAL_30:.*]] = "onnx.Reshape"(%[[VAL_29]], %[[VAL_3]]) {allowzero = 0 : si64} : (tensor<1x2048x8x8xf32>, tensor<5xi64>) -> tensor<2x2x512x8x8xf32>
511+
// CHECK: %[[VAL_31:.*]] = "onnx.Transpose"(%[[VAL_30]]) {perm = [2, 3, 0, 4, 1]} : (tensor<2x2x512x8x8xf32>) -> tensor<512x8x2x8x2xf32>
512+
// CHECK: %[[VAL_32:.*]] = "onnx.Reshape"(%[[VAL_31]], %[[VAL_2]]) {allowzero = 0 : si64} : (tensor<512x8x2x8x2xf32>, tensor<4xi64>) -> tensor<1x512x16x16xf32>
513+
// CHECK: onnx.Return %[[VAL_32]] : tensor<1x512x16x16xf32>
512514
// CHECK: }
513515
}
514516

0 commit comments

Comments
 (0)