Skip to content

Commit 7901074

Browse files
authored
Merge pull request #382 from Xilinx/jrickert.fix_recompose
Fix various bugs in parallel conv recomposition
2 parents 898b13d + 8dd03fe commit 7901074

File tree

3 files changed

+133
-45
lines changed

3 files changed

+133
-45
lines changed

src/Dialect/ONNX/Transforms/Recompose.cpp

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include <numeric>
2424

25+
#include "mlir/Analysis/TopologicalSortUtils.h"
2526
#include "mlir/IR/PatternMatch.h"
2627
#include "mlir/Pass/Pass.h"
2728
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -75,7 +76,6 @@ ValueRange emitSplitByChannels(PatternRewriter &rewriter, Location loc,
7576
splitShape[axis] = size;
7677
resultTypes.push_back(RankedTensorType::get(splitShape, elementType));
7778
}
78-
rewriter.setInsertionPointAfter(input.getDefiningOp());
7979
// Perform Split Operation
8080
ValueRange results =
8181
create.onnx.split(ArrayRef(resultTypes), input, splitConstant, axis);
@@ -1057,6 +1057,10 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
10571057
return rewriter.notifyMatchFailure(
10581058
convOp1, "input must be a ranked tensor with static shape");
10591059

1060+
if (!cast<ShapedType>(convOp1.getType()).hasStaticShape())
1061+
return rewriter.notifyMatchFailure(
1062+
convOp1, "output type must be a ranked tensor with static shape");
1063+
10601064
// Collect all ONNXConvOps using this input.
10611065
SmallVector<ONNXConvOp> candidateConvs;
10621066
for (auto user : input.getUsers()) {
@@ -1084,6 +1088,55 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
10841088

10851089
SmallVector<ONNXConvOp> parallelConvs = candidateConvs;
10861090

1091+
SmallVector<Value> weightValues;
1092+
int64_t totalOutputChannels = 0;
1093+
for (auto conv : parallelConvs) {
1094+
auto weightType = mlir::cast<ShapedType>(conv.getW().getType());
1095+
if (!weightType.hasStaticShape())
1096+
return rewriter.notifyMatchFailure(
1097+
conv, "weight must be a ranked tensor with static shape");
1098+
if (!cast<ShapedType>(conv.getType()).hasStaticShape())
1099+
return rewriter.notifyMatchFailure(
1100+
conv, "output type must be a ranked tensor with static shape");
1101+
weightValues.push_back(conv.getW());
1102+
totalOutputChannels += weightType.getShape()[0];
1103+
}
1104+
1105+
auto *latestConv =
1106+
llvm::max_element(parallelConvs, [](ONNXConvOp a, ONNXConvOp b) {
1107+
return a->isBeforeInBlock(b.getOperation());
1108+
});
1109+
1110+
const auto checkIfOtherConvsReachable = [&](ONNXConvOp conv) {
1111+
SmallVector<Operation *> worklist;
1112+
DenseSet<Operation *> visited;
1113+
worklist.push_back(conv.getOperation());
1114+
while (!worklist.empty()) {
1115+
Operation *current = worklist.back();
1116+
worklist.pop_back();
1117+
1118+
for (auto *user : current->getUsers()) {
1119+
if (auto otherConv = dyn_cast<ONNXConvOp>(user)) {
1120+
if (llvm::is_contained(parallelConvs, otherConv)) {
1121+
// Found another conv that is part of the parallel convs.
1122+
return true;
1123+
}
1124+
}
1125+
if (visited.insert(user).second &&
1126+
user->isBeforeInBlock(*latestConv)) {
1127+
worklist.push_back(user);
1128+
}
1129+
};
1130+
}
1131+
return false;
1132+
};
1133+
// Ensure all convolutions are really parallel, none of then can be part of
1134+
// the input of another convolution
1135+
if (llvm::any_of(parallelConvs, checkIfOtherConvsReachable)) {
1136+
return rewriter.notifyMatchFailure(
1137+
convOp1, "conv ops are not parallel (reachable from each other)");
1138+
}
1139+
10871140
bool allHaveBias = !mlir::isa<NoneType>(parallelConvs[0].getB().getType());
10881141

10891142
Location loc = convOp1.getLoc();
@@ -1097,14 +1150,6 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
10971150

10981151
int64_t concatAxis = 1;
10991152

1100-
SmallVector<Value> weightValues;
1101-
int64_t totalOutputChannels = 0;
1102-
for (auto conv : parallelConvs) {
1103-
auto weightType = mlir::cast<ShapedType>(conv.getW().getType());
1104-
weightValues.push_back(conv.getW());
1105-
totalOutputChannels += weightType.getShape()[0];
1106-
}
1107-
11081153
auto firstWeightType =
11091154
mlir::cast<ShapedType>(parallelConvs[0].getW().getType());
11101155
SmallVector<int64_t> newWeightShape(
@@ -1137,6 +1182,8 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11371182
newOutputShape[concatAxis] = totalOutputChannels;
11381183
auto newOutputType = RankedTensorType::get(newOutputShape, elementType);
11391184

1185+
OpBuilder::InsertionGuard guard(rewriter);
1186+
rewriter.setInsertionPointAfter(*latestConv);
11401187
auto newConv =
11411188
rewriter.create<ONNXConvOp>(loc, newOutputType, input, newWeight,
11421189
newBias, convOp1.getAutoPadAttr(), convOp1.getDilationsAttr(),
@@ -1171,8 +1218,7 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11711218

11721219
if (allOutputsUsedInCommonConcat && commonConcatOp &&
11731220
commonConcatOp.getAxis() == 1) {
1174-
commonConcatOp.getResult().replaceAllUsesWith(newConv.getResult());
1175-
rewriter.eraseOp(commonConcatOp);
1221+
rewriter.replaceOp(commonConcatOp, newConv);
11761222
} else {
11771223
SmallVector<int64_t> splitSizesVec;
11781224
for (auto conv : parallelConvs) {
@@ -1181,15 +1227,15 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
11811227
splitSizesVec.push_back(channels);
11821228
}
11831229

1184-
rewriter.setInsertionPointAfter(newConv);
11851230
ValueRange splitResults = onnx_mlir::emitSplitByChannels(
11861231
rewriter, loc, newConv.getResult(), splitSizesVec, concatAxis);
1187-
11881232
for (size_t i = 0; i < parallelConvs.size(); ++i) {
1189-
parallelConvs[i].getResult().replaceAllUsesWith(splitResults[i]);
1233+
rewriter.replaceAllOpUsesWith(parallelConvs[i], splitResults[i]);
11901234
}
1235+
// Sort the block topological, as the operations after the split may be in
1236+
// the wrong place otherwise
1237+
mlir::sortTopologically(newConv->getBlock());
11911238
}
1192-
11931239
for (auto conv : parallelConvs) {
11941240
rewriter.eraseOp(conv);
11951241
}
@@ -1273,8 +1319,7 @@ void onnx_mlir::getRecomposeONNXToONNXPatterns(
12731319
patterns.insert<RecomposeDepthToSpaceDCR>(context);
12741320
// AMD Disabled as downstream has no special support for it
12751321
// patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
1276-
// AMD Temporary disabled as this pattern is buggy.
1277-
// patterns.insert<CombineParallelConv2DPattern>(context);
1322+
patterns.insert<CombineParallelConv2DPattern>(context);
12781323
}
12791324

12801325
/*!

test/mlir/onnx/onnx_recompose_combine_parallel_conv.mlir

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
// RUN: onnx-mlir --useOnnxModelTypes=false --EmitONNXIR --printIR %s | FileCheck %s
22

3-
// Temporary disabled
4-
// XFAIL: *
5-
63
func.func @test_conv_concat_simple(%arg0: tensor<1x1x512x512xf32>) -> tensor<1x64x512x512xf32> {
74
%0 = onnx.Constant dense<0.00999999977> : tensor<32x1x3x3xf32>
85
%1 = onnx.Constant dense<0.00999999977> : tensor<32xf32>
@@ -118,4 +115,50 @@ func.func @test_combine_conv_split(%arg0: tensor<1x1x512x512xf32>) -> tensor<1x9
118115
// CHECK: [[FINAL_OUT:%.+]] = "onnx.Concat"([[VAR_3_]], [[VAR_4_]], [[VAR_5_]]) {axis = 1 : si64, onnx_node_name = "onnx.Concat_4_7"} : (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) -> tensor<1x96x512x512xf32>
119116
// CHECK: return [[FINAL_OUT]] : tensor<1x96x512x512xf32>
120117

121-
}
118+
}
119+
120+
func.func @test_conv_concat_dependency(%arg0: tensor<1x1x512x512xf32>) -> tensor<1x64x512x512xf32> {
121+
%0 = onnx.Constant dense<0.00999999977> : tensor<32x1x3x3xf32>
122+
%1 = onnx.Constant dense<0.00999999977> : tensor<32xf32>
123+
%2 = onnx.Constant dense<0.00999999977> : tensor<32x1x3x3xf32>
124+
%3 = onnx.Constant dense<0.00999999977> : tensor<32xf32>
125+
%4 = "onnx.Conv"(%arg0, %0, %1) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<32x1x3x3xf32>, tensor<32xf32>) -> tensor<1x32x512x512xf32>
126+
%reduceAxes = onnx.Constant dense<[0, 2, 3]> : tensor<3xi64>
127+
%reduced = "onnx.ReduceMean"(%4, %reduceAxes) {keepdims = 0 : si64} : (tensor<1x32x512x512xf32>, tensor<3xi64>) -> tensor<32xf32>
128+
%5 = "onnx.Conv"(%arg0, %2, %reduced) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<32x1x3x3xf32>, tensor<32xf32>) -> tensor<1x32x512x512xf32>
129+
%6 = "onnx.Concat"(%4, %5) {axis = 1 : si64} : (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) -> tensor<1x64x512x512xf32>
130+
return %6 : tensor<1x64x512x512xf32>
131+
132+
// COM: Can not be rewritten as there is a def-use chain between the Convs
133+
// CHECK-LABEL: func.func @test_conv_concat_dependency
134+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x512x512xf32>) -> tensor<1x64x512x512xf32> {
135+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[0, 2, 3]> : tensor<3xi64>
136+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32x1x3x3xf32>
137+
// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32xf32>
138+
// CHECK: [[VAR_3_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_1_]], [[VAR_2_]]) {auto_pad = "NOTSET", group = 1 : si64, onnx_node_name = "onnx.Conv_8", pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<32x1x3x3xf32>, tensor<32xf32>) -> tensor<1x32x512x512xf32>
139+
// CHECK: [[VAR_4_:%.+]] = "onnx.ReduceMean"([[VAR_3_]], [[VAR_0_]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64, onnx_node_name = "onnx.ReduceMean_9"} : (tensor<1x32x512x512xf32>, tensor<3xi64>) -> tensor<32xf32>
140+
// CHECK: [[VAR_5_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_1_]], [[VAR_4_]]) {auto_pad = "NOTSET", group = 1 : si64, onnx_node_name = "onnx.Conv_10", pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<32x1x3x3xf32>, tensor<32xf32>) -> tensor<1x32x512x512xf32>
141+
// CHECK: [[VAR_6_:%.+]] = "onnx.Concat"([[VAR_3_]], [[VAR_5_]]) {axis = 1 : si64, onnx_node_name = "onnx.Concat_11"} : (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) -> tensor<1x64x512x512xf32>
142+
// CHECK: return [[VAR_6_]] : tensor<1x64x512x512xf32>
143+
// CHECK: }
144+
}
145+
146+
func.func @test_conv_concat_not_static_shape(%arg0: tensor<1x1x512x512xf32>, %0: tensor<*xf32>) -> tensor<1x64x512x512xf32> {
147+
%1 = onnx.Constant dense<0.00999999977> : tensor<32xf32>
148+
%2 = onnx.Constant dense<0.00999999977> : tensor<32x1x3x3xf32>
149+
%3 = onnx.Constant dense<0.00999999977> : tensor<32xf32>
150+
%4 = "onnx.Conv"(%arg0, %0, %1) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<*xf32>, tensor<32xf32>) -> tensor<1x32x512x512xf32>
151+
%5 = "onnx.Conv"(%arg0, %2, %3) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<32x1x3x3xf32>, tensor<32xf32>) -> tensor<1x?x512x512xf32>
152+
%6 = "onnx.Concat"(%4, %5) {axis = 1 : si64} : (tensor<1x32x512x512xf32>, tensor<1x?x512x512xf32>) -> tensor<1x64x512x512xf32>
153+
return %6 : tensor<1x64x512x512xf32>
154+
155+
// CHECK-LABEL: func.func @test_conv_concat_not_static_shape
156+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x512x512xf32>, [[PARAM_1_:%.+]]: tensor<*xf32>) -> tensor<1x64x512x512xf32> {
157+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32xf32>
158+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32x1x3x3xf32>
159+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {auto_pad = "NOTSET", group = 1 : si64, onnx_node_name = "onnx.Conv_12", pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<*xf32>, tensor<32xf32>) -> tensor<1x32x512x512xf32>
160+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_1_]], [[VAR_0_]]) {auto_pad = "NOTSET", group = 1 : si64, onnx_node_name = "onnx.Conv_13", pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<32x1x3x3xf32>, tensor<32xf32>) -> tensor<1x32x512x512xf32>
161+
// CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_3_]]) {axis = 1 : si64, onnx_node_name = "onnx.Concat_14"} : (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) -> tensor<1x64x512x512xf32>
162+
// CHECK: return [[VAR_4_]] : tensor<1x64x512x512xf32>
163+
// CHECK: }
164+
}

test/mlir/onnx/onnx_recompose_locations.mlir

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,28 +46,28 @@ func.func @test_combine_conv_split(%arg0: tensor<1x1x512x512xf32>) -> tensor<1x9
4646
return %12 : tensor<1x96x512x512xf32>
4747

4848
// CHECK-LABEL: func.func @test_combine_conv_split
49-
// XFAIL-CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x512x512xf32>
50-
// XFAIL-CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<32> : tensor<3xi64>
51-
// XFAIL-CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32x1x3x3xf32>
52-
// XFAIL-CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32xf32>
53-
// XFAIL-CHECK-NOT: separator of consecutive DAGs
54-
// XFAIL-CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_1_]], [[VAR_1_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<32x1x3x3xf32>, tensor<32x1x3x3xf32>, tensor<32x1x3x3xf32>) -> tensor<96x1x3x3xf32> loc([[LOC_FUSED:#.+]])
55-
// XFAIL-CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_2_]], [[VAR_2_]]) {axis = 0 : si64} : (tensor<32xf32>, tensor<32xf32>, tensor<32xf32>) -> tensor<96xf32> loc([[LOC_FUSED:#.+]])
56-
// XFAIL-CHECK: [[VAR_5_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_3_]], [[VAR_4_]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<96x1x3x3xf32>, tensor<96xf32>) -> tensor<1x96x512x512xf32> loc([[LOC_FUSED:#.+]])
57-
// XFAIL-CHECK: [[VAR_6_:%.+]]:3 = "onnx.Split"([[VAR_5_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x96x512x512xf32>, tensor<3xi64>) -> (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) loc([[LOC_FUSED:#.+]])
58-
// XFAIL-CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Relu"([[VAR_6_]]#2) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_RELU:#.+]])
59-
// XFAIL-CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Sigmoid"([[VAR_6_]]#1) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_SIGMOID:#.+]])
60-
// XFAIL-CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Tanh"([[VAR_6_]]#0) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_TANH:#.+]])
61-
// XFAIL-CHECK: [[VAR_10_:%.+]] = "onnx.Concat"([[VAR_7_]], [[VAR_8_]], [[VAR_9_]]) {axis = 1 : si64} : (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) -> tensor<1x96x512x512xf32> loc([[LOC_ORIGINAL_CONCAT:#.+]])
62-
// XFAIL-CHECK: return [[VAR_10_]] : tensor<1x96x512x512xf32>
63-
// XFAIL-CHECK: }
49+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x512x512xf32>
50+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<32> : tensor<3xi64>
51+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32x1x3x3xf32>
52+
// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32xf32>
53+
// CHECK-NOT: separator of consecutive DAGs
54+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_1_]], [[VAR_1_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<32x1x3x3xf32>, tensor<32x1x3x3xf32>, tensor<32x1x3x3xf32>) -> tensor<96x1x3x3xf32> loc([[LOC_FUSED:#.+]])
55+
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_2_]], [[VAR_2_]]) {axis = 0 : si64} : (tensor<32xf32>, tensor<32xf32>, tensor<32xf32>) -> tensor<96xf32> loc([[LOC_FUSED:#.+]])
56+
// CHECK: [[VAR_5_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_3_]], [[VAR_4_]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<96x1x3x3xf32>, tensor<96xf32>) -> tensor<1x96x512x512xf32> loc([[LOC_FUSED:#.+]])
57+
// CHECK: [[VAR_6_:%.+]]:3 = "onnx.Split"([[VAR_5_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x96x512x512xf32>, tensor<3xi64>) -> (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) loc([[LOC_FUSED:#.+]])
58+
// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Relu"([[VAR_6_]]#2) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_RELU:#.+]])
59+
// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Sigmoid"([[VAR_6_]]#1) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_SIGMOID:#.+]])
60+
// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Tanh"([[VAR_6_]]#0) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_TANH:#.+]])
61+
// CHECK: [[VAR_10_:%.+]] = "onnx.Concat"([[VAR_7_]], [[VAR_8_]], [[VAR_9_]]) {axis = 1 : si64} : (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) -> tensor<1x96x512x512xf32> loc([[LOC_ORIGINAL_CONCAT:#.+]])
62+
// CHECK: return [[VAR_10_]] : tensor<1x96x512x512xf32>
63+
// CHECK: }
6464

65-
// XFAIL-CHECK-DAG: [[LOC_RELU:#.+]] = loc("relu")
66-
// XFAIL-CHECK-DAG: [[LOC_SIGMOID:#.+]] = loc("sigmoid")
67-
// XFAIL-CHECK-DAG: [[LOC_TANH:#.+]] = loc("tanh")
68-
// XFAIL-CHECK-DAG: [[LOC_ORIGINAL_CONCAT:#.+]] = loc("concat")
69-
// XFAIL-CHECK-DAG: [[LOC_CONV1:#.+]] = loc("conv1")
70-
// XFAIL-CHECK-DAG: [[LOC_CONV2:#.+]] = loc("conv2")
71-
// XFAIL-CHECK-DAG: [[LOC_CONV3:#.+]] = loc("conv3")
72-
// XFAIL-CHECK-DAG: [[LOC_FUSED]] = loc(fused[[[LOC_CONV1]], [[LOC_CONV3]], [[LOC_CONV2]]])
73-
}
65+
// CHECK-DAG: [[LOC_RELU:#.+]] = loc("relu")
66+
// CHECK-DAG: [[LOC_SIGMOID:#.+]] = loc("sigmoid")
67+
// CHECK-DAG: [[LOC_TANH:#.+]] = loc("tanh")
68+
// CHECK-DAG: [[LOC_ORIGINAL_CONCAT:#.+]] = loc("concat")
69+
// CHECK-DAG: [[LOC_CONV1:#.+]] = loc("conv1")
70+
// CHECK-DAG: [[LOC_CONV2:#.+]] = loc("conv2")
71+
// CHECK-DAG: [[LOC_CONV3:#.+]] = loc("conv3")
72+
// CHECK-DAG: [[LOC_FUSED]] = loc(fused[[[LOC_CONV1]], [[LOC_CONV3]], [[LOC_CONV2]]])
73+
}

0 commit comments

Comments
 (0)