Skip to content

Commit edd9f5d

Browse files
committed
Merge remote-tracking branch 'origin/feature/onnx-to-tosa' into xiao.add_remove_qdq_aroundop
2 parents 5a200dd + a2be298 commit edd9f5d

File tree

6 files changed

+795
-957
lines changed

6 files changed

+795
-957
lines changed

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 35 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,10 @@ bool ShouldDecomposeConvTransposeOpToPhasedConvs(Value convTransposeResult,
851851
bool fourPhaseDecomposition = (stridesShape[0] == 2);
852852
bool ninePhaseDecomposition = (stridesShape[0] == 3);
853853
if (fourPhaseDecomposition) {
854+
if (outputShape[0] != 1) {
855+
// Currently support batch=1
856+
return false;
857+
}
854858
if (kernelShape[0] == 6 && padsShape[0] == 2 &&
855859
llvm::all_equal(padsShape)) {
856860
// Currently support only with pads [2, 2, 2, 2]
@@ -1713,102 +1717,56 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
17131717
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
17141718
stepOnnxConstant);
17151719
}
1716-
1717-
// The four convOutputs are adjusted to add an extra dimension at the
1718-
// innermost level.
1719-
SmallVector<int64_t> outputShapePlusOneDim(convOutputShape);
1720-
outputShapePlusOneDim.push_back(1);
1721-
auto onnxConstForReshapeAddOneDim =
1722-
getONNXConstOpFromVector(rewriter, loc, outputShapePlusOneDim);
1723-
1724-
auto reshapeOutputType =
1725-
RankedTensorType::get(outputShapePlusOneDim, elementType);
1726-
1727-
auto reshapeOutputAddOneDimConv1 = rewriter.create<ONNXReshapeOp>(
1728-
loc, reshapeOutputType, conv1, onnxConstForReshapeAddOneDim);
1729-
auto reshapeOutputAddOneDimConv2 = rewriter.create<ONNXReshapeOp>(
1730-
loc, reshapeOutputType, conv2, onnxConstForReshapeAddOneDim);
1731-
auto reshapeOutputAddOneDimConv3 = rewriter.create<ONNXReshapeOp>(
1732-
loc, reshapeOutputType, conv3, onnxConstForReshapeAddOneDim);
1733-
auto reshapeOutputAddOneDimConv4 = rewriter.create<ONNXReshapeOp>(
1734-
loc, reshapeOutputType, conv4, onnxConstForReshapeAddOneDim);
1735-
1736-
SmallVector<int64_t> outputShapeLevel1Concat(outputShapePlusOneDim);
1737-
outputShapeLevel1Concat[outputShapeLevel1Concat.size() - 1] = 2;
1738-
auto level1ConcatOutputType =
1739-
RankedTensorType::get(outputShapeLevel1Concat, elementType);
1720+
// Four conv outputs are merged in channel dim
1721+
SmallVector<int64_t> outputShapeOfConcat = {
1722+
1, convOutputShape[1] * 4, convOutputShape[2], convOutputShape[3]};
1723+
auto concatOutputType =
1724+
RankedTensorType::get(outputShapeOfConcat, elementType);
17401725
// for the case where convtranspose kernel is [4, 4] and with pads [1, 1, 1,
17411726
// 1] The phased convs output are to be concatenated in the reverse order.
17421727
// This is observed by looking at the phased conv outputs with respect to
17431728
// convtranspose output.
17441729
bool reverseConcatOrder = (needWeightsPadding || (kernelShape[0] == 4));
1745-
// Below concats result will have the innermost dim as 2.
1730+
// The concat output will have 4 times the channels of a single conv.
17461731
auto firstConcat =
17471732
(reverseConcatOrder)
1748-
? rewriter.create<ONNXConcatOp>(loc, level1ConcatOutputType,
1749-
ValueRange{
1750-
reshapeOutputAddOneDimConv3, reshapeOutputAddOneDimConv1},
1751-
-1)
1752-
: rewriter.create<ONNXConcatOp>(loc, level1ConcatOutputType,
1753-
ValueRange{
1754-
reshapeOutputAddOneDimConv1, reshapeOutputAddOneDimConv3},
1755-
-1);
1756-
auto secondConcat =
1757-
(reverseConcatOrder)
1758-
? rewriter.create<ONNXConcatOp>(loc, level1ConcatOutputType,
1759-
ValueRange{
1760-
reshapeOutputAddOneDimConv2, reshapeOutputAddOneDimConv4},
1761-
-1)
1762-
: rewriter.create<ONNXConcatOp>(loc, level1ConcatOutputType,
1763-
ValueRange{
1764-
reshapeOutputAddOneDimConv4, reshapeOutputAddOneDimConv2},
1765-
-1);
1733+
? rewriter.create<ONNXConcatOp>(loc, concatOutputType,
1734+
ValueRange{conv2, conv4, conv3, conv1}, 1)
1735+
: rewriter.create<ONNXConcatOp>(loc, concatOutputType,
1736+
ValueRange{conv1, conv3, conv4, conv2}, 1);
17661737

1767-
// Reshaping to modify the two innermost levels,ensuring the second
1768-
// innermost level is set to 1
1769-
SmallVector<int64_t> outputShapeForDimAdjust(convOutputShape);
1770-
auto dimValueAtLastIndex = convOutputShape[convOutputShape.size() - 1] * 2;
1771-
outputShapeForDimAdjust[outputShapeForDimAdjust.size() - 1] = 1;
1772-
outputShapeForDimAdjust.push_back(dimValueAtLastIndex);
1738+
// Here we are reshaping the concatenated conv channels of 4*Conv_channels
1739+
// into groups of 2x2 channels. This can be visualized as
1740+
// H_chan(2) * W_Chan(2) * C_real, then doing the transpose into
1741+
// Conv_channels H H_chan W W_chan. The adjecent H and H_chan will be merged
1742+
// into H, same way W and W_chan will be merged into W. This leads to
1743+
// doubling of the H and W. Keeping the channels same.
1744+
1745+
SmallVector<int64_t> outputShapeForDimAdjust = {
1746+
2, 2, convOutputShape[1], convOutputShape[2], convOutputShape[3]};
17731747

17741748
auto onnxConstForReshapeDimAdjust =
17751749
getONNXConstOpFromVector(rewriter, loc, outputShapeForDimAdjust);
17761750

17771751
auto reshapeOutputForDimAdjustType =
17781752
RankedTensorType::get(outputShapeForDimAdjust, elementType);
1779-
auto reshapeOutputDimAdjustOfFirstConcat =
1753+
auto reshapeOutputDimAdjust =
17801754
rewriter.create<ONNXReshapeOp>(loc, reshapeOutputForDimAdjustType,
17811755
firstConcat, onnxConstForReshapeDimAdjust);
1782-
auto reshapeOutputDimAdjustOfSecondConcat =
1783-
rewriter.create<ONNXReshapeOp>(loc, reshapeOutputForDimAdjustType,
1784-
secondConcat, onnxConstForReshapeDimAdjust);
17851756

1786-
SmallVector<int64_t> outputShapeForFinalConcat(outputShapeForDimAdjust);
1787-
outputShapeForFinalConcat[outputShapeForFinalConcat.size() - 2] = 2;
1757+
SmallVector<int64_t> transposeOuputShape = {
1758+
convOutputShape[1], convOutputShape[2], 2, convOutputShape[3], 2};
17881759

1789-
auto finalConcatOutputType =
1790-
RankedTensorType::get(outputShapeForFinalConcat, elementType);
1760+
auto transposeOutputType =
1761+
RankedTensorType::get(transposeOuputShape, elementType);
17911762

1792-
// Final Concat is performed on the two reshaped outputs at the
1793-
// second innermost level
1794-
auto finalConcat =
1795-
(reverseConcatOrder)
1796-
? rewriter.create<ONNXConcatOp>(loc, finalConcatOutputType,
1797-
ValueRange{reshapeOutputDimAdjustOfSecondConcat,
1798-
reshapeOutputDimAdjustOfFirstConcat},
1799-
-2)
1800-
: rewriter.create<ONNXConcatOp>(loc, finalConcatOutputType,
1801-
ValueRange{reshapeOutputDimAdjustOfFirstConcat,
1802-
reshapeOutputDimAdjustOfSecondConcat},
1803-
-2);
1763+
auto permArrayAttr = rewriter.getI64ArrayAttr({2, 3, 0, 4, 1});
18041764

1805-
SmallVector<int64_t> outputShapeForResult(convOutputShape);
1806-
dimValueAtLastIndex = convOutputShape[convOutputShape.size() - 1] * 2;
1807-
auto dimValueAtSecondLastIndex =
1808-
convOutputShape[convOutputShape.size() - 2] * 2;
1809-
outputShapeForResult[outputShapeForResult.size() - 2] =
1810-
dimValueAtSecondLastIndex;
1811-
outputShapeForResult[outputShapeForResult.size() - 1] = dimValueAtLastIndex;
1765+
auto transpose = rewriter.create<ONNXTransposeOp>(
1766+
loc, transposeOutputType, reshapeOutputDimAdjust, permArrayAttr);
1767+
1768+
SmallVector<int64_t> outputShapeForResult = {
1769+
1, convOutputShape[1], convOutputShape[2] * 2, convOutputShape[3] * 2};
18121770

18131771
auto onnxConstForLastReshape =
18141772
getONNXConstOpFromVector(rewriter, loc, outputShapeForResult);
@@ -1818,7 +1776,7 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
18181776
// Result is reshaped back to match the original convtranspose output
18191777
// dimensions
18201778
auto finalOutput = rewriter.create<ONNXReshapeOp>(
1821-
loc, finalOutputType, finalConcat, onnxConstForLastReshape);
1779+
loc, finalOutputType, transpose, onnxConstForLastReshape);
18221780
return finalOutput;
18231781
}
18241782
if (numPhases == 9) {

src/Dialect/ONNX/Transforms/Recompose.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1315,7 +1315,7 @@ void onnx_mlir::getRecomposeONNXToONNXPatterns(
13151315
patterns.insert<RecomposeDepthToSpaceDCR>(context);
13161316
// AMD Disabled as downstream has no special support for it
13171317
// patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
1318-
patterns.insert<CombineParallelConv2DPattern>(context);
1318+
// patterns.insert<CombineParallelConv2DPattern>(context);
13191319
}
13201320

13211321
/*!

0 commit comments

Comments
 (0)