Skip to content

Commit 4e3672a

Browse files
fixing a mismatch in 4x4 kernel usecase
1 parent 189c863 commit 4e3672a

File tree

2 files changed

+58
-7
lines changed

2 files changed

+58
-7
lines changed

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1735,6 +1735,55 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
17351735
ValueRange{conv2, conv4, conv3, conv1}, 1)
17361736
: rewriter.create<ONNXConcatOp>(loc, concatOutputType,
17371737
ValueRange{conv1, conv3, conv4, conv2}, 1);
1738+
} else if (kernelShape[0] == 4) {
1739+
Value conv1 = getActivationAppliedToConv(
1740+
addQDQNodesForActivationIfNeeded(rewriter.create<ONNXConvOp>(loc,
1741+
convOutputType, input, addDequantizeNodeIfNeeded(weightSlices[3]),
1742+
bias, mlir::StringAttr(), dilations, group,
1743+
convKernelShapeArrayAttr,
1744+
getPadsArrayAttr(kernelShape[0], 1, needWeightsPadding),
1745+
stridesArrayAttr)),
1746+
convOutputType);
1747+
Value conv2 = getActivationAppliedToConv(
1748+
addQDQNodesForActivationIfNeeded(rewriter.create<ONNXConvOp>(loc,
1749+
convOutputType, input, addDequantizeNodeIfNeeded(weightSlices[0]),
1750+
bias, mlir::StringAttr(), dilations, group,
1751+
convKernelShapeArrayAttr,
1752+
getPadsArrayAttr(kernelShape[0], 2, needWeightsPadding),
1753+
stridesArrayAttr)),
1754+
convOutputType);
1755+
Value conv3 = getActivationAppliedToConv(
1756+
addQDQNodesForActivationIfNeeded(rewriter.create<ONNXConvOp>(loc,
1757+
convOutputType, input, addDequantizeNodeIfNeeded(weightSlices[1]),
1758+
bias, mlir::StringAttr(), dilations, group,
1759+
convKernelShapeArrayAttr,
1760+
getPadsArrayAttr(kernelShape[0], 3, needWeightsPadding),
1761+
stridesArrayAttr)),
1762+
convOutputType);
1763+
Value conv4 = getActivationAppliedToConv(
1764+
addQDQNodesForActivationIfNeeded(rewriter.create<ONNXConvOp>(loc,
1765+
convOutputType, input, addDequantizeNodeIfNeeded(weightSlices[2]),
1766+
bias, mlir::StringAttr(), dilations, group,
1767+
convKernelShapeArrayAttr,
1768+
getPadsArrayAttr(kernelShape[0], 4, needWeightsPadding),
1769+
stridesArrayAttr)),
1770+
convOutputType);
1771+
// Four conv outputs are merged in channel dim
1772+
SmallVector<int64_t> outputShapeOfConcat = {
1773+
1, convOutputShape[1] * 4, convOutputShape[2], convOutputShape[3]};
1774+
auto concatOutputType =
1775+
RankedTensorType::get(outputShapeOfConcat, elementType);
1776+
// for the case where convtranspose kernel is [4, 4] and with pads [1, 1,
1777+
// 1, 1] The phased convs output are to be concatenated in the reverse
1778+
// order. This is observed by looking at the phased conv outputs with
1779+
// respect to convtranspose output.
1780+
bool reverseConcatOrder = (needWeightsPadding || (kernelShape[0] == 4));
1781+
// The concat output will have 4 times the channels of a single conv.
1782+
conv = (reverseConcatOrder)
1783+
? rewriter.create<ONNXConcatOp>(loc, concatOutputType,
1784+
ValueRange{conv2, conv4, conv3, conv1}, 1)
1785+
: rewriter.create<ONNXConcatOp>(loc, concatOutputType,
1786+
ValueRange{conv1, conv3, conv4, conv2}, 1);
17381787
} else {
17391788
// Combining the 4 phased weights into single weight.
17401789
bool reverseOrder = (kernelShape[0] == 4);

test/mlir/onnx/onnx_decompose_convtranspose_phased_conv.mlir

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -502,13 +502,15 @@ func.func @test_convtrans_4phase_kernel_shape_44(%arg0: tensor<1x512x8x8xf32>, %
502502
// CHECK: %[[VAL_22:.*]] = "onnx.Slice"(%[[VAL_20]], %[[VAL_9]], %[[VAL_8]], %[[VAL_13]], %[[VAL_12]]) : (tensor<512x512x4x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<512x512x2x2xf32>
503503
// CHECK: %[[VAL_23:.*]] = "onnx.Slice"(%[[VAL_20]], %[[VAL_7]], %[[VAL_6]], %[[VAL_13]], %[[VAL_12]]) : (tensor<512x512x4x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<512x512x2x2xf32>
504504
// CHECK: %[[VAL_24:.*]] = "onnx.Slice"(%[[VAL_20]], %[[VAL_5]], %[[VAL_4]], %[[VAL_13]], %[[VAL_12]]) : (tensor<512x512x4x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<512x512x2x2xf32>
505-
// CHECK: %[[VAL_25:.*]] = "onnx.Concat"(%[[VAL_21]], %[[VAL_23]], %[[VAL_22]], %[[VAL_24]]) {axis = 0 : si64} : (tensor<512x512x2x2xf32>, tensor<512x512x2x2xf32>, tensor<512x512x2x2xf32>, tensor<512x512x2x2xf32>) -> tensor<2048x512x2x2xf32>
506-
// CHECK: %[[VAL_26:.*]] = "onnx.Concat"(%[[VAL_15]], %[[VAL_15]], %[[VAL_15]], %[[VAL_15]]) {axis = 0 : si64} : (tensor<512xf32>, tensor<512xf32>, tensor<512xf32>, tensor<512xf32>) -> tensor<2048xf32>
507-
// CHECK: %[[VAL_27:.*]] = "onnx.Conv"(%[[VAL_0]], %[[VAL_25]], %[[VAL_26]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : si64, kernel_shape = [2, 2], pads = [0, 0, 1, 1], strides = [1, 1]} : (tensor<1x512x8x8xf32>, tensor<2048x512x2x2xf32>, tensor<2048xf32>) -> tensor<1x2048x8x8xf32>
508-
// CHECK: %[[VAL_28:.*]] = "onnx.Reshape"(%[[VAL_27]], %[[VAL_3]]) {allowzero = 0 : si64} : (tensor<1x2048x8x8xf32>, tensor<5xi64>) -> tensor<2x2x512x8x8xf32>
509-
// CHECK: %[[VAL_29:.*]] = "onnx.Transpose"(%[[VAL_28]]) {perm = [2, 3, 0, 4, 1]} : (tensor<2x2x512x8x8xf32>) -> tensor<512x8x2x8x2xf32>
510-
// CHECK: %[[VAL_30:.*]] = "onnx.Reshape"(%[[VAL_29]], %[[VAL_2]]) {allowzero = 0 : si64} : (tensor<512x8x2x8x2xf32>, tensor<4xi64>) -> tensor<1x512x16x16xf32>
511-
// CHECK: onnx.Return %[[VAL_30]] : tensor<1x512x16x16xf32>
505+
// CHECK: %[[VAL_25:.*]] = "onnx.Conv"(%[[VAL_0]], %[[VAL_24]], %[[VAL_15]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : si64, kernel_shape = [2, 2], pads = [0, 0, 1, 1], strides = [1, 1]} : (tensor<1x512x8x8xf32>, tensor<512x512x2x2xf32>, tensor<512xf32>) -> tensor<1x512x8x8xf32>
506+
// CHECK: %[[VAL_26:.*]] = "onnx.Conv"(%[[VAL_0]], %[[VAL_21]], %[[VAL_15]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : si64, kernel_shape = [2, 2], pads = [1, 1, 0, 0], strides = [1, 1]} : (tensor<1x512x8x8xf32>, tensor<512x512x2x2xf32>, tensor<512xf32>) -> tensor<1x512x8x8xf32>
507+
// CHECK: %[[VAL_27:.*]] = "onnx.Conv"(%[[VAL_0]], %[[VAL_22]], %[[VAL_15]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : si64, kernel_shape = [2, 2], pads = [0, 1, 1, 0], strides = [1, 1]} : (tensor<1x512x8x8xf32>, tensor<512x512x2x2xf32>, tensor<512xf32>) -> tensor<1x512x8x8xf32>
508+
// CHECK: %[[VAL_28:.*]] = "onnx.Conv"(%[[VAL_0]], %[[VAL_23]], %[[VAL_15]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : si64, kernel_shape = [2, 2], pads = [1, 0, 0, 1], strides = [1, 1]} : (tensor<1x512x8x8xf32>, tensor<512x512x2x2xf32>, tensor<512xf32>) -> tensor<1x512x8x8xf32>
509+
// CHECK: %[[VAL_29:.*]] = "onnx.Concat"(%[[VAL_26]], %[[VAL_28]], %[[VAL_27]], %[[VAL_25]]) {axis = 1 : si64} : (tensor<1x512x8x8xf32>, tensor<1x512x8x8xf32>, tensor<1x512x8x8xf32>, tensor<1x512x8x8xf32>) -> tensor<1x2048x8x8xf32>
510+
// CHECK: %[[VAL_30:.*]] = "onnx.Reshape"(%[[VAL_29]], %[[VAL_3]]) {allowzero = 0 : si64} : (tensor<1x2048x8x8xf32>, tensor<5xi64>) -> tensor<2x2x512x8x8xf32>
511+
// CHECK: %[[VAL_31:.*]] = "onnx.Transpose"(%[[VAL_30]]) {perm = [2, 3, 0, 4, 1]} : (tensor<2x2x512x8x8xf32>) -> tensor<512x8x2x8x2xf32>
512+
// CHECK: %[[VAL_32:.*]] = "onnx.Reshape"(%[[VAL_31]], %[[VAL_2]]) {allowzero = 0 : si64} : (tensor<512x8x2x8x2xf32>, tensor<4xi64>) -> tensor<1x512x16x16xf32>
513+
// CHECK: onnx.Return %[[VAL_32]] : tensor<1x512x16x16xf32>
512514
// CHECK: }
513515
}
514516

0 commit comments

Comments
 (0)