2222
2323#include < numeric>
2424
25+ #include " mlir/Analysis/TopologicalSortUtils.h"
2526#include " mlir/IR/PatternMatch.h"
2627#include " mlir/Pass/Pass.h"
2728#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -75,7 +76,6 @@ ValueRange emitSplitByChannels(PatternRewriter &rewriter, Location loc,
7576 splitShape[axis] = size;
7677 resultTypes.push_back (RankedTensorType::get (splitShape, elementType));
7778 }
78- rewriter.setInsertionPointAfter (input.getDefiningOp ());
7979 // Perform Split Operation
8080 ValueRange results =
8181 create.onnx .split (ArrayRef (resultTypes), input, splitConstant, axis);
@@ -1057,6 +1057,10 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
10571057 return rewriter.notifyMatchFailure (
10581058 convOp1, " input must be a ranked tensor with static shape" );
10591059
1060+ if (!cast<ShapedType>(convOp1.getType ()).hasStaticShape ())
1061+ return rewriter.notifyMatchFailure (
1062+ convOp1, " output type must be a ranked tensor with static shape" );
1063+
10601064 // Collect all ONNXConvOps using this input.
10611065 SmallVector<ONNXConvOp> candidateConvs;
10621066 for (auto user : input.getUsers ()) {
@@ -1084,6 +1088,55 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
10841088
10851089 SmallVector<ONNXConvOp> parallelConvs = candidateConvs;
10861090
1091+ SmallVector<Value> weightValues;
1092+ int64_t totalOutputChannels = 0 ;
1093+ for (auto conv : parallelConvs) {
1094+ auto weightType = mlir::cast<ShapedType>(conv.getW ().getType ());
1095+ if (!weightType.hasStaticShape ())
1096+ return rewriter.notifyMatchFailure (
1097+ conv, " weight must be a ranked tensor with static shape" );
1098+ if (!cast<ShapedType>(conv.getType ()).hasStaticShape ())
1099+ return rewriter.notifyMatchFailure (
1100+ conv, " output type must be a ranked tensor with static shape" );
1101+ weightValues.push_back (conv.getW ());
1102+ totalOutputChannels += weightType.getShape ()[0 ];
1103+ }
1104+
1105+ auto *latestConv =
1106+ llvm::max_element (parallelConvs, [](ONNXConvOp a, ONNXConvOp b) {
1107+ return a->isBeforeInBlock (b.getOperation ());
1108+ });
1109+
1110+ const auto checkIfOtherConvsReachable = [&](ONNXConvOp conv) {
1111+ SmallVector<Operation *> worklist;
1112+ DenseSet<Operation *> visited;
1113+ worklist.push_back (conv.getOperation ());
1114+ while (!worklist.empty ()) {
1115+ Operation *current = worklist.back ();
1116+ worklist.pop_back ();
1117+
1118+ for (auto *user : current->getUsers ()) {
1119+ if (auto otherConv = dyn_cast<ONNXConvOp>(user)) {
1120+ if (llvm::is_contained (parallelConvs, otherConv)) {
1121+ // Found another conv that is part of the parallel convs.
1122+ return true ;
1123+ }
1124+ }
1125+ if (visited.insert (user).second &&
1126+ user->isBeforeInBlock (*latestConv)) {
1127+ worklist.push_back (user);
1128+ }
1129+ };
1130+ }
1131+ return false ;
1132+ };
1133+ // Ensure all convolutions are really parallel, none of then can be part of
1134+ // the input of another convolution
1135+ if (llvm::any_of (parallelConvs, checkIfOtherConvsReachable)) {
1136+ return rewriter.notifyMatchFailure (
1137+ convOp1, " conv ops are not parallel (reachable from each other)" );
1138+ }
1139+
10871140 bool allHaveBias = !mlir::isa<NoneType>(parallelConvs[0 ].getB ().getType ());
10881141
10891142 Location loc = convOp1.getLoc ();
@@ -1097,14 +1150,6 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
10971150
10981151 int64_t concatAxis = 1 ;
10991152
1100- SmallVector<Value> weightValues;
1101- int64_t totalOutputChannels = 0 ;
1102- for (auto conv : parallelConvs) {
1103- auto weightType = mlir::cast<ShapedType>(conv.getW ().getType ());
1104- weightValues.push_back (conv.getW ());
1105- totalOutputChannels += weightType.getShape ()[0 ];
1106- }
1107-
11081153 auto firstWeightType =
11091154 mlir::cast<ShapedType>(parallelConvs[0 ].getW ().getType ());
11101155 SmallVector<int64_t > newWeightShape (
@@ -1137,6 +1182,8 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11371182 newOutputShape[concatAxis] = totalOutputChannels;
11381183 auto newOutputType = RankedTensorType::get (newOutputShape, elementType);
11391184
1185+ OpBuilder::InsertionGuard guard (rewriter);
1186+ rewriter.setInsertionPointAfter (*latestConv);
11401187 auto newConv =
11411188 rewriter.create <ONNXConvOp>(loc, newOutputType, input, newWeight,
11421189 newBias, convOp1.getAutoPadAttr (), convOp1.getDilationsAttr (),
@@ -1171,8 +1218,7 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11711218
11721219 if (allOutputsUsedInCommonConcat && commonConcatOp &&
11731220 commonConcatOp.getAxis () == 1 ) {
1174- commonConcatOp.getResult ().replaceAllUsesWith (newConv.getResult ());
1175- rewriter.eraseOp (commonConcatOp);
1221+ rewriter.replaceOp (commonConcatOp, newConv);
11761222 } else {
11771223 SmallVector<int64_t > splitSizesVec;
11781224 for (auto conv : parallelConvs) {
@@ -1181,15 +1227,15 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11811227 splitSizesVec.push_back (channels);
11821228 }
11831229
1184- rewriter.setInsertionPointAfter (newConv);
11851230 ValueRange splitResults = onnx_mlir::emitSplitByChannels (
11861231 rewriter, loc, newConv.getResult (), splitSizesVec, concatAxis);
1187-
11881232 for (size_t i = 0 ; i < parallelConvs.size (); ++i) {
1189- parallelConvs[i]. getResult (). replaceAllUsesWith ( splitResults[i]);
1233+ rewriter. replaceAllOpUsesWith ( parallelConvs[i], splitResults[i]);
11901234 }
1235+ // Sort the block topological, as the operations after the split may be in
1236+ // the wrong place otherwise
1237+ mlir::sortTopologically (newConv->getBlock ());
11911238 }
1192-
11931239 for (auto conv : parallelConvs) {
11941240 rewriter.eraseOp (conv);
11951241 }
@@ -1273,8 +1319,7 @@ void onnx_mlir::getRecomposeONNXToONNXPatterns(
12731319 patterns.insert <RecomposeDepthToSpaceDCR>(context);
12741320 // AMD Disabled as downstream has no special support for it
12751321 // patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
1276- // AMD Temporary disabled as this pattern is buggy.
1277- // patterns.insert<CombineParallelConv2DPattern>(context);
1322+ patterns.insert <CombineParallelConv2DPattern>(context);
12781323}
12791324
12801325/* !
0 commit comments