Skip to content

Commit 8f70a00

Browse files
simplifying
1 parent 4e3672a commit 8f70a00

File tree

1 file changed

+39
-86
lines changed

1 file changed

+39
-86
lines changed

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 39 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,7 +1648,7 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
16481648
};
16491649
auto stridesArrayAttr = rewriter.getI64ArrayAttr({1, 1});
16501650
Value conv;
1651-
if (needWeightsPadding) {
1651+
if (needWeightsPadding || (kernelShape[0] == 4)) {
16521652
Value conv1 = getActivationAppliedToConv(
16531653
addQDQNodesForActivationIfNeeded(rewriter.create<ONNXConvOp>(loc,
16541654
convOutputType, input, addDequantizeNodeIfNeeded(weightSlices[3]),
@@ -1683,91 +1683,44 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
16831683
convOutputType);
16841684
// Need to remove excess the ofm when weights are padded.
16851685

1686-
auto startOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {1, 1});
1687-
auto endOnnxConstant = getONNXConstOpFromVector(rewriter, loc,
1688-
{convOutputShape[convOutputShape.size() - 2] + 2,
1689-
convOutputShape[convOutputShape.size() - 1] + 2});
1690-
auto axisOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {2, 3});
1691-
auto stepOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {1, 1});
1692-
auto convSliceOutputType = RankedTensorType::get(
1693-
convOutputShape, convTransposeOutputType.getElementType());
1694-
conv1 = rewriter.create<ONNXSliceOp>(loc, convSliceOutputType, conv1,
1695-
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1696-
stepOnnxConstant);
1697-
1698-
startOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {0, 0});
1699-
endOnnxConstant = getONNXConstOpFromVector(rewriter, loc,
1700-
{convOutputShape[convOutputShape.size() - 2],
1701-
convOutputShape[convOutputShape.size() - 1]});
1702-
conv2 = rewriter.create<ONNXSliceOp>(loc, convSliceOutputType, conv2,
1703-
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1704-
stepOnnxConstant);
1705-
1706-
startOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {1, 0});
1707-
endOnnxConstant = getONNXConstOpFromVector(rewriter, loc,
1708-
{convOutputShape[convOutputShape.size() - 2] + 2,
1709-
convOutputShape[convOutputShape.size() - 1]});
1710-
conv3 = rewriter.create<ONNXSliceOp>(loc, convSliceOutputType, conv3,
1711-
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1712-
stepOnnxConstant);
1713-
1714-
startOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {0, 1});
1715-
endOnnxConstant = getONNXConstOpFromVector(rewriter, loc,
1716-
{convOutputShape[convOutputShape.size() - 2],
1717-
convOutputShape[convOutputShape.size() - 1] + 2});
1718-
conv4 = rewriter.create<ONNXSliceOp>(loc, convSliceOutputType, conv4,
1719-
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1720-
stepOnnxConstant);
1721-
1722-
// Four conv outputs are merged in channel dim
1723-
SmallVector<int64_t> outputShapeOfConcat = {
1724-
1, convOutputShape[1] * 4, convOutputShape[2], convOutputShape[3]};
1725-
auto concatOutputType =
1726-
RankedTensorType::get(outputShapeOfConcat, elementType);
1727-
// for the case where convtranspose kernel is [4, 4] and with pads [1, 1,
1728-
// 1, 1] The phased convs output are to be concatenated in the reverse
1729-
// order. This is observed by looking at the phased conv outputs with
1730-
// respect to convtranspose output.
1731-
bool reverseConcatOrder = (needWeightsPadding || (kernelShape[0] == 4));
1732-
// The concat output will have 4 times the channels of a single conv.
1733-
conv = (reverseConcatOrder)
1734-
? rewriter.create<ONNXConcatOp>(loc, concatOutputType,
1735-
ValueRange{conv2, conv4, conv3, conv1}, 1)
1736-
: rewriter.create<ONNXConcatOp>(loc, concatOutputType,
1737-
ValueRange{conv1, conv3, conv4, conv2}, 1);
1738-
} else if (kernelShape[0] == 4) {
1739-
Value conv1 = getActivationAppliedToConv(
1740-
addQDQNodesForActivationIfNeeded(rewriter.create<ONNXConvOp>(loc,
1741-
convOutputType, input, addDequantizeNodeIfNeeded(weightSlices[3]),
1742-
bias, mlir::StringAttr(), dilations, group,
1743-
convKernelShapeArrayAttr,
1744-
getPadsArrayAttr(kernelShape[0], 1, needWeightsPadding),
1745-
stridesArrayAttr)),
1746-
convOutputType);
1747-
Value conv2 = getActivationAppliedToConv(
1748-
addQDQNodesForActivationIfNeeded(rewriter.create<ONNXConvOp>(loc,
1749-
convOutputType, input, addDequantizeNodeIfNeeded(weightSlices[0]),
1750-
bias, mlir::StringAttr(), dilations, group,
1751-
convKernelShapeArrayAttr,
1752-
getPadsArrayAttr(kernelShape[0], 2, needWeightsPadding),
1753-
stridesArrayAttr)),
1754-
convOutputType);
1755-
Value conv3 = getActivationAppliedToConv(
1756-
addQDQNodesForActivationIfNeeded(rewriter.create<ONNXConvOp>(loc,
1757-
convOutputType, input, addDequantizeNodeIfNeeded(weightSlices[1]),
1758-
bias, mlir::StringAttr(), dilations, group,
1759-
convKernelShapeArrayAttr,
1760-
getPadsArrayAttr(kernelShape[0], 3, needWeightsPadding),
1761-
stridesArrayAttr)),
1762-
convOutputType);
1763-
Value conv4 = getActivationAppliedToConv(
1764-
addQDQNodesForActivationIfNeeded(rewriter.create<ONNXConvOp>(loc,
1765-
convOutputType, input, addDequantizeNodeIfNeeded(weightSlices[2]),
1766-
bias, mlir::StringAttr(), dilations, group,
1767-
convKernelShapeArrayAttr,
1768-
getPadsArrayAttr(kernelShape[0], 4, needWeightsPadding),
1769-
stridesArrayAttr)),
1770-
convOutputType);
1686+
if (needWeightsPadding) {
1687+
auto startOnnxConstant =
1688+
getONNXConstOpFromVector(rewriter, loc, {1, 1});
1689+
auto endOnnxConstant = getONNXConstOpFromVector(rewriter, loc,
1690+
{convOutputShape[convOutputShape.size() - 2] + 2,
1691+
convOutputShape[convOutputShape.size() - 1] + 2});
1692+
auto axisOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {2, 3});
1693+
auto stepOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {1, 1});
1694+
auto convSliceOutputType = RankedTensorType::get(
1695+
convOutputShape, convTransposeOutputType.getElementType());
1696+
conv1 = rewriter.create<ONNXSliceOp>(loc, convSliceOutputType, conv1,
1697+
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1698+
stepOnnxConstant);
1699+
1700+
startOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {0, 0});
1701+
endOnnxConstant = getONNXConstOpFromVector(rewriter, loc,
1702+
{convOutputShape[convOutputShape.size() - 2],
1703+
convOutputShape[convOutputShape.size() - 1]});
1704+
conv2 = rewriter.create<ONNXSliceOp>(loc, convSliceOutputType, conv2,
1705+
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1706+
stepOnnxConstant);
1707+
1708+
startOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {1, 0});
1709+
endOnnxConstant = getONNXConstOpFromVector(rewriter, loc,
1710+
{convOutputShape[convOutputShape.size() - 2] + 2,
1711+
convOutputShape[convOutputShape.size() - 1]});
1712+
conv3 = rewriter.create<ONNXSliceOp>(loc, convSliceOutputType, conv3,
1713+
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1714+
stepOnnxConstant);
1715+
1716+
startOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {0, 1});
1717+
endOnnxConstant = getONNXConstOpFromVector(rewriter, loc,
1718+
{convOutputShape[convOutputShape.size() - 2],
1719+
convOutputShape[convOutputShape.size() - 1] + 2});
1720+
conv4 = rewriter.create<ONNXSliceOp>(loc, convSliceOutputType, conv4,
1721+
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1722+
stepOnnxConstant);
1723+
}
17711724
// Four conv outputs are merged in channel dim
17721725
SmallVector<int64_t> outputShapeOfConcat = {
17731726
1, convOutputShape[1] * 4, convOutputShape[2], convOutputShape[3]};

0 commit comments

Comments
 (0)