Skip to content

Commit 0a665d9

Browse files
authored
Revert "Take concat order in account when combining parallel convs"
1 parent 033a463 commit 0a665d9

File tree

3 files changed

+155
-228
lines changed

3 files changed

+155
-228
lines changed

src/Dialect/ONNX/Transforms/Recompose.cpp

Lines changed: 34 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,8 +1086,7 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
10861086

10871087
SmallVector<ONNXConvOp> parallelConvs = candidateConvs;
10881088

1089-
llvm::MapVector<ONNXConvOp, Value>
1090-
weightValues; // MapVector to keep the iteration order stable
1089+
SmallVector<Value> weightValues;
10911090
int64_t totalOutputChannels = 0;
10921091
for (auto conv : parallelConvs) {
10931092
auto weightType = mlir::cast<ShapedType>(conv.getW().getType());
@@ -1097,7 +1096,7 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
10971096
if (!cast<ShapedType>(conv.getType()).hasStaticShape())
10981097
return rewriter.notifyMatchFailure(
10991098
conv, "output type must be a ranked tensor with static shape");
1100-
weightValues[conv] = conv.getW();
1099+
weightValues.push_back(conv.getW());
11011100
totalOutputChannels += weightType.getShape()[0];
11021101
}
11031102

@@ -1149,38 +1148,7 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11491148
OpBuilder::InsertionGuard guard(rewriter);
11501149
rewriter.setInsertionPointAfter(*latestConv);
11511150

1152-
ONNXConcatOp commonConcatOp = nullptr;
1153-
bool allOutputsUsedInCommonConcat = true;
1154-
1155-
std::map<size_t, ONNXConvOp>
1156-
convOrder; // Key: Operand index in the common concat
1157-
1158-
for (auto conv : parallelConvs) {
1159-
if (!conv.getResult().hasOneUse()) {
1160-
allOutputsUsedInCommonConcat = false;
1161-
break;
1162-
}
1163-
for (auto &use : conv.getResult().getUses()) {
1164-
if (auto concatOp = dyn_cast<ONNXConcatOp>(use.getOwner())) {
1165-
if (!commonConcatOp) {
1166-
commonConcatOp = concatOp;
1167-
}
1168-
if (concatOp != commonConcatOp) {
1169-
allOutputsUsedInCommonConcat = false;
1170-
break;
1171-
}
1172-
convOrder[use.getOperandNumber()] = conv;
1173-
} else {
1174-
allOutputsUsedInCommonConcat = false;
1175-
break;
1176-
}
1177-
}
1178-
if (!allOutputsUsedInCommonConcat) {
1179-
break;
1180-
}
1181-
}
1182-
1183-
const int64_t concatAxis = 1;
1151+
int64_t concatAxis = 1;
11841152

11851153
auto firstWeightType =
11861154
mlir::cast<ShapedType>(parallelConvs[0].getW().getType());
@@ -1189,37 +1157,17 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11891157
newWeightShape[0] = totalOutputChannels;
11901158
Type newWeightType =
11911159
RankedTensorType::get(newWeightShape, firstWeightType.getElementType());
1192-
SmallVector<Value> orderedWeightValues;
1193-
if (allOutputsUsedInCommonConcat) {
1194-
for (auto [_, v] : convOrder) {
1195-
orderedWeightValues.push_back(weightValues[v]);
1196-
}
1197-
} else {
1198-
for (auto [_, v] : weightValues) {
1199-
orderedWeightValues.push_back(v);
1200-
}
1201-
}
1202-
Value newWeight = create.onnx.concat(newWeightType, orderedWeightValues, 0);
1160+
Value newWeight = create.onnx.concat(newWeightType, weightValues, 0);
12031161

12041162
Value newBias;
12051163
if (allHaveBias) {
1206-
llvm::MapVector<ONNXConvOp, Value> biasValues;
1164+
SmallVector<Value> biasValues;
12071165
for (auto conv : parallelConvs) {
1208-
biasValues[conv] = conv.getB();
1166+
biasValues.push_back(conv.getB());
12091167
}
12101168
SmallVector<int64_t> newBiasShape = {totalOutputChannels};
12111169
Type newBiasType = RankedTensorType::get(newBiasShape, elementType);
1212-
SmallVector<Value> orderedBiasValues;
1213-
if (allOutputsUsedInCommonConcat) {
1214-
for (auto [_, v] : convOrder) {
1215-
orderedBiasValues.push_back(biasValues[v]);
1216-
}
1217-
} else {
1218-
for (auto [_, v] : biasValues) {
1219-
orderedBiasValues.push_back(v);
1220-
}
1221-
}
1222-
newBias = create.onnx.concat(newBiasType, orderedBiasValues, 0);
1170+
newBias = create.onnx.concat(newBiasType, biasValues, 0);
12231171
} else {
12241172
newBias = parallelConvs[0].getB();
12251173
}
@@ -1238,6 +1186,32 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
12381186
convOp1.getGroupAttr(), convOp1.getKernelShapeAttr(),
12391187
convOp1.getPadsAttr(), convOp1.getStridesAttr());
12401188

1189+
ONNXConcatOp commonConcatOp = nullptr;
1190+
bool allOutputsUsedInCommonConcat = true;
1191+
1192+
for (auto conv : parallelConvs) {
1193+
bool usedInCommonConcat = false;
1194+
for (auto user : conv.getResult().getUsers()) {
1195+
if (auto concatOp = dyn_cast<ONNXConcatOp>(user)) {
1196+
if (!commonConcatOp) {
1197+
commonConcatOp = concatOp;
1198+
}
1199+
if (concatOp != commonConcatOp) {
1200+
allOutputsUsedInCommonConcat = false;
1201+
break;
1202+
}
1203+
usedInCommonConcat = true;
1204+
} else {
1205+
allOutputsUsedInCommonConcat = false;
1206+
break;
1207+
}
1208+
}
1209+
if (!usedInCommonConcat || !allOutputsUsedInCommonConcat) {
1210+
allOutputsUsedInCommonConcat = false;
1211+
break;
1212+
}
1213+
}
1214+
12411215
if (allOutputsUsedInCommonConcat && commonConcatOp &&
12421216
commonConcatOp.getAxis() == 1) {
12431217
rewriter.replaceOp(commonConcatOp, newConv);
@@ -1341,7 +1315,7 @@ void onnx_mlir::getRecomposeONNXToONNXPatterns(
13411315
patterns.insert<RecomposeDepthToSpaceDCR>(context);
13421316
// AMD Disabled as downstream has no special support for it
13431317
// patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
1344-
patterns.insert<CombineParallelConv2DPattern>(context);
1318+
// patterns.insert<CombineParallelConv2DPattern>(context);
13451319
}
13461320

13471321
/*!

0 commit comments

Comments
 (0)