Skip to content

Commit 05dad07

Browse files
simplifying the 4 phased conv merge operation
1 parent f1ecdb1 commit 05dad07

File tree

2 files changed

+343
-450
lines changed

2 files changed

+343
-450
lines changed

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 29 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1716,27 +1716,10 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
17161716

17171717
// The four convOutputs are adjusted to add an extra dimension at the
17181718
// innermost level.
1719-
SmallVector<int64_t> outputShapePlusOneDim(convOutputShape);
1720-
outputShapePlusOneDim.push_back(1);
1721-
auto onnxConstForReshapeAddOneDim =
1722-
getONNXConstOpFromVector(rewriter, loc, outputShapePlusOneDim);
1723-
1724-
auto reshapeOutputType =
1725-
RankedTensorType::get(outputShapePlusOneDim, elementType);
1726-
1727-
auto reshapeOutputAddOneDimConv1 = rewriter.create<ONNXReshapeOp>(
1728-
loc, reshapeOutputType, conv1, onnxConstForReshapeAddOneDim);
1729-
auto reshapeOutputAddOneDimConv2 = rewriter.create<ONNXReshapeOp>(
1730-
loc, reshapeOutputType, conv2, onnxConstForReshapeAddOneDim);
1731-
auto reshapeOutputAddOneDimConv3 = rewriter.create<ONNXReshapeOp>(
1732-
loc, reshapeOutputType, conv3, onnxConstForReshapeAddOneDim);
1733-
auto reshapeOutputAddOneDimConv4 = rewriter.create<ONNXReshapeOp>(
1734-
loc, reshapeOutputType, conv4, onnxConstForReshapeAddOneDim);
1735-
1736-
SmallVector<int64_t> outputShapeLevel1Concat(outputShapePlusOneDim);
1737-
outputShapeLevel1Concat[outputShapeLevel1Concat.size() - 1] = 2;
1738-
auto level1ConcatOutputType =
1739-
RankedTensorType::get(outputShapeLevel1Concat, elementType);
1719+
SmallVector<int64_t> outputShapeOfConcat = {
1720+
1, convOutputShape[1] * 4, convOutputShape[2], convOutputShape[3]};
1721+
auto concatOutputType =
1722+
RankedTensorType::get(outputShapeOfConcat, elementType);
17401723
// for the case where convtranspose kernel is [4, 4] and with pads [1, 1, 1,
17411724
// 1] The phased convs output are to be concatenated in the reverse order.
17421725
// This is observed by looking at the phased conv outputs with respect to
@@ -1745,70 +1728,43 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
17451728
// Below concats result will have the innermost dim as 2.
17461729
auto firstConcat =
17471730
(reverseConcatOrder)
1748-
? rewriter.create<ONNXConcatOp>(loc, level1ConcatOutputType,
1749-
ValueRange{
1750-
reshapeOutputAddOneDimConv3, reshapeOutputAddOneDimConv1},
1751-
-1)
1752-
: rewriter.create<ONNXConcatOp>(loc, level1ConcatOutputType,
1753-
ValueRange{
1754-
reshapeOutputAddOneDimConv1, reshapeOutputAddOneDimConv3},
1755-
-1);
1756-
auto secondConcat =
1757-
(reverseConcatOrder)
1758-
? rewriter.create<ONNXConcatOp>(loc, level1ConcatOutputType,
1759-
ValueRange{
1760-
reshapeOutputAddOneDimConv2, reshapeOutputAddOneDimConv4},
1761-
-1)
1762-
: rewriter.create<ONNXConcatOp>(loc, level1ConcatOutputType,
1763-
ValueRange{
1764-
reshapeOutputAddOneDimConv4, reshapeOutputAddOneDimConv2},
1765-
-1);
1731+
? rewriter.create<ONNXConcatOp>(loc, concatOutputType,
1732+
ValueRange{conv2, conv4, conv3, conv1}, 1)
1733+
: rewriter.create<ONNXConcatOp>(loc, concatOutputType,
1734+
ValueRange{conv1, conv3, conv4, conv2}, 1);
17661735

1767-
// Reshaping to modify the two innermost levels,ensuring the second
1768-
// innermost level is set to 1
1769-
SmallVector<int64_t> outputShapeForDimAdjust(convOutputShape);
1770-
auto dimValueAtLastIndex = convOutputShape[convOutputShape.size() - 1] * 2;
1771-
outputShapeForDimAdjust[outputShapeForDimAdjust.size() - 1] = 1;
1772-
outputShapeForDimAdjust.push_back(dimValueAtLastIndex);
1736+
// Here we are reshaping the concatenated conv channels of 4*Conv_channels
1737+
// into groups of 2x2 channels. This can be visualized as
1738+
// H_chan(2) * W_Chan(2) * C_real, then doing the transpose into
1739+
// Conv_channels H H_chan W W_chan. The adjecent H and H_chan will be merged
1740+
// into H, same way W and W_chan will be merged into W. This leads to
1741+
// doubling of the H and W. Keeping the channels same.
1742+
1743+
SmallVector<int64_t> outputShapeForDimAdjust = {
1744+
2, 2, convOutputShape[1], convOutputShape[2], convOutputShape[3]};
17731745

17741746
auto onnxConstForReshapeDimAdjust =
17751747
getONNXConstOpFromVector(rewriter, loc, outputShapeForDimAdjust);
17761748

17771749
auto reshapeOutputForDimAdjustType =
17781750
RankedTensorType::get(outputShapeForDimAdjust, elementType);
1779-
auto reshapeOutputDimAdjustOfFirstConcat =
1751+
auto reshapeOutputDimAdjust =
17801752
rewriter.create<ONNXReshapeOp>(loc, reshapeOutputForDimAdjustType,
17811753
firstConcat, onnxConstForReshapeDimAdjust);
1782-
auto reshapeOutputDimAdjustOfSecondConcat =
1783-
rewriter.create<ONNXReshapeOp>(loc, reshapeOutputForDimAdjustType,
1784-
secondConcat, onnxConstForReshapeDimAdjust);
17851754

1786-
SmallVector<int64_t> outputShapeForFinalConcat(outputShapeForDimAdjust);
1787-
outputShapeForFinalConcat[outputShapeForFinalConcat.size() - 2] = 2;
1755+
SmallVector<int64_t> transposeOuputShape = {
1756+
convOutputShape[1], convOutputShape[2], 2, convOutputShape[3], 2};
17881757

1789-
auto finalConcatOutputType =
1790-
RankedTensorType::get(outputShapeForFinalConcat, elementType);
1758+
auto transposeOutputType =
1759+
RankedTensorType::get(transposeOuputShape, elementType);
17911760

1792-
// Final Concat is performed on the two reshaped outputs at the
1793-
// second innermost level
1794-
auto finalConcat =
1795-
(reverseConcatOrder)
1796-
? rewriter.create<ONNXConcatOp>(loc, finalConcatOutputType,
1797-
ValueRange{reshapeOutputDimAdjustOfSecondConcat,
1798-
reshapeOutputDimAdjustOfFirstConcat},
1799-
-2)
1800-
: rewriter.create<ONNXConcatOp>(loc, finalConcatOutputType,
1801-
ValueRange{reshapeOutputDimAdjustOfFirstConcat,
1802-
reshapeOutputDimAdjustOfSecondConcat},
1803-
-2);
1761+
auto permArrayAttr = rewriter.getI64ArrayAttr({2, 3, 0, 4, 1});
18041762

1805-
SmallVector<int64_t> outputShapeForResult(convOutputShape);
1806-
dimValueAtLastIndex = convOutputShape[convOutputShape.size() - 1] * 2;
1807-
auto dimValueAtSecondLastIndex =
1808-
convOutputShape[convOutputShape.size() - 2] * 2;
1809-
outputShapeForResult[outputShapeForResult.size() - 2] =
1810-
dimValueAtSecondLastIndex;
1811-
outputShapeForResult[outputShapeForResult.size() - 1] = dimValueAtLastIndex;
1763+
auto transpose = rewriter.create<ONNXTransposeOp>(
1764+
loc, transposeOutputType, reshapeOutputDimAdjust, permArrayAttr);
1765+
1766+
SmallVector<int64_t> outputShapeForResult = {
1767+
1, convOutputShape[1], convOutputShape[2] * 2, convOutputShape[3] * 2};
18121768

18131769
auto onnxConstForLastReshape =
18141770
getONNXConstOpFromVector(rewriter, loc, outputShapeForResult);
@@ -1818,7 +1774,7 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
18181774
// Result is reshaped back to match the original convtranspose output
18191775
// dimensions
18201776
auto finalOutput = rewriter.create<ONNXReshapeOp>(
1821-
loc, finalOutputType, finalConcat, onnxConstForLastReshape);
1777+
loc, finalOutputType, transpose, onnxConstForLastReshape);
18221778
return finalOutput;
18231779
}
18241780
if (numPhases == 9) {

0 commit comments

Comments
 (0)