@@ -1716,27 +1716,10 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
17161716
17171717 // The four convOutputs are adjusted to add an extra dimension at the
17181718 // innermost level.
1719- SmallVector<int64_t > outputShapePlusOneDim (convOutputShape);
1720- outputShapePlusOneDim.push_back (1 );
1721- auto onnxConstForReshapeAddOneDim =
1722- getONNXConstOpFromVector (rewriter, loc, outputShapePlusOneDim);
1723-
1724- auto reshapeOutputType =
1725- RankedTensorType::get (outputShapePlusOneDim, elementType);
1726-
1727- auto reshapeOutputAddOneDimConv1 = rewriter.create <ONNXReshapeOp>(
1728- loc, reshapeOutputType, conv1, onnxConstForReshapeAddOneDim);
1729- auto reshapeOutputAddOneDimConv2 = rewriter.create <ONNXReshapeOp>(
1730- loc, reshapeOutputType, conv2, onnxConstForReshapeAddOneDim);
1731- auto reshapeOutputAddOneDimConv3 = rewriter.create <ONNXReshapeOp>(
1732- loc, reshapeOutputType, conv3, onnxConstForReshapeAddOneDim);
1733- auto reshapeOutputAddOneDimConv4 = rewriter.create <ONNXReshapeOp>(
1734- loc, reshapeOutputType, conv4, onnxConstForReshapeAddOneDim);
1735-
1736- SmallVector<int64_t > outputShapeLevel1Concat (outputShapePlusOneDim);
1737- outputShapeLevel1Concat[outputShapeLevel1Concat.size () - 1 ] = 2 ;
1738- auto level1ConcatOutputType =
1739- RankedTensorType::get (outputShapeLevel1Concat, elementType);
1719+ SmallVector<int64_t > outputShapeOfConcat = {
1720+ 1 , convOutputShape[1 ] * 4 , convOutputShape[2 ], convOutputShape[3 ]};
1721+ auto concatOutputType =
1722+ RankedTensorType::get (outputShapeOfConcat, elementType);
17401723 // for the case where convtranspose kernel is [4, 4] and with pads [1, 1, 1,
17411724 // 1] The phased convs output are to be concatenated in the reverse order.
17421725 // This is observed by looking at the phased conv outputs with respect to
@@ -1745,70 +1728,43 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
17451728 // Below concats result will have the innermost dim as 2.
17461729 auto firstConcat =
17471730 (reverseConcatOrder)
1748- ? rewriter.create <ONNXConcatOp>(loc, level1ConcatOutputType,
1749- ValueRange{
1750- reshapeOutputAddOneDimConv3, reshapeOutputAddOneDimConv1},
1751- -1 )
1752- : rewriter.create <ONNXConcatOp>(loc, level1ConcatOutputType,
1753- ValueRange{
1754- reshapeOutputAddOneDimConv1, reshapeOutputAddOneDimConv3},
1755- -1 );
1756- auto secondConcat =
1757- (reverseConcatOrder)
1758- ? rewriter.create <ONNXConcatOp>(loc, level1ConcatOutputType,
1759- ValueRange{
1760- reshapeOutputAddOneDimConv2, reshapeOutputAddOneDimConv4},
1761- -1 )
1762- : rewriter.create <ONNXConcatOp>(loc, level1ConcatOutputType,
1763- ValueRange{
1764- reshapeOutputAddOneDimConv4, reshapeOutputAddOneDimConv2},
1765- -1 );
1731+ ? rewriter.create <ONNXConcatOp>(loc, concatOutputType,
1732+ ValueRange{conv2, conv4, conv3, conv1}, 1 )
1733+ : rewriter.create <ONNXConcatOp>(loc, concatOutputType,
1734+ ValueRange{conv1, conv3, conv4, conv2}, 1 );
17661735
1767- // Reshaping to modify the two innermost levels,ensuring the second
1768- // innermost level is set to 1
1769- SmallVector<int64_t > outputShapeForDimAdjust (convOutputShape);
1770- auto dimValueAtLastIndex = convOutputShape[convOutputShape.size () - 1 ] * 2 ;
1771- outputShapeForDimAdjust[outputShapeForDimAdjust.size () - 1 ] = 1 ;
1772- outputShapeForDimAdjust.push_back (dimValueAtLastIndex);
1736+ // Here we are reshaping the concatenated conv channels of 4*Conv_channels
1737+ // into groups of 2x2 channels. This can be visualized as
1738+ // H_chan(2) * W_Chan(2) * C_real, then doing the transpose into
1739+ // Conv_channels H H_chan W W_chan. The adjecent H and H_chan will be merged
1740+ // into H, same way W and W_chan will be merged into W. This leads to
1741+ // doubling of the H and W. Keeping the channels same.
1742+
1743+ SmallVector<int64_t > outputShapeForDimAdjust = {
1744+ 2 , 2 , convOutputShape[1 ], convOutputShape[2 ], convOutputShape[3 ]};
17731745
17741746 auto onnxConstForReshapeDimAdjust =
17751747 getONNXConstOpFromVector (rewriter, loc, outputShapeForDimAdjust);
17761748
17771749 auto reshapeOutputForDimAdjustType =
17781750 RankedTensorType::get (outputShapeForDimAdjust, elementType);
1779- auto reshapeOutputDimAdjustOfFirstConcat =
1751+ auto reshapeOutputDimAdjust =
17801752 rewriter.create <ONNXReshapeOp>(loc, reshapeOutputForDimAdjustType,
17811753 firstConcat, onnxConstForReshapeDimAdjust);
1782- auto reshapeOutputDimAdjustOfSecondConcat =
1783- rewriter.create <ONNXReshapeOp>(loc, reshapeOutputForDimAdjustType,
1784- secondConcat, onnxConstForReshapeDimAdjust);
17851754
1786- SmallVector<int64_t > outputShapeForFinalConcat (outputShapeForDimAdjust);
1787- outputShapeForFinalConcat[outputShapeForFinalConcat. size () - 2 ] = 2 ;
1755+ SmallVector<int64_t > transposeOuputShape = {
1756+ convOutputShape[ 1 ], convOutputShape[ 2 ], 2 , convOutputShape[ 3 ], 2 } ;
17881757
1789- auto finalConcatOutputType =
1790- RankedTensorType::get (outputShapeForFinalConcat , elementType);
1758+ auto transposeOutputType =
1759+ RankedTensorType::get (transposeOuputShape , elementType);
17911760
1792- // Final Concat is performed on the two reshaped outputs at the
1793- // second innermost level
1794- auto finalConcat =
1795- (reverseConcatOrder)
1796- ? rewriter.create <ONNXConcatOp>(loc, finalConcatOutputType,
1797- ValueRange{reshapeOutputDimAdjustOfSecondConcat,
1798- reshapeOutputDimAdjustOfFirstConcat},
1799- -2 )
1800- : rewriter.create <ONNXConcatOp>(loc, finalConcatOutputType,
1801- ValueRange{reshapeOutputDimAdjustOfFirstConcat,
1802- reshapeOutputDimAdjustOfSecondConcat},
1803- -2 );
1761+ auto permArrayAttr = rewriter.getI64ArrayAttr ({2 , 3 , 0 , 4 , 1 });
18041762
1805- SmallVector<int64_t > outputShapeForResult (convOutputShape);
1806- dimValueAtLastIndex = convOutputShape[convOutputShape.size () - 1 ] * 2 ;
1807- auto dimValueAtSecondLastIndex =
1808- convOutputShape[convOutputShape.size () - 2 ] * 2 ;
1809- outputShapeForResult[outputShapeForResult.size () - 2 ] =
1810- dimValueAtSecondLastIndex;
1811- outputShapeForResult[outputShapeForResult.size () - 1 ] = dimValueAtLastIndex;
1763+ auto transpose = rewriter.create <ONNXTransposeOp>(
1764+ loc, transposeOutputType, reshapeOutputDimAdjust, permArrayAttr);
1765+
1766+ SmallVector<int64_t > outputShapeForResult = {
1767+ 1 , convOutputShape[1 ], convOutputShape[2 ] * 2 , convOutputShape[3 ] * 2 };
18121768
18131769 auto onnxConstForLastReshape =
18141770 getONNXConstOpFromVector (rewriter, loc, outputShapeForResult);
@@ -1818,7 +1774,7 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
18181774 // Result is reshaped back to match the original convtranspose output
18191775 // dimensions
18201776 auto finalOutput = rewriter.create <ONNXReshapeOp>(
1821- loc, finalOutputType, finalConcat , onnxConstForLastReshape);
1777+ loc, finalOutputType, transpose , onnxConstForLastReshape);
18221778 return finalOutput;
18231779 }
18241780 if (numPhases == 9 ) {
0 commit comments