Skip to content

Commit 9120641

Browse files
committed
Take concat order in account when combining
parallel convs Signed-off-by: Rickert, Jonas <[email protected]>
1 parent a2be298 commit 9120641

File tree

3 files changed

+228
-155
lines changed

3 files changed

+228
-155
lines changed

src/Dialect/ONNX/Transforms/Recompose.cpp

Lines changed: 60 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,7 +1086,8 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
10861086

10871087
SmallVector<ONNXConvOp> parallelConvs = candidateConvs;
10881088

1089-
SmallVector<Value> weightValues;
1089+
llvm::MapVector<ONNXConvOp, Value>
1090+
weightValues; // MapVector to keep the iteration order stable
10901091
int64_t totalOutputChannels = 0;
10911092
for (auto conv : parallelConvs) {
10921093
auto weightType = mlir::cast<ShapedType>(conv.getW().getType());
@@ -1096,7 +1097,7 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
10961097
if (!cast<ShapedType>(conv.getType()).hasStaticShape())
10971098
return rewriter.notifyMatchFailure(
10981099
conv, "output type must be a ranked tensor with static shape");
1099-
weightValues.push_back(conv.getW());
1100+
weightValues[conv] = conv.getW();
11001101
totalOutputChannels += weightType.getShape()[0];
11011102
}
11021103

@@ -1148,7 +1149,38 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11481149
OpBuilder::InsertionGuard guard(rewriter);
11491150
rewriter.setInsertionPointAfter(*latestConv);
11501151

1151-
int64_t concatAxis = 1;
1152+
ONNXConcatOp commonConcatOp = nullptr;
1153+
bool allOutputsUsedInCommonConcat = true;
1154+
1155+
std::map<size_t, ONNXConvOp>
1156+
convOrder; // Key: Operand index in the common concat
1157+
1158+
for (auto conv : parallelConvs) {
1159+
if (!conv.getResult().hasOneUse()) {
1160+
allOutputsUsedInCommonConcat = false;
1161+
break;
1162+
}
1163+
for (auto &use : conv.getResult().getUses()) {
1164+
if (auto concatOp = dyn_cast<ONNXConcatOp>(use.getOwner())) {
1165+
if (!commonConcatOp) {
1166+
commonConcatOp = concatOp;
1167+
}
1168+
if (concatOp != commonConcatOp) {
1169+
allOutputsUsedInCommonConcat = false;
1170+
break;
1171+
}
1172+
convOrder[use.getOperandNumber()] = conv;
1173+
} else {
1174+
allOutputsUsedInCommonConcat = false;
1175+
break;
1176+
}
1177+
}
1178+
if (!allOutputsUsedInCommonConcat) {
1179+
break;
1180+
}
1181+
}
1182+
1183+
const int64_t concatAxis = 1;
11521184

11531185
auto firstWeightType =
11541186
mlir::cast<ShapedType>(parallelConvs[0].getW().getType());
@@ -1157,17 +1189,37 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11571189
newWeightShape[0] = totalOutputChannels;
11581190
Type newWeightType =
11591191
RankedTensorType::get(newWeightShape, firstWeightType.getElementType());
1160-
Value newWeight = create.onnx.concat(newWeightType, weightValues, 0);
1192+
SmallVector<Value> orderedWeightValues;
1193+
if (allOutputsUsedInCommonConcat) {
1194+
for (auto [_, v] : convOrder) {
1195+
orderedWeightValues.push_back(weightValues[v]);
1196+
}
1197+
} else {
1198+
for (auto [_, v] : weightValues) {
1199+
orderedWeightValues.push_back(v);
1200+
}
1201+
}
1202+
Value newWeight = create.onnx.concat(newWeightType, orderedWeightValues, 0);
11611203

11621204
Value newBias;
11631205
if (allHaveBias) {
1164-
SmallVector<Value> biasValues;
1206+
llvm::MapVector<ONNXConvOp, Value> biasValues;
11651207
for (auto conv : parallelConvs) {
1166-
biasValues.push_back(conv.getB());
1208+
biasValues[conv] = conv.getB();
11671209
}
11681210
SmallVector<int64_t> newBiasShape = {totalOutputChannels};
11691211
Type newBiasType = RankedTensorType::get(newBiasShape, elementType);
1170-
newBias = create.onnx.concat(newBiasType, biasValues, 0);
1212+
SmallVector<Value> orderedBiasValues;
1213+
if (allOutputsUsedInCommonConcat) {
1214+
for (auto [_, v] : convOrder) {
1215+
orderedBiasValues.push_back(biasValues[v]);
1216+
}
1217+
} else {
1218+
for (auto [_, v] : biasValues) {
1219+
orderedBiasValues.push_back(v);
1220+
}
1221+
}
1222+
newBias = create.onnx.concat(newBiasType, orderedBiasValues, 0);
11711223
} else {
11721224
newBias = parallelConvs[0].getB();
11731225
}
@@ -1186,32 +1238,6 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11861238
convOp1.getGroupAttr(), convOp1.getKernelShapeAttr(),
11871239
convOp1.getPadsAttr(), convOp1.getStridesAttr());
11881240

1189-
ONNXConcatOp commonConcatOp = nullptr;
1190-
bool allOutputsUsedInCommonConcat = true;
1191-
1192-
for (auto conv : parallelConvs) {
1193-
bool usedInCommonConcat = false;
1194-
for (auto user : conv.getResult().getUsers()) {
1195-
if (auto concatOp = dyn_cast<ONNXConcatOp>(user)) {
1196-
if (!commonConcatOp) {
1197-
commonConcatOp = concatOp;
1198-
}
1199-
if (concatOp != commonConcatOp) {
1200-
allOutputsUsedInCommonConcat = false;
1201-
break;
1202-
}
1203-
usedInCommonConcat = true;
1204-
} else {
1205-
allOutputsUsedInCommonConcat = false;
1206-
break;
1207-
}
1208-
}
1209-
if (!usedInCommonConcat || !allOutputsUsedInCommonConcat) {
1210-
allOutputsUsedInCommonConcat = false;
1211-
break;
1212-
}
1213-
}
1214-
12151241
if (allOutputsUsedInCommonConcat && commonConcatOp &&
12161242
commonConcatOp.getAxis() == 1) {
12171243
rewriter.replaceOp(commonConcatOp, newConv);
@@ -1315,7 +1341,7 @@ void onnx_mlir::getRecomposeONNXToONNXPatterns(
13151341
patterns.insert<RecomposeDepthToSpaceDCR>(context);
13161342
// AMD Disabled as downstream has no special support for it
13171343
// patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
1318-
// patterns.insert<CombineParallelConv2DPattern>(context);
1344+
patterns.insert<CombineParallelConv2DPattern>(context);
13191345
}
13201346

13211347
/*!

0 commit comments

Comments
 (0)