@@ -1648,7 +1648,7 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
16481648 };
16491649 auto stridesArrayAttr = rewriter.getI64ArrayAttr ({1 , 1 });
16501650 Value conv;
1651- if (needWeightsPadding) {
1651+ if (needWeightsPadding || (kernelShape[ 0 ] == 4 ) ) {
16521652 Value conv1 = getActivationAppliedToConv (
16531653 addQDQNodesForActivationIfNeeded (rewriter.create <ONNXConvOp>(loc,
16541654 convOutputType, input, addDequantizeNodeIfNeeded (weightSlices[3 ]),
@@ -1683,91 +1683,44 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
16831683 convOutputType);
16841684 // Need to remove excess the ofm when weights are padded.
16851685
1686- auto startOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {1 , 1 });
1687- auto endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
1688- {convOutputShape[convOutputShape.size () - 2 ] + 2 ,
1689- convOutputShape[convOutputShape.size () - 1 ] + 2 });
1690- auto axisOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {2 , 3 });
1691- auto stepOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {1 , 1 });
1692- auto convSliceOutputType = RankedTensorType::get (
1693- convOutputShape, convTransposeOutputType.getElementType ());
1694- conv1 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv1,
1695- startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1696- stepOnnxConstant);
1697-
1698- startOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {0 , 0 });
1699- endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
1700- {convOutputShape[convOutputShape.size () - 2 ],
1701- convOutputShape[convOutputShape.size () - 1 ]});
1702- conv2 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv2,
1703- startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1704- stepOnnxConstant);
1705-
1706- startOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {1 , 0 });
1707- endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
1708- {convOutputShape[convOutputShape.size () - 2 ] + 2 ,
1709- convOutputShape[convOutputShape.size () - 1 ]});
1710- conv3 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv3,
1711- startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1712- stepOnnxConstant);
1713-
1714- startOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {0 , 1 });
1715- endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
1716- {convOutputShape[convOutputShape.size () - 2 ],
1717- convOutputShape[convOutputShape.size () - 1 ] + 2 });
1718- conv4 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv4,
1719- startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1720- stepOnnxConstant);
1721-
1722- // Four conv outputs are merged in channel dim
1723- SmallVector<int64_t > outputShapeOfConcat = {
1724- 1 , convOutputShape[1 ] * 4 , convOutputShape[2 ], convOutputShape[3 ]};
1725- auto concatOutputType =
1726- RankedTensorType::get (outputShapeOfConcat, elementType);
1727- // for the case where convtranspose kernel is [4, 4] and with pads [1, 1,
1728- // 1, 1] The phased convs output are to be concatenated in the reverse
1729- // order. This is observed by looking at the phased conv outputs with
1730- // respect to convtranspose output.
1731- bool reverseConcatOrder = (needWeightsPadding || (kernelShape[0 ] == 4 ));
1732- // The concat output will have 4 times the channels of a single conv.
1733- conv = (reverseConcatOrder)
1734- ? rewriter.create <ONNXConcatOp>(loc, concatOutputType,
1735- ValueRange{conv2, conv4, conv3, conv1}, 1 )
1736- : rewriter.create <ONNXConcatOp>(loc, concatOutputType,
1737- ValueRange{conv1, conv3, conv4, conv2}, 1 );
1738- } else if (kernelShape[0 ] == 4 ) {
1739- Value conv1 = getActivationAppliedToConv (
1740- addQDQNodesForActivationIfNeeded (rewriter.create <ONNXConvOp>(loc,
1741- convOutputType, input, addDequantizeNodeIfNeeded (weightSlices[3 ]),
1742- bias, mlir::StringAttr (), dilations, group,
1743- convKernelShapeArrayAttr,
1744- getPadsArrayAttr (kernelShape[0 ], 1 , needWeightsPadding),
1745- stridesArrayAttr)),
1746- convOutputType);
1747- Value conv2 = getActivationAppliedToConv (
1748- addQDQNodesForActivationIfNeeded (rewriter.create <ONNXConvOp>(loc,
1749- convOutputType, input, addDequantizeNodeIfNeeded (weightSlices[0 ]),
1750- bias, mlir::StringAttr (), dilations, group,
1751- convKernelShapeArrayAttr,
1752- getPadsArrayAttr (kernelShape[0 ], 2 , needWeightsPadding),
1753- stridesArrayAttr)),
1754- convOutputType);
1755- Value conv3 = getActivationAppliedToConv (
1756- addQDQNodesForActivationIfNeeded (rewriter.create <ONNXConvOp>(loc,
1757- convOutputType, input, addDequantizeNodeIfNeeded (weightSlices[1 ]),
1758- bias, mlir::StringAttr (), dilations, group,
1759- convKernelShapeArrayAttr,
1760- getPadsArrayAttr (kernelShape[0 ], 3 , needWeightsPadding),
1761- stridesArrayAttr)),
1762- convOutputType);
1763- Value conv4 = getActivationAppliedToConv (
1764- addQDQNodesForActivationIfNeeded (rewriter.create <ONNXConvOp>(loc,
1765- convOutputType, input, addDequantizeNodeIfNeeded (weightSlices[2 ]),
1766- bias, mlir::StringAttr (), dilations, group,
1767- convKernelShapeArrayAttr,
1768- getPadsArrayAttr (kernelShape[0 ], 4 , needWeightsPadding),
1769- stridesArrayAttr)),
1770- convOutputType);
1686+ if (needWeightsPadding) {
1687+ auto startOnnxConstant =
1688+ getONNXConstOpFromVector (rewriter, loc, {1 , 1 });
1689+ auto endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
1690+ {convOutputShape[convOutputShape.size () - 2 ] + 2 ,
1691+ convOutputShape[convOutputShape.size () - 1 ] + 2 });
1692+ auto axisOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {2 , 3 });
1693+ auto stepOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {1 , 1 });
1694+ auto convSliceOutputType = RankedTensorType::get (
1695+ convOutputShape, convTransposeOutputType.getElementType ());
1696+ conv1 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv1,
1697+ startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1698+ stepOnnxConstant);
1699+
1700+ startOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {0 , 0 });
1701+ endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
1702+ {convOutputShape[convOutputShape.size () - 2 ],
1703+ convOutputShape[convOutputShape.size () - 1 ]});
1704+ conv2 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv2,
1705+ startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1706+ stepOnnxConstant);
1707+
1708+ startOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {1 , 0 });
1709+ endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
1710+ {convOutputShape[convOutputShape.size () - 2 ] + 2 ,
1711+ convOutputShape[convOutputShape.size () - 1 ]});
1712+ conv3 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv3,
1713+ startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1714+ stepOnnxConstant);
1715+
1716+ startOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {0 , 1 });
1717+ endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
1718+ {convOutputShape[convOutputShape.size () - 2 ],
1719+ convOutputShape[convOutputShape.size () - 1 ] + 2 });
1720+ conv4 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv4,
1721+ startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1722+ stepOnnxConstant);
1723+ }
17711724 // Four conv outputs are merged in channel dim
17721725 SmallVector<int64_t > outputShapeOfConcat = {
17731726 1 , convOutputShape[1 ] * 4 , convOutputShape[2 ], convOutputShape[3 ]};
0 commit comments