@@ -1648,7 +1648,7 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
16481648 };
16491649 auto stridesArrayAttr = rewriter.getI64ArrayAttr ({1 , 1 });
16501650 Value conv;
1651- if (needWeightsPadding) {
1651+ if (needWeightsPadding || (kernelShape[ 0 ] == 4 ) ) {
16521652 Value conv1 = getActivationAppliedToConv (
16531653 addQDQNodesForActivationIfNeeded (rewriter.create <ONNXConvOp>(loc,
16541654 convOutputType, input, addDequantizeNodeIfNeeded (weightSlices[3 ]),
@@ -1683,42 +1683,44 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
16831683 convOutputType);
16841684 // Need to remove excess the ofm when weights are padded.
16851685
1686- auto startOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {1 , 1 });
1687- auto endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
1688- {convOutputShape[convOutputShape.size () - 2 ] + 2 ,
1689- convOutputShape[convOutputShape.size () - 1 ] + 2 });
1690- auto axisOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {2 , 3 });
1691- auto stepOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {1 , 1 });
1692- auto convSliceOutputType = RankedTensorType::get (
1693- convOutputShape, convTransposeOutputType.getElementType ());
1694- conv1 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv1,
1695- startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1696- stepOnnxConstant);
1697-
1698- startOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {0 , 0 });
1699- endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
1700- {convOutputShape[convOutputShape.size () - 2 ],
1701- convOutputShape[convOutputShape.size () - 1 ]});
1702- conv2 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv2,
1703- startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1704- stepOnnxConstant);
1705-
1706- startOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {1 , 0 });
1707- endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
1708- {convOutputShape[convOutputShape.size () - 2 ] + 2 ,
1709- convOutputShape[convOutputShape.size () - 1 ]});
1710- conv3 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv3,
1711- startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1712- stepOnnxConstant);
1713-
1714- startOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {0 , 1 });
1715- endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
1716- {convOutputShape[convOutputShape.size () - 2 ],
1717- convOutputShape[convOutputShape.size () - 1 ] + 2 });
1718- conv4 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv4,
1719- startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1720- stepOnnxConstant);
1721-
1686+ if (needWeightsPadding) {
1687+ auto startOnnxConstant =
1688+ getONNXConstOpFromVector (rewriter, loc, {1 , 1 });
1689+ auto endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
1690+ {convOutputShape[convOutputShape.size () - 2 ] + 2 ,
1691+ convOutputShape[convOutputShape.size () - 1 ] + 2 });
1692+ auto axisOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {2 , 3 });
1693+ auto stepOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {1 , 1 });
1694+ auto convSliceOutputType = RankedTensorType::get (
1695+ convOutputShape, convTransposeOutputType.getElementType ());
1696+ conv1 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv1,
1697+ startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1698+ stepOnnxConstant);
1699+
1700+ startOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {0 , 0 });
1701+ endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
1702+ {convOutputShape[convOutputShape.size () - 2 ],
1703+ convOutputShape[convOutputShape.size () - 1 ]});
1704+ conv2 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv2,
1705+ startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1706+ stepOnnxConstant);
1707+
1708+ startOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {1 , 0 });
1709+ endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
1710+ {convOutputShape[convOutputShape.size () - 2 ] + 2 ,
1711+ convOutputShape[convOutputShape.size () - 1 ]});
1712+ conv3 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv3,
1713+ startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1714+ stepOnnxConstant);
1715+
1716+ startOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {0 , 1 });
1717+ endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
1718+ {convOutputShape[convOutputShape.size () - 2 ],
1719+ convOutputShape[convOutputShape.size () - 1 ] + 2 });
1720+ conv4 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv4,
1721+ startOnnxConstant, endOnnxConstant, axisOnnxConstant,
1722+ stepOnnxConstant);
1723+ }
17221724 // Four conv outputs are merged in channel dim
17231725 SmallVector<int64_t > outputShapeOfConcat = {
17241726 1 , convOutputShape[1 ] * 4 , convOutputShape[2 ], convOutputShape[3 ]};
0 commit comments