@@ -1647,41 +1647,43 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
16471647 }
16481648 };
16491649 auto stridesArrayAttr = rewriter.getI64ArrayAttr ({1 , 1 });
1650-
1651- Value conv1 = getActivationAppliedToConv (
1652- addQDQNodesForActivationIfNeeded (
1653- rewriter.create <ONNXConvOp>(loc, convOutputType, input,
1654- addDequantizeNodeIfNeeded (weightSlices[3 ]), bias,
1655- mlir::StringAttr (), dilations, group, convKernelShapeArrayAttr,
1656- getPadsArrayAttr (kernelShape[0 ], 1 , needWeightsPadding),
1657- stridesArrayAttr)),
1658- convOutputType);
1659- Value conv2 = getActivationAppliedToConv (
1660- addQDQNodesForActivationIfNeeded (
1661- rewriter.create <ONNXConvOp>(loc, convOutputType, input,
1662- addDequantizeNodeIfNeeded (weightSlices[0 ]), bias,
1663- mlir::StringAttr (), dilations, group, convKernelShapeArrayAttr,
1664- getPadsArrayAttr (kernelShape[0 ], 2 , needWeightsPadding),
1665- stridesArrayAttr)),
1666- convOutputType);
1667- Value conv3 = getActivationAppliedToConv (
1668- addQDQNodesForActivationIfNeeded (
1669- rewriter.create <ONNXConvOp>(loc, convOutputType, input,
1670- addDequantizeNodeIfNeeded (weightSlices[1 ]), bias,
1671- mlir::StringAttr (), dilations, group, convKernelShapeArrayAttr,
1672- getPadsArrayAttr (kernelShape[0 ], 3 , needWeightsPadding),
1673- stridesArrayAttr)),
1674- convOutputType);
1675- Value conv4 = getActivationAppliedToConv (
1676- addQDQNodesForActivationIfNeeded (
1677- rewriter.create <ONNXConvOp>(loc, convOutputType, input,
1678- addDequantizeNodeIfNeeded (weightSlices[2 ]), bias,
1679- mlir::StringAttr (), dilations, group, convKernelShapeArrayAttr,
1680- getPadsArrayAttr (kernelShape[0 ], 4 , needWeightsPadding),
1681- stridesArrayAttr)),
1682- convOutputType);
1683- // Need to remove excess the ofm when weights are padded.
1650+ Value conv;
16841651 if (needWeightsPadding) {
1652+ // Combining the 4 phased weights into single weight.
1653+ Value conv1 = getActivationAppliedToConv (
1654+ addQDQNodesForActivationIfNeeded (rewriter.create <ONNXConvOp>(loc,
1655+ convOutputType, input, addDequantizeNodeIfNeeded (weightSlices[3 ]),
1656+ bias, mlir::StringAttr (), dilations, group,
1657+ convKernelShapeArrayAttr,
1658+ getPadsArrayAttr (kernelShape[0 ], 1 , needWeightsPadding),
1659+ stridesArrayAttr)),
1660+ convOutputType);
1661+ Value conv2 = getActivationAppliedToConv (
1662+ addQDQNodesForActivationIfNeeded (rewriter.create <ONNXConvOp>(loc,
1663+ convOutputType, input, addDequantizeNodeIfNeeded (weightSlices[0 ]),
1664+ bias, mlir::StringAttr (), dilations, group,
1665+ convKernelShapeArrayAttr,
1666+ getPadsArrayAttr (kernelShape[0 ], 2 , needWeightsPadding),
1667+ stridesArrayAttr)),
1668+ convOutputType);
1669+ Value conv3 = getActivationAppliedToConv (
1670+ addQDQNodesForActivationIfNeeded (rewriter.create <ONNXConvOp>(loc,
1671+ convOutputType, input, addDequantizeNodeIfNeeded (weightSlices[1 ]),
1672+ bias, mlir::StringAttr (), dilations, group,
1673+ convKernelShapeArrayAttr,
1674+ getPadsArrayAttr (kernelShape[0 ], 3 , needWeightsPadding),
1675+ stridesArrayAttr)),
1676+ convOutputType);
1677+ Value conv4 = getActivationAppliedToConv (
1678+ addQDQNodesForActivationIfNeeded (rewriter.create <ONNXConvOp>(loc,
1679+ convOutputType, input, addDequantizeNodeIfNeeded (weightSlices[2 ]),
1680+ bias, mlir::StringAttr (), dilations, group,
1681+ convKernelShapeArrayAttr,
1682+ getPadsArrayAttr (kernelShape[0 ], 4 , needWeightsPadding),
1683+ stridesArrayAttr)),
1684+ convOutputType);
1685+ // Need to remove excess the ofm when weights are padded.
1686+
16851687 auto startOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {1 , 1 });
16861688 auto endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
16871689 {convOutputShape[convOutputShape.size () - 2 ] + 2 ,
@@ -1717,24 +1719,72 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
17171719 conv4 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv4,
17181720 startOnnxConstant, endOnnxConstant, axisOnnxConstant,
17191721 stepOnnxConstant);
1722+
1723+ // Four conv outputs are merged in channel dim
1724+ SmallVector<int64_t > outputShapeOfConcat = {
1725+ 1 , convOutputShape[1 ] * 4 , convOutputShape[2 ], convOutputShape[3 ]};
1726+ auto concatOutputType =
1727+ RankedTensorType::get (outputShapeOfConcat, elementType);
1728+ // for the case where convtranspose kernel is [4, 4] and with pads [1, 1,
1729+ // 1, 1] The phased convs output are to be concatenated in the reverse
1730+ // order. This is observed by looking at the phased conv outputs with
1731+ // respect to convtranspose output.
1732+ bool reverseConcatOrder = (needWeightsPadding || (kernelShape[0 ] == 4 ));
1733+ // The concat output will have 4 times the channels of a single conv.
1734+ conv = (reverseConcatOrder)
1735+ ? rewriter.create <ONNXConcatOp>(loc, concatOutputType,
1736+ ValueRange{conv2, conv4, conv3, conv1}, 1 )
1737+ : rewriter.create <ONNXConcatOp>(loc, concatOutputType,
1738+ ValueRange{conv1, conv3, conv4, conv2}, 1 );
1739+ } else {
1740+ bool reverseOrder = (kernelShape[0 ] == 4 );
1741+ auto combinedConvWeightsShapedType =
1742+ weightsType.get ({weightsShape[0 ] * 4 , weightsShape[1 ], convKernelSize,
1743+ convKernelSize},
1744+ weightsType.getElementType ());
1745+
1746+ Value combinedWeights =
1747+ (reverseOrder) ? rewriter.create <ONNXConcatOp>(loc,
1748+ combinedConvWeightsShapedType,
1749+ ValueRange{weightSlices[0 ], weightSlices[2 ],
1750+ weightSlices[1 ], weightSlices[3 ]},
1751+ 0 )
1752+ : rewriter.create <ONNXConcatOp>(loc,
1753+ combinedConvWeightsShapedType,
1754+ ValueRange{weightSlices[3 ], weightSlices[1 ],
1755+ weightSlices[2 ], weightSlices[0 ]},
1756+ 0 );
1757+
1758+ if (!bias.getDefiningOp <ONNXNoneOp>()) {
1759+ RankedTensorType biasType =
1760+ mlir::cast<RankedTensorType>(bias.getType ());
1761+ auto biasShape = biasType.getShape ();
1762+
1763+ auto combinedBiasShapedType =
1764+ biasType.get ({biasShape[0 ] * 4 }, biasType.getElementType ());
1765+
1766+ bias = rewriter.create <ONNXConcatOp>(
1767+ loc, combinedBiasShapedType, ValueRange{bias, bias, bias, bias}, 0 );
1768+ }
1769+
1770+ auto combinedConvOutputType = RankedTensorType::get (
1771+ (needWeightsPadding)
1772+ ? SmallVector<int64_t >(
1773+ {convOutputShape[0 ], convOutputShape[1 ] * 4 ,
1774+ convOutputShape[2 ] + 1 , convOutputShape[3 ] + 1 })
1775+ : SmallVector<int64_t >(
1776+ {convOutputShape[0 ], convOutputShape[1 ] * 4 ,
1777+ convOutputShape[2 ], convOutputShape[3 ]}),
1778+ convTransposeOutputType.getElementType ());
1779+ conv = getActivationAppliedToConv (
1780+ addQDQNodesForActivationIfNeeded (rewriter.create <ONNXConvOp>(loc,
1781+ combinedConvOutputType, input,
1782+ addDequantizeNodeIfNeeded (combinedWeights), bias,
1783+ mlir::StringAttr (), dilations, group, convKernelShapeArrayAttr,
1784+ getPadsArrayAttr (kernelShape[0 ], 1 , needWeightsPadding),
1785+ stridesArrayAttr)),
1786+ combinedConvOutputType);
17201787 }
1721- // Four conv outputs are merged in channel dim
1722- SmallVector<int64_t > outputShapeOfConcat = {
1723- 1 , convOutputShape[1 ] * 4 , convOutputShape[2 ], convOutputShape[3 ]};
1724- auto concatOutputType =
1725- RankedTensorType::get (outputShapeOfConcat, elementType);
1726- // for the case where convtranspose kernel is [4, 4] and with pads [1, 1, 1,
1727- // 1] The phased convs output are to be concatenated in the reverse order.
1728- // This is observed by looking at the phased conv outputs with respect to
1729- // convtranspose output.
1730- bool reverseConcatOrder = (needWeightsPadding || (kernelShape[0 ] == 4 ));
1731- // The concat output will have 4 times the channels of a single conv.
1732- auto firstConcat =
1733- (reverseConcatOrder)
1734- ? rewriter.create <ONNXConcatOp>(loc, concatOutputType,
1735- ValueRange{conv2, conv4, conv3, conv1}, 1 )
1736- : rewriter.create <ONNXConcatOp>(loc, concatOutputType,
1737- ValueRange{conv1, conv3, conv4, conv2}, 1 );
17381788
17391789 // Here we are reshaping the concatenated conv channels of 4*Conv_channels
17401790 // into groups of 2x2 channels. This can be visualized as
@@ -1751,9 +1801,8 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
17511801
17521802 auto reshapeOutputForDimAdjustType =
17531803 RankedTensorType::get (outputShapeForDimAdjust, elementType);
1754- auto reshapeOutputDimAdjust =
1755- rewriter.create <ONNXReshapeOp>(loc, reshapeOutputForDimAdjustType,
1756- firstConcat, onnxConstForReshapeDimAdjust);
1804+ auto reshapeOutputDimAdjust = rewriter.create <ONNXReshapeOp>(
1805+ loc, reshapeOutputForDimAdjustType, conv, onnxConstForReshapeDimAdjust);
17571806
17581807 SmallVector<int64_t > transposeOuputShape = {
17591808 convOutputShape[1 ], convOutputShape[2 ], 2 , convOutputShape[3 ], 2 };
0 commit comments