Skip to content

Commit 9314dbe

Browse files
Merge pull request #427 from Xilinx/chaitany.improving_convtranspose_with_single_conv
Improving convtranspose decomposition for stride 2x2
2 parents 605ea3c + 6caf973 commit 9314dbe

File tree

3 files changed

+283
-349
lines changed

3 files changed

+283
-349
lines changed

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 98 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,41 +1647,42 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
16471647
}
16481648
};
16491649
auto stridesArrayAttr = rewriter.getI64ArrayAttr({1, 1});
1650-
1651-
Value conv1 = getActivationAppliedToConv(
1652-
addQDQNodesForActivationIfNeeded(
1653-
rewriter.create<ONNXConvOp>(loc, convOutputType, input,
1654-
addDequantizeNodeIfNeeded(weightSlices[3]), bias,
1655-
mlir::StringAttr(), dilations, group, convKernelShapeArrayAttr,
1656-
getPadsArrayAttr(kernelShape[0], 1, needWeightsPadding),
1657-
stridesArrayAttr)),
1658-
convOutputType);
1659-
Value conv2 = getActivationAppliedToConv(
1660-
addQDQNodesForActivationIfNeeded(
1661-
rewriter.create<ONNXConvOp>(loc, convOutputType, input,
1662-
addDequantizeNodeIfNeeded(weightSlices[0]), bias,
1663-
mlir::StringAttr(), dilations, group, convKernelShapeArrayAttr,
1664-
getPadsArrayAttr(kernelShape[0], 2, needWeightsPadding),
1665-
stridesArrayAttr)),
1666-
convOutputType);
1667-
Value conv3 = getActivationAppliedToConv(
1668-
addQDQNodesForActivationIfNeeded(
1669-
rewriter.create<ONNXConvOp>(loc, convOutputType, input,
1670-
addDequantizeNodeIfNeeded(weightSlices[1]), bias,
1671-
mlir::StringAttr(), dilations, group, convKernelShapeArrayAttr,
1672-
getPadsArrayAttr(kernelShape[0], 3, needWeightsPadding),
1673-
stridesArrayAttr)),
1674-
convOutputType);
1675-
Value conv4 = getActivationAppliedToConv(
1676-
addQDQNodesForActivationIfNeeded(
1677-
rewriter.create<ONNXConvOp>(loc, convOutputType, input,
1678-
addDequantizeNodeIfNeeded(weightSlices[2]), bias,
1679-
mlir::StringAttr(), dilations, group, convKernelShapeArrayAttr,
1680-
getPadsArrayAttr(kernelShape[0], 4, needWeightsPadding),
1681-
stridesArrayAttr)),
1682-
convOutputType);
1683-
// Need to remove excess the ofm when weights are padded.
1650+
Value conv;
16841651
if (needWeightsPadding) {
1652+
Value conv1 = getActivationAppliedToConv(
1653+
addQDQNodesForActivationIfNeeded(rewriter.create<ONNXConvOp>(loc,
1654+
convOutputType, input, addDequantizeNodeIfNeeded(weightSlices[3]),
1655+
bias, mlir::StringAttr(), dilations, group,
1656+
convKernelShapeArrayAttr,
1657+
getPadsArrayAttr(kernelShape[0], 1, needWeightsPadding),
1658+
stridesArrayAttr)),
1659+
convOutputType);
1660+
Value conv2 = getActivationAppliedToConv(
1661+
addQDQNodesForActivationIfNeeded(rewriter.create<ONNXConvOp>(loc,
1662+
convOutputType, input, addDequantizeNodeIfNeeded(weightSlices[0]),
1663+
bias, mlir::StringAttr(), dilations, group,
1664+
convKernelShapeArrayAttr,
1665+
getPadsArrayAttr(kernelShape[0], 2, needWeightsPadding),
1666+
stridesArrayAttr)),
1667+
convOutputType);
1668+
Value conv3 = getActivationAppliedToConv(
1669+
addQDQNodesForActivationIfNeeded(rewriter.create<ONNXConvOp>(loc,
1670+
convOutputType, input, addDequantizeNodeIfNeeded(weightSlices[1]),
1671+
bias, mlir::StringAttr(), dilations, group,
1672+
convKernelShapeArrayAttr,
1673+
getPadsArrayAttr(kernelShape[0], 3, needWeightsPadding),
1674+
stridesArrayAttr)),
1675+
convOutputType);
1676+
Value conv4 = getActivationAppliedToConv(
1677+
addQDQNodesForActivationIfNeeded(rewriter.create<ONNXConvOp>(loc,
1678+
convOutputType, input, addDequantizeNodeIfNeeded(weightSlices[2]),
1679+
bias, mlir::StringAttr(), dilations, group,
1680+
convKernelShapeArrayAttr,
1681+
getPadsArrayAttr(kernelShape[0], 4, needWeightsPadding),
1682+
stridesArrayAttr)),
1683+
convOutputType);
1684+
// Need to remove excess the ofm when weights are padded.
1685+
16851686
auto startOnnxConstant = getONNXConstOpFromVector(rewriter, loc, {1, 1});
16861687
auto endOnnxConstant = getONNXConstOpFromVector(rewriter, loc,
16871688
{convOutputShape[convOutputShape.size() - 2] + 2,
@@ -1717,24 +1718,68 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
17171718
conv4 = rewriter.create<ONNXSliceOp>(loc, convSliceOutputType, conv4,
17181719
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
17191720
stepOnnxConstant);
1721+
1722+
// Four conv outputs are merged in channel dim
1723+
SmallVector<int64_t> outputShapeOfConcat = {
1724+
1, convOutputShape[1] * 4, convOutputShape[2], convOutputShape[3]};
1725+
auto concatOutputType =
1726+
RankedTensorType::get(outputShapeOfConcat, elementType);
1727+
// for the case where convtranspose kernel is [4, 4] and with pads [1, 1,
1728+
// 1, 1] The phased convs output are to be concatenated in the reverse
1729+
// order. This is observed by looking at the phased conv outputs with
1730+
// respect to convtranspose output.
1731+
bool reverseConcatOrder = (needWeightsPadding || (kernelShape[0] == 4));
1732+
// The concat output will have 4 times the channels of a single conv.
1733+
conv = (reverseConcatOrder)
1734+
? rewriter.create<ONNXConcatOp>(loc, concatOutputType,
1735+
ValueRange{conv2, conv4, conv3, conv1}, 1)
1736+
: rewriter.create<ONNXConcatOp>(loc, concatOutputType,
1737+
ValueRange{conv1, conv3, conv4, conv2}, 1);
1738+
} else {
1739+
// Combining the 4 phased weights into single weight.
1740+
bool reverseOrder = (kernelShape[0] == 4);
1741+
auto combinedConvWeightsShapedType =
1742+
weightsType.get({weightsShape[0] * 4, weightsShape[1], convKernelSize,
1743+
convKernelSize},
1744+
weightsType.getElementType());
1745+
1746+
Value combinedWeights =
1747+
(reverseOrder) ? rewriter.create<ONNXConcatOp>(loc,
1748+
combinedConvWeightsShapedType,
1749+
ValueRange{weightSlices[0], weightSlices[2],
1750+
weightSlices[1], weightSlices[3]},
1751+
0)
1752+
: rewriter.create<ONNXConcatOp>(loc,
1753+
combinedConvWeightsShapedType,
1754+
ValueRange{weightSlices[3], weightSlices[1],
1755+
weightSlices[2], weightSlices[0]},
1756+
0);
1757+
1758+
if (!bias.getDefiningOp<ONNXNoneOp>()) {
1759+
RankedTensorType biasType =
1760+
mlir::cast<RankedTensorType>(bias.getType());
1761+
auto biasShape = biasType.getShape();
1762+
1763+
auto combinedBiasShapedType =
1764+
biasType.get({biasShape[0] * 4}, biasType.getElementType());
1765+
1766+
bias = rewriter.create<ONNXConcatOp>(
1767+
loc, combinedBiasShapedType, ValueRange{bias, bias, bias, bias}, 0);
1768+
}
1769+
1770+
auto combinedConvOutputType = RankedTensorType::get(
1771+
SmallVector<int64_t>({convOutputShape[0], convOutputShape[1] * 4,
1772+
convOutputShape[2], convOutputShape[3]}),
1773+
convTransposeOutputType.getElementType());
1774+
conv = getActivationAppliedToConv(
1775+
addQDQNodesForActivationIfNeeded(rewriter.create<ONNXConvOp>(loc,
1776+
combinedConvOutputType, input,
1777+
addDequantizeNodeIfNeeded(combinedWeights), bias,
1778+
mlir::StringAttr(), dilations, group, convKernelShapeArrayAttr,
1779+
getPadsArrayAttr(kernelShape[0], 1, needWeightsPadding),
1780+
stridesArrayAttr)),
1781+
combinedConvOutputType);
17201782
}
1721-
// Four conv outputs are merged in channel dim
1722-
SmallVector<int64_t> outputShapeOfConcat = {
1723-
1, convOutputShape[1] * 4, convOutputShape[2], convOutputShape[3]};
1724-
auto concatOutputType =
1725-
RankedTensorType::get(outputShapeOfConcat, elementType);
1726-
// for the case where convtranspose kernel is [4, 4] and with pads [1, 1, 1,
1727-
// 1] The phased convs output are to be concatenated in the reverse order.
1728-
// This is observed by looking at the phased conv outputs with respect to
1729-
// convtranspose output.
1730-
bool reverseConcatOrder = (needWeightsPadding || (kernelShape[0] == 4));
1731-
// The concat output will have 4 times the channels of a single conv.
1732-
auto firstConcat =
1733-
(reverseConcatOrder)
1734-
? rewriter.create<ONNXConcatOp>(loc, concatOutputType,
1735-
ValueRange{conv2, conv4, conv3, conv1}, 1)
1736-
: rewriter.create<ONNXConcatOp>(loc, concatOutputType,
1737-
ValueRange{conv1, conv3, conv4, conv2}, 1);
17381783

17391784
// Here we are reshaping the concatenated conv channels of 4*Conv_channels
17401785
// into groups of 2x2 channels. This can be visualized as
@@ -1751,9 +1796,8 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
17511796

17521797
auto reshapeOutputForDimAdjustType =
17531798
RankedTensorType::get(outputShapeForDimAdjust, elementType);
1754-
auto reshapeOutputDimAdjust =
1755-
rewriter.create<ONNXReshapeOp>(loc, reshapeOutputForDimAdjustType,
1756-
firstConcat, onnxConstForReshapeDimAdjust);
1799+
auto reshapeOutputDimAdjust = rewriter.create<ONNXReshapeOp>(
1800+
loc, reshapeOutputForDimAdjustType, conv, onnxConstForReshapeDimAdjust);
17571801

17581802
SmallVector<int64_t> transposeOuputShape = {
17591803
convOutputShape[1], convOutputShape[2], 2, convOutputShape[3], 2};

0 commit comments

Comments
 (0)