@@ -1086,8 +1086,7 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
10861086
10871087 SmallVector<ONNXConvOp> parallelConvs = candidateConvs;
10881088
1089- llvm::MapVector<ONNXConvOp, Value>
1090- weightValues; // MapVector to keep the iteration order stable
1089+ SmallVector<Value> weightValues;
10911090 int64_t totalOutputChannels = 0 ;
10921091 for (auto conv : parallelConvs) {
10931092 auto weightType = mlir::cast<ShapedType>(conv.getW ().getType ());
@@ -1097,7 +1096,7 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
10971096 if (!cast<ShapedType>(conv.getType ()).hasStaticShape ())
10981097 return rewriter.notifyMatchFailure (
10991098 conv, " output type must be a ranked tensor with static shape" );
1100- weightValues[ conv] = conv .getW ();
1099+ weightValues. push_back ( conv.getW () );
11011100 totalOutputChannels += weightType.getShape ()[0 ];
11021101 }
11031102
@@ -1149,38 +1148,7 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11491148 OpBuilder::InsertionGuard guard (rewriter);
11501149 rewriter.setInsertionPointAfter (*latestConv);
11511150
1152- ONNXConcatOp commonConcatOp = nullptr ;
1153- bool allOutputsUsedInCommonConcat = true ;
1154-
1155- std::map<size_t , ONNXConvOp>
1156- convOrder; // Key: Operand index in the common concat
1157-
1158- for (auto conv : parallelConvs) {
1159- if (!conv.getResult ().hasOneUse ()) {
1160- allOutputsUsedInCommonConcat = false ;
1161- break ;
1162- }
1163- for (auto &use : conv.getResult ().getUses ()) {
1164- if (auto concatOp = dyn_cast<ONNXConcatOp>(use.getOwner ())) {
1165- if (!commonConcatOp) {
1166- commonConcatOp = concatOp;
1167- }
1168- if (concatOp != commonConcatOp) {
1169- allOutputsUsedInCommonConcat = false ;
1170- break ;
1171- }
1172- convOrder[use.getOperandNumber ()] = conv;
1173- } else {
1174- allOutputsUsedInCommonConcat = false ;
1175- break ;
1176- }
1177- }
1178- if (!allOutputsUsedInCommonConcat) {
1179- break ;
1180- }
1181- }
1182-
1183- const int64_t concatAxis = 1 ;
1151+ int64_t concatAxis = 1 ;
11841152
11851153 auto firstWeightType =
11861154 mlir::cast<ShapedType>(parallelConvs[0 ].getW ().getType ());
@@ -1189,37 +1157,17 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11891157 newWeightShape[0 ] = totalOutputChannels;
11901158 Type newWeightType =
11911159 RankedTensorType::get (newWeightShape, firstWeightType.getElementType ());
1192- SmallVector<Value> orderedWeightValues;
1193- if (allOutputsUsedInCommonConcat) {
1194- for (auto [_, v] : convOrder) {
1195- orderedWeightValues.push_back (weightValues[v]);
1196- }
1197- } else {
1198- for (auto [_, v] : weightValues) {
1199- orderedWeightValues.push_back (v);
1200- }
1201- }
1202- Value newWeight = create.onnx .concat (newWeightType, orderedWeightValues, 0 );
1160+ Value newWeight = create.onnx .concat (newWeightType, weightValues, 0 );
12031161
12041162 Value newBias;
12051163 if (allHaveBias) {
1206- llvm::MapVector<ONNXConvOp, Value> biasValues;
1164+ SmallVector< Value> biasValues;
12071165 for (auto conv : parallelConvs) {
1208- biasValues[ conv] = conv .getB ();
1166+ biasValues. push_back ( conv.getB () );
12091167 }
12101168 SmallVector<int64_t > newBiasShape = {totalOutputChannels};
12111169 Type newBiasType = RankedTensorType::get (newBiasShape, elementType);
1212- SmallVector<Value> orderedBiasValues;
1213- if (allOutputsUsedInCommonConcat) {
1214- for (auto [_, v] : convOrder) {
1215- orderedBiasValues.push_back (biasValues[v]);
1216- }
1217- } else {
1218- for (auto [_, v] : biasValues) {
1219- orderedBiasValues.push_back (v);
1220- }
1221- }
1222- newBias = create.onnx .concat (newBiasType, orderedBiasValues, 0 );
1170+ newBias = create.onnx .concat (newBiasType, biasValues, 0 );
12231171 } else {
12241172 newBias = parallelConvs[0 ].getB ();
12251173 }
@@ -1238,6 +1186,32 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
12381186 convOp1.getGroupAttr (), convOp1.getKernelShapeAttr (),
12391187 convOp1.getPadsAttr (), convOp1.getStridesAttr ());
12401188
1189+ ONNXConcatOp commonConcatOp = nullptr ;
1190+ bool allOutputsUsedInCommonConcat = true ;
1191+
1192+ for (auto conv : parallelConvs) {
1193+ bool usedInCommonConcat = false ;
1194+ for (auto user : conv.getResult ().getUsers ()) {
1195+ if (auto concatOp = dyn_cast<ONNXConcatOp>(user)) {
1196+ if (!commonConcatOp) {
1197+ commonConcatOp = concatOp;
1198+ }
1199+ if (concatOp != commonConcatOp) {
1200+ allOutputsUsedInCommonConcat = false ;
1201+ break ;
1202+ }
1203+ usedInCommonConcat = true ;
1204+ } else {
1205+ allOutputsUsedInCommonConcat = false ;
1206+ break ;
1207+ }
1208+ }
1209+ if (!usedInCommonConcat || !allOutputsUsedInCommonConcat) {
1210+ allOutputsUsedInCommonConcat = false ;
1211+ break ;
1212+ }
1213+ }
1214+
12411215 if (allOutputsUsedInCommonConcat && commonConcatOp &&
12421216 commonConcatOp.getAxis () == 1 ) {
12431217 rewriter.replaceOp (commonConcatOp, newConv);
@@ -1341,7 +1315,7 @@ void onnx_mlir::getRecomposeONNXToONNXPatterns(
13411315 patterns.insert <RecomposeDepthToSpaceDCR>(context);
13421316 // AMD Disabled as downstream has no special support for it
13431317 // patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
1344- patterns.insert <CombineParallelConv2DPattern>(context);
1318+ // patterns.insert<CombineParallelConv2DPattern>(context);
13451319}
13461320
13471321/* !
0 commit comments