Skip to content

Commit b14c9d3

Browse files
committed
Use rewriter methods instead of 'raw' IR manipulation. Bypassing the rewriter can lead to subtle bugs as listeners do not get notified.
Modify the insertion point and use a final topological sort to ensure the IR/graph order is valid, no matter how th einput orders are ordered and from which conv the recomposition starts. Signed-off-by: Rickert, Jonas <[email protected]>
1 parent b5e9386 commit b14c9d3

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

src/Dialect/ONNX/Transforms/Recompose.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include <numeric>
2424

25+
#include "mlir/Analysis/TopologicalSortUtils.h"
2526
#include "mlir/IR/PatternMatch.h"
2627
#include "mlir/Pass/Pass.h"
2728
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -75,7 +76,6 @@ ValueRange emitSplitByChannels(PatternRewriter &rewriter, Location loc,
7576
splitShape[axis] = size;
7677
resultTypes.push_back(RankedTensorType::get(splitShape, elementType));
7778
}
78-
rewriter.setInsertionPointAfter(input.getDefiningOp());
7979
// Perform Split Operation
8080
ValueRange results =
8181
create.onnx.split(ArrayRef(resultTypes), input, splitConstant, axis);
@@ -1182,6 +1182,8 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11821182
newOutputShape[concatAxis] = totalOutputChannels;
11831183
auto newOutputType = RankedTensorType::get(newOutputShape, elementType);
11841184

1185+
OpBuilder::InsertionGuard guard(rewriter);
1186+
rewriter.setInsertionPointAfter(*latestConv);
11851187
auto newConv =
11861188
rewriter.create<ONNXConvOp>(loc, newOutputType, input, newWeight,
11871189
newBias, convOp1.getAutoPadAttr(), convOp1.getDilationsAttr(),
@@ -1216,8 +1218,7 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
12161218

12171219
if (allOutputsUsedInCommonConcat && commonConcatOp &&
12181220
commonConcatOp.getAxis() == 1) {
1219-
commonConcatOp.getResult().replaceAllUsesWith(newConv.getResult());
1220-
rewriter.eraseOp(commonConcatOp);
1221+
rewriter.replaceOp(commonConcatOp, newConv);
12211222
} else {
12221223
SmallVector<int64_t> splitSizesVec;
12231224
for (auto conv : parallelConvs) {
@@ -1226,15 +1227,15 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
12261227
splitSizesVec.push_back(channels);
12271228
}
12281229

1229-
rewriter.setInsertionPointAfter(newConv);
12301230
ValueRange splitResults = onnx_mlir::emitSplitByChannels(
12311231
rewriter, loc, newConv.getResult(), splitSizesVec, concatAxis);
1232-
12331232
for (size_t i = 0; i < parallelConvs.size(); ++i) {
1234-
parallelConvs[i].getResult().replaceAllUsesWith(splitResults[i]);
1233+
rewriter.replaceAllOpUsesWith(parallelConvs[i], splitResults[i]);
12351234
}
1235+
// Sort the block topological, as the operations after the split may be in
1236+
// the wrong place otherwise
1237+
mlir::sortTopologically(newConv->getBlock());
12361238
}
1237-
12381239
for (auto conv : parallelConvs) {
12391240
rewriter.eraseOp(conv);
12401241
}

0 commit comments

Comments
 (0)