@@ -851,6 +851,10 @@ bool ShouldDecomposeConvTransposeOpToPhasedConvs(Value convTransposeResult,
851851 bool fourPhaseDecomposition = (stridesShape[0 ] == 2 );
852852 bool ninePhaseDecomposition = (stridesShape[0 ] == 3 );
853853 if (fourPhaseDecomposition) {
854+ if (outputShape[0 ] != 1 ) {
855+ // Currently support batch=1
856+ return false ;
857+ }
854858 if (kernelShape[0 ] == 6 && padsShape[0 ] == 2 &&
855859 llvm::all_equal (padsShape)) {
856860 // Currently support only with pads [2, 2, 2, 2]
@@ -1713,102 +1717,56 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
17131717 startOnnxConstant, endOnnxConstant, axisOnnxConstant,
17141718 stepOnnxConstant);
17151719 }
1716-
1717- // The four convOutputs are adjusted to add an extra dimension at the
1718- // innermost level.
1719- SmallVector<int64_t > outputShapePlusOneDim (convOutputShape);
1720- outputShapePlusOneDim.push_back (1 );
1721- auto onnxConstForReshapeAddOneDim =
1722- getONNXConstOpFromVector (rewriter, loc, outputShapePlusOneDim);
1723-
1724- auto reshapeOutputType =
1725- RankedTensorType::get (outputShapePlusOneDim, elementType);
1726-
1727- auto reshapeOutputAddOneDimConv1 = rewriter.create <ONNXReshapeOp>(
1728- loc, reshapeOutputType, conv1, onnxConstForReshapeAddOneDim);
1729- auto reshapeOutputAddOneDimConv2 = rewriter.create <ONNXReshapeOp>(
1730- loc, reshapeOutputType, conv2, onnxConstForReshapeAddOneDim);
1731- auto reshapeOutputAddOneDimConv3 = rewriter.create <ONNXReshapeOp>(
1732- loc, reshapeOutputType, conv3, onnxConstForReshapeAddOneDim);
1733- auto reshapeOutputAddOneDimConv4 = rewriter.create <ONNXReshapeOp>(
1734- loc, reshapeOutputType, conv4, onnxConstForReshapeAddOneDim);
1735-
1736- SmallVector<int64_t > outputShapeLevel1Concat (outputShapePlusOneDim);
1737- outputShapeLevel1Concat[outputShapeLevel1Concat.size () - 1 ] = 2 ;
1738- auto level1ConcatOutputType =
1739- RankedTensorType::get (outputShapeLevel1Concat, elementType);
1720+ // Four conv outputs are merged in channel dim
1721+ SmallVector<int64_t > outputShapeOfConcat = {
1722+ 1 , convOutputShape[1 ] * 4 , convOutputShape[2 ], convOutputShape[3 ]};
1723+ auto concatOutputType =
1724+ RankedTensorType::get (outputShapeOfConcat, elementType);
17401725 // for the case where convtranspose kernel is [4, 4] and with pads [1, 1, 1,
17411726 // 1] The phased convs output are to be concatenated in the reverse order.
17421727 // This is observed by looking at the phased conv outputs with respect to
17431728 // convtranspose output.
17441729 bool reverseConcatOrder = (needWeightsPadding || (kernelShape[0 ] == 4 ));
1745- // Below concats result will have the innermost dim as 2 .
1730+ // The concat output will have 4 times the channels of a single conv .
17461731 auto firstConcat =
17471732 (reverseConcatOrder)
1748- ? rewriter.create <ONNXConcatOp>(loc, level1ConcatOutputType,
1749- ValueRange{
1750- reshapeOutputAddOneDimConv3, reshapeOutputAddOneDimConv1},
1751- -1 )
1752- : rewriter.create <ONNXConcatOp>(loc, level1ConcatOutputType,
1753- ValueRange{
1754- reshapeOutputAddOneDimConv1, reshapeOutputAddOneDimConv3},
1755- -1 );
1756- auto secondConcat =
1757- (reverseConcatOrder)
1758- ? rewriter.create <ONNXConcatOp>(loc, level1ConcatOutputType,
1759- ValueRange{
1760- reshapeOutputAddOneDimConv2, reshapeOutputAddOneDimConv4},
1761- -1 )
1762- : rewriter.create <ONNXConcatOp>(loc, level1ConcatOutputType,
1763- ValueRange{
1764- reshapeOutputAddOneDimConv4, reshapeOutputAddOneDimConv2},
1765- -1 );
1733+ ? rewriter.create <ONNXConcatOp>(loc, concatOutputType,
1734+ ValueRange{conv2, conv4, conv3, conv1}, 1 )
1735+ : rewriter.create <ONNXConcatOp>(loc, concatOutputType,
1736+ ValueRange{conv1, conv3, conv4, conv2}, 1 );
17661737
1767- // Reshaping to modify the two innermost levels,ensuring the second
1768- // innermost level is set to 1
1769- SmallVector<int64_t > outputShapeForDimAdjust (convOutputShape);
1770- auto dimValueAtLastIndex = convOutputShape[convOutputShape.size () - 1 ] * 2 ;
1771- outputShapeForDimAdjust[outputShapeForDimAdjust.size () - 1 ] = 1 ;
1772- outputShapeForDimAdjust.push_back (dimValueAtLastIndex);
1738+ // Here we are reshaping the concatenated conv channels of 4*Conv_channels
1739+ // into groups of 2x2 channels. This can be visualized as
1740+ // H_chan(2) * W_Chan(2) * C_real, then doing the transpose into
1741+ // Conv_channels H H_chan W W_chan. The adjecent H and H_chan will be merged
1742+ // into H, same way W and W_chan will be merged into W. This leads to
1743+ // doubling of the H and W. Keeping the channels same.
1744+
1745+ SmallVector<int64_t > outputShapeForDimAdjust = {
1746+ 2 , 2 , convOutputShape[1 ], convOutputShape[2 ], convOutputShape[3 ]};
17731747
17741748 auto onnxConstForReshapeDimAdjust =
17751749 getONNXConstOpFromVector (rewriter, loc, outputShapeForDimAdjust);
17761750
17771751 auto reshapeOutputForDimAdjustType =
17781752 RankedTensorType::get (outputShapeForDimAdjust, elementType);
1779- auto reshapeOutputDimAdjustOfFirstConcat =
1753+ auto reshapeOutputDimAdjust =
17801754 rewriter.create <ONNXReshapeOp>(loc, reshapeOutputForDimAdjustType,
17811755 firstConcat, onnxConstForReshapeDimAdjust);
1782- auto reshapeOutputDimAdjustOfSecondConcat =
1783- rewriter.create <ONNXReshapeOp>(loc, reshapeOutputForDimAdjustType,
1784- secondConcat, onnxConstForReshapeDimAdjust);
17851756
1786- SmallVector<int64_t > outputShapeForFinalConcat (outputShapeForDimAdjust);
1787- outputShapeForFinalConcat[outputShapeForFinalConcat. size () - 2 ] = 2 ;
1757+ SmallVector<int64_t > transposeOuputShape = {
1758+ convOutputShape[ 1 ], convOutputShape[ 2 ], 2 , convOutputShape[ 3 ], 2 } ;
17881759
1789- auto finalConcatOutputType =
1790- RankedTensorType::get (outputShapeForFinalConcat , elementType);
1760+ auto transposeOutputType =
1761+ RankedTensorType::get (transposeOuputShape , elementType);
17911762
1792- // Final Concat is performed on the two reshaped outputs at the
1793- // second innermost level
1794- auto finalConcat =
1795- (reverseConcatOrder)
1796- ? rewriter.create <ONNXConcatOp>(loc, finalConcatOutputType,
1797- ValueRange{reshapeOutputDimAdjustOfSecondConcat,
1798- reshapeOutputDimAdjustOfFirstConcat},
1799- -2 )
1800- : rewriter.create <ONNXConcatOp>(loc, finalConcatOutputType,
1801- ValueRange{reshapeOutputDimAdjustOfFirstConcat,
1802- reshapeOutputDimAdjustOfSecondConcat},
1803- -2 );
1763+ auto permArrayAttr = rewriter.getI64ArrayAttr ({2 , 3 , 0 , 4 , 1 });
18041764
1805- SmallVector<int64_t > outputShapeForResult (convOutputShape);
1806- dimValueAtLastIndex = convOutputShape[convOutputShape.size () - 1 ] * 2 ;
1807- auto dimValueAtSecondLastIndex =
1808- convOutputShape[convOutputShape.size () - 2 ] * 2 ;
1809- outputShapeForResult[outputShapeForResult.size () - 2 ] =
1810- dimValueAtSecondLastIndex;
1811- outputShapeForResult[outputShapeForResult.size () - 1 ] = dimValueAtLastIndex;
1765+ auto transpose = rewriter.create <ONNXTransposeOp>(
1766+ loc, transposeOutputType, reshapeOutputDimAdjust, permArrayAttr);
1767+
1768+ SmallVector<int64_t > outputShapeForResult = {
1769+ 1 , convOutputShape[1 ], convOutputShape[2 ] * 2 , convOutputShape[3 ] * 2 };
18121770
18131771 auto onnxConstForLastReshape =
18141772 getONNXConstOpFromVector (rewriter, loc, outputShapeForResult);
@@ -1818,7 +1776,7 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
18181776 // Result is reshaped back to match the original convtranspose output
18191777 // dimensions
18201778 auto finalOutput = rewriter.create <ONNXReshapeOp>(
1821- loc, finalOutputType, finalConcat , onnxConstForLastReshape);
1779+ loc, finalOutputType, transpose , onnxConstForLastReshape);
18221780 return finalOutput;
18231781 }
18241782 if (numPhases == 9 ) {
0 commit comments