@@ -1086,7 +1086,8 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
10861086
10871087 SmallVector<ONNXConvOp> parallelConvs = candidateConvs;
10881088
1089- SmallVector<Value> weightValues;
1089+ llvm::MapVector<ONNXConvOp, Value>
1090+ weightValues; // MapVector to keep the iteration order stable
10901091 int64_t totalOutputChannels = 0 ;
10911092 for (auto conv : parallelConvs) {
10921093 auto weightType = mlir::cast<ShapedType>(conv.getW ().getType ());
@@ -1096,7 +1097,7 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
10961097 if (!cast<ShapedType>(conv.getType ()).hasStaticShape ())
10971098 return rewriter.notifyMatchFailure (
10981099 conv, " output type must be a ranked tensor with static shape" );
1099- weightValues. push_back ( conv.getW () );
1100+ weightValues[ conv] = conv .getW ();
11001101 totalOutputChannels += weightType.getShape ()[0 ];
11011102 }
11021103
@@ -1148,7 +1149,38 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11481149 OpBuilder::InsertionGuard guard (rewriter);
11491150 rewriter.setInsertionPointAfter (*latestConv);
11501151
1151- int64_t concatAxis = 1 ;
1152+ ONNXConcatOp commonConcatOp = nullptr ;
1153+ bool allOutputsUsedInCommonConcat = true ;
1154+
1155+ std::map<size_t , ONNXConvOp>
1156+ convOrder; // Key: Operand index in the common concat
1157+
1158+ for (auto conv : parallelConvs) {
1159+ if (!conv.getResult ().hasOneUse ()) {
1160+ allOutputsUsedInCommonConcat = false ;
1161+ break ;
1162+ }
1163+ for (auto &use : conv.getResult ().getUses ()) {
1164+ if (auto concatOp = dyn_cast<ONNXConcatOp>(use.getOwner ())) {
1165+ if (!commonConcatOp) {
1166+ commonConcatOp = concatOp;
1167+ }
1168+ if (concatOp != commonConcatOp) {
1169+ allOutputsUsedInCommonConcat = false ;
1170+ break ;
1171+ }
1172+ convOrder[use.getOperandNumber ()] = conv;
1173+ } else {
1174+ allOutputsUsedInCommonConcat = false ;
1175+ break ;
1176+ }
1177+ }
1178+ if (!allOutputsUsedInCommonConcat) {
1179+ break ;
1180+ }
1181+ }
1182+
1183+ const int64_t concatAxis = 1 ;
11521184
11531185 auto firstWeightType =
11541186 mlir::cast<ShapedType>(parallelConvs[0 ].getW ().getType ());
@@ -1157,17 +1189,37 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11571189 newWeightShape[0 ] = totalOutputChannels;
11581190 Type newWeightType =
11591191 RankedTensorType::get (newWeightShape, firstWeightType.getElementType ());
1160- Value newWeight = create.onnx .concat (newWeightType, weightValues, 0 );
1192+ SmallVector<Value> orderedWeightValues;
1193+ if (allOutputsUsedInCommonConcat) {
1194+ for (auto [_, v] : convOrder) {
1195+ orderedWeightValues.push_back (weightValues[v]);
1196+ }
1197+ } else {
1198+ for (auto [_, v] : weightValues) {
1199+ orderedWeightValues.push_back (v);
1200+ }
1201+ }
1202+ Value newWeight = create.onnx .concat (newWeightType, orderedWeightValues, 0 );
11611203
11621204 Value newBias;
11631205 if (allHaveBias) {
1164- SmallVector< Value> biasValues;
1206+ llvm::MapVector<ONNXConvOp, Value> biasValues;
11651207 for (auto conv : parallelConvs) {
1166- biasValues. push_back ( conv.getB () );
1208+ biasValues[ conv] = conv .getB ();
11671209 }
11681210 SmallVector<int64_t > newBiasShape = {totalOutputChannels};
11691211 Type newBiasType = RankedTensorType::get (newBiasShape, elementType);
1170- newBias = create.onnx .concat (newBiasType, biasValues, 0 );
1212+ SmallVector<Value> orderedBiasValues;
1213+ if (allOutputsUsedInCommonConcat) {
1214+ for (auto [_, v] : convOrder) {
1215+ orderedBiasValues.push_back (biasValues[v]);
1216+ }
1217+ } else {
1218+ for (auto [_, v] : biasValues) {
1219+ orderedBiasValues.push_back (v);
1220+ }
1221+ }
1222+ newBias = create.onnx .concat (newBiasType, orderedBiasValues, 0 );
11711223 } else {
11721224 newBias = parallelConvs[0 ].getB ();
11731225 }
@@ -1186,32 +1238,6 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11861238 convOp1.getGroupAttr (), convOp1.getKernelShapeAttr (),
11871239 convOp1.getPadsAttr (), convOp1.getStridesAttr ());
11881240
1189- ONNXConcatOp commonConcatOp = nullptr ;
1190- bool allOutputsUsedInCommonConcat = true ;
1191-
1192- for (auto conv : parallelConvs) {
1193- bool usedInCommonConcat = false ;
1194- for (auto user : conv.getResult ().getUsers ()) {
1195- if (auto concatOp = dyn_cast<ONNXConcatOp>(user)) {
1196- if (!commonConcatOp) {
1197- commonConcatOp = concatOp;
1198- }
1199- if (concatOp != commonConcatOp) {
1200- allOutputsUsedInCommonConcat = false ;
1201- break ;
1202- }
1203- usedInCommonConcat = true ;
1204- } else {
1205- allOutputsUsedInCommonConcat = false ;
1206- break ;
1207- }
1208- }
1209- if (!usedInCommonConcat || !allOutputsUsedInCommonConcat) {
1210- allOutputsUsedInCommonConcat = false ;
1211- break ;
1212- }
1213- }
1214-
12151241 if (allOutputsUsedInCommonConcat && commonConcatOp &&
12161242 commonConcatOp.getAxis () == 1 ) {
12171243 rewriter.replaceOp (commonConcatOp, newConv);
@@ -1315,7 +1341,7 @@ void onnx_mlir::getRecomposeONNXToONNXPatterns(
13151341 patterns.insert <RecomposeDepthToSpaceDCR>(context);
13161342 // AMD Disabled as downstream has no special support for it
13171343 // patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
1318- // patterns.insert<CombineParallelConv2DPattern>(context);
1344+ patterns.insert <CombineParallelConv2DPattern>(context);
13191345}
13201346
13211347/* !
0 commit comments