@@ -1647,41 +1647,42 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
16471647 }
16481648 };
16491649 auto stridesArrayAttr = rewriter.getI64ArrayAttr ({1 , 1 });
1650-
1651- Value conv1 = getActivationAppliedToConv (
1652- addQDQNodesForActivationIfNeeded (
1653- rewriter.create <ONNXConvOp>(loc, convOutputType, input,
1654- addDequantizeNodeIfNeeded (weightSlices[3 ]), bias,
1655- mlir::StringAttr (), dilations, group, convKernelShapeArrayAttr,
1656- getPadsArrayAttr (kernelShape[0 ], 1 , needWeightsPadding),
1657- stridesArrayAttr)),
1658- convOutputType);
1659- Value conv2 = getActivationAppliedToConv (
1660- addQDQNodesForActivationIfNeeded (
1661- rewriter.create <ONNXConvOp>(loc, convOutputType, input,
1662- addDequantizeNodeIfNeeded (weightSlices[0 ]), bias,
1663- mlir::StringAttr (), dilations, group, convKernelShapeArrayAttr,
1664- getPadsArrayAttr (kernelShape[0 ], 2 , needWeightsPadding),
1665- stridesArrayAttr)),
1666- convOutputType);
1667- Value conv3 = getActivationAppliedToConv (
1668- addQDQNodesForActivationIfNeeded (
1669- rewriter.create <ONNXConvOp>(loc, convOutputType, input,
1670- addDequantizeNodeIfNeeded (weightSlices[1 ]), bias,
1671- mlir::StringAttr (), dilations, group, convKernelShapeArrayAttr,
1672- getPadsArrayAttr (kernelShape[0 ], 3 , needWeightsPadding),
1673- stridesArrayAttr)),
1674- convOutputType);
1675- Value conv4 = getActivationAppliedToConv (
1676- addQDQNodesForActivationIfNeeded (
1677- rewriter.create <ONNXConvOp>(loc, convOutputType, input,
1678- addDequantizeNodeIfNeeded (weightSlices[2 ]), bias,
1679- mlir::StringAttr (), dilations, group, convKernelShapeArrayAttr,
1680- getPadsArrayAttr (kernelShape[0 ], 4 , needWeightsPadding),
1681- stridesArrayAttr)),
1682- convOutputType);
1683- // Need to remove excess the ofm when weights are padded.
1650+ Value conv;
16841651 if (needWeightsPadding) {
1652+ Value conv1 = getActivationAppliedToConv (
1653+ addQDQNodesForActivationIfNeeded (rewriter.create <ONNXConvOp>(loc,
1654+ convOutputType, input, addDequantizeNodeIfNeeded (weightSlices[3 ]),
1655+ bias, mlir::StringAttr (), dilations, group,
1656+ convKernelShapeArrayAttr,
1657+ getPadsArrayAttr (kernelShape[0 ], 1 , needWeightsPadding),
1658+ stridesArrayAttr)),
1659+ convOutputType);
1660+ Value conv2 = getActivationAppliedToConv (
1661+ addQDQNodesForActivationIfNeeded (rewriter.create <ONNXConvOp>(loc,
1662+ convOutputType, input, addDequantizeNodeIfNeeded (weightSlices[0 ]),
1663+ bias, mlir::StringAttr (), dilations, group,
1664+ convKernelShapeArrayAttr,
1665+ getPadsArrayAttr (kernelShape[0 ], 2 , needWeightsPadding),
1666+ stridesArrayAttr)),
1667+ convOutputType);
1668+ Value conv3 = getActivationAppliedToConv (
1669+ addQDQNodesForActivationIfNeeded (rewriter.create <ONNXConvOp>(loc,
1670+ convOutputType, input, addDequantizeNodeIfNeeded (weightSlices[1 ]),
1671+ bias, mlir::StringAttr (), dilations, group,
1672+ convKernelShapeArrayAttr,
1673+ getPadsArrayAttr (kernelShape[0 ], 3 , needWeightsPadding),
1674+ stridesArrayAttr)),
1675+ convOutputType);
1676+ Value conv4 = getActivationAppliedToConv (
1677+ addQDQNodesForActivationIfNeeded (rewriter.create <ONNXConvOp>(loc,
1678+ convOutputType, input, addDequantizeNodeIfNeeded (weightSlices[2 ]),
1679+ bias, mlir::StringAttr (), dilations, group,
1680+ convKernelShapeArrayAttr,
1681+ getPadsArrayAttr (kernelShape[0 ], 4 , needWeightsPadding),
1682+ stridesArrayAttr)),
1683+ convOutputType);
1684+ // Need to remove excess the ofm when weights are padded.
1685+
16851686 auto startOnnxConstant = getONNXConstOpFromVector (rewriter, loc, {1 , 1 });
16861687 auto endOnnxConstant = getONNXConstOpFromVector (rewriter, loc,
16871688 {convOutputShape[convOutputShape.size () - 2 ] + 2 ,
@@ -1717,24 +1718,68 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
17171718 conv4 = rewriter.create <ONNXSliceOp>(loc, convSliceOutputType, conv4,
17181719 startOnnxConstant, endOnnxConstant, axisOnnxConstant,
17191720 stepOnnxConstant);
1721+
1722+ // Four conv outputs are merged in channel dim
1723+ SmallVector<int64_t > outputShapeOfConcat = {
1724+ 1 , convOutputShape[1 ] * 4 , convOutputShape[2 ], convOutputShape[3 ]};
1725+ auto concatOutputType =
1726+ RankedTensorType::get (outputShapeOfConcat, elementType);
1727+ // for the case where convtranspose kernel is [4, 4] and with pads [1, 1,
1728+ // 1, 1] The phased convs output are to be concatenated in the reverse
1729+ // order. This is observed by looking at the phased conv outputs with
1730+ // respect to convtranspose output.
1731+ bool reverseConcatOrder = (needWeightsPadding || (kernelShape[0 ] == 4 ));
1732+ // The concat output will have 4 times the channels of a single conv.
1733+ conv = (reverseConcatOrder)
1734+ ? rewriter.create <ONNXConcatOp>(loc, concatOutputType,
1735+ ValueRange{conv2, conv4, conv3, conv1}, 1 )
1736+ : rewriter.create <ONNXConcatOp>(loc, concatOutputType,
1737+ ValueRange{conv1, conv3, conv4, conv2}, 1 );
1738+ } else {
1739+ // Combining the 4 phased weights into single weight.
1740+ bool reverseOrder = (kernelShape[0 ] == 4 );
1741+ auto combinedConvWeightsShapedType =
1742+ weightsType.get ({weightsShape[0 ] * 4 , weightsShape[1 ], convKernelSize,
1743+ convKernelSize},
1744+ weightsType.getElementType ());
1745+
1746+ Value combinedWeights =
1747+ (reverseOrder) ? rewriter.create <ONNXConcatOp>(loc,
1748+ combinedConvWeightsShapedType,
1749+ ValueRange{weightSlices[0 ], weightSlices[2 ],
1750+ weightSlices[1 ], weightSlices[3 ]},
1751+ 0 )
1752+ : rewriter.create <ONNXConcatOp>(loc,
1753+ combinedConvWeightsShapedType,
1754+ ValueRange{weightSlices[3 ], weightSlices[1 ],
1755+ weightSlices[2 ], weightSlices[0 ]},
1756+ 0 );
1757+
1758+ if (!bias.getDefiningOp <ONNXNoneOp>()) {
1759+ RankedTensorType biasType =
1760+ mlir::cast<RankedTensorType>(bias.getType ());
1761+ auto biasShape = biasType.getShape ();
1762+
1763+ auto combinedBiasShapedType =
1764+ biasType.get ({biasShape[0 ] * 4 }, biasType.getElementType ());
1765+
1766+ bias = rewriter.create <ONNXConcatOp>(
1767+ loc, combinedBiasShapedType, ValueRange{bias, bias, bias, bias}, 0 );
1768+ }
1769+
1770+ auto combinedConvOutputType = RankedTensorType::get (
1771+ SmallVector<int64_t >({convOutputShape[0 ], convOutputShape[1 ] * 4 ,
1772+ convOutputShape[2 ], convOutputShape[3 ]}),
1773+ convTransposeOutputType.getElementType ());
1774+ conv = getActivationAppliedToConv (
1775+ addQDQNodesForActivationIfNeeded (rewriter.create <ONNXConvOp>(loc,
1776+ combinedConvOutputType, input,
1777+ addDequantizeNodeIfNeeded (combinedWeights), bias,
1778+ mlir::StringAttr (), dilations, group, convKernelShapeArrayAttr,
1779+ getPadsArrayAttr (kernelShape[0 ], 1 , needWeightsPadding),
1780+ stridesArrayAttr)),
1781+ combinedConvOutputType);
17201782 }
1721- // Four conv outputs are merged in channel dim
1722- SmallVector<int64_t > outputShapeOfConcat = {
1723- 1 , convOutputShape[1 ] * 4 , convOutputShape[2 ], convOutputShape[3 ]};
1724- auto concatOutputType =
1725- RankedTensorType::get (outputShapeOfConcat, elementType);
1726- // for the case where convtranspose kernel is [4, 4] and with pads [1, 1, 1,
1727- // 1] The phased convs output are to be concatenated in the reverse order.
1728- // This is observed by looking at the phased conv outputs with respect to
1729- // convtranspose output.
1730- bool reverseConcatOrder = (needWeightsPadding || (kernelShape[0 ] == 4 ));
1731- // The concat output will have 4 times the channels of a single conv.
1732- auto firstConcat =
1733- (reverseConcatOrder)
1734- ? rewriter.create <ONNXConcatOp>(loc, concatOutputType,
1735- ValueRange{conv2, conv4, conv3, conv1}, 1 )
1736- : rewriter.create <ONNXConcatOp>(loc, concatOutputType,
1737- ValueRange{conv1, conv3, conv4, conv2}, 1 );
17381783
17391784 // Here we are reshaping the concatenated conv channels of 4*Conv_channels
17401785 // into groups of 2x2 channels. This can be visualized as
@@ -1751,9 +1796,8 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
17511796
17521797 auto reshapeOutputForDimAdjustType =
17531798 RankedTensorType::get (outputShapeForDimAdjust, elementType);
1754- auto reshapeOutputDimAdjust =
1755- rewriter.create <ONNXReshapeOp>(loc, reshapeOutputForDimAdjustType,
1756- firstConcat, onnxConstForReshapeDimAdjust);
1799+ auto reshapeOutputDimAdjust = rewriter.create <ONNXReshapeOp>(
1800+ loc, reshapeOutputForDimAdjustType, conv, onnxConstForReshapeDimAdjust);
17571801
17581802 SmallVector<int64_t > transposeOuputShape = {
17591803 convOutputShape[1 ], convOutputShape[2 ], 2 , convOutputShape[3 ], 2 };
0 commit comments