Skip to content

Commit 4bb5549

Browse files
authored
Merge branch 'feature/onnx-to-tosa' into jrickert.bump_integration_1
2 parents 5fca4b1 + 66f19d8 commit 4bb5549

File tree

9 files changed

+199
-26
lines changed

9 files changed

+199
-26
lines changed

.github/workflows/ubuntu-build.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
name: Out-of-tree build
22

33
on:
4-
pull_request:
54
push:
6-
branches: [ main, feature/onnx-to-tosa ]
5+
branches: [ main ]
76

87
concurrency:
98
# Build every push to main

src/Dialect/ONNX/ONNXOps.td.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9233,6 +9233,7 @@ def ONNXSpaceToDepthOp:ONNX_Op<"SpaceToDepth",
92339233

92349234
def ONNXSplitOp:ONNX_Op<"Split",
92359235
[Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
9236+
let hasCanonicalizer = 1;
92369237
let summary = "ONNX Split operation";
92379238
let description = [{
92389239
Split a tensor into a list of tensors, along the specified 'axis'.

src/Dialect/ONNX/ONNXOps/Canonicalize.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2096,6 +2096,75 @@ struct RemoveBatchNormPattern
20962096
}
20972097
};
20982098

2099+
// "Pulls" Relu-like operations up through a SplitOp
2100+
struct PullReluLikeOpsThroughSplitPattern
2101+
: public OpRewritePattern<ONNXSplitOp> {
2102+
using OpRewritePattern<ONNXSplitOp>::OpRewritePattern;
2103+
2104+
LogicalResult matchAndRewrite(
2105+
ONNXSplitOp splitOp, PatternRewriter &rewriter) const final {
2106+
2107+
Operation *firstUser = nullptr;
2108+
SmallVector<Operation *> reluLikeOps;
2109+
Location newLoc = rewriter.getUnknownLoc();
2110+
2111+
const auto areFilteredAttrsEqual = [](Operation *op1, Operation *op2) {
2112+
DenseMap<StringRef, Attribute> filteredAttrs1;
2113+
DenseMap<StringRef, Attribute> filteredAttrs2;
2114+
for (const auto &attr : op1->getAttrs()) {
2115+
if (attr.getName() != "onnx_node_name") {
2116+
filteredAttrs1[attr.getName()] = attr.getValue();
2117+
}
2118+
}
2119+
for (const auto &attr : op2->getAttrs()) {
2120+
if (attr.getName() != "onnx_node_name") {
2121+
filteredAttrs2[attr.getName()] = attr.getValue();
2122+
}
2123+
}
2124+
return filteredAttrs1 == filteredAttrs2;
2125+
};
2126+
2127+
for (Operation *op : splitOp->getUsers()) {
2128+
// TODO: This pattern could be more generic, for all unary, elementwise
2129+
// ops. Having a trait for them would make this easier.
2130+
if (!isa<ONNXReluOp, ONNXLeakyReluOp>(op)) {
2131+
return rewriter.notifyMatchFailure(
2132+
splitOp, "SplitOp must be used by a Relu-like op");
2133+
}
2134+
if (op->getOperand(0).getType() != op->getResult(0).getType()) {
2135+
// This could happen if shape inference did not run
2136+
return rewriter.notifyMatchFailure(
2137+
splitOp, "Relu-like op must have same input and output type");
2138+
}
2139+
if (!firstUser) {
2140+
firstUser = op;
2141+
} else {
2142+
if (firstUser->getName() != op->getName() ||
2143+
!areFilteredAttrsEqual(firstUser, op)) {
2144+
return rewriter.notifyMatchFailure(splitOp,
2145+
"SplitOp must be used by Relu-like ops of the same type "
2146+
"and attributes");
2147+
}
2148+
}
2149+
reluLikeOps.push_back(op);
2150+
newLoc = rewriter.getFusedLoc({newLoc, op->getLoc()});
2151+
}
2152+
rewriter.setInsertionPoint(splitOp);
2153+
auto *newRelu = rewriter.clone(*reluLikeOps.front());
2154+
rewriter.modifyOpInPlace(newRelu, [&]() {
2155+
newRelu->setOperand(0, splitOp.getOperand(0));
2156+
newRelu->getResult(0).setType(splitOp.getOperand(0).getType());
2157+
newRelu->setLoc(newLoc);
2158+
});
2159+
rewriter.modifyOpInPlace(
2160+
splitOp, [&]() { splitOp->setOperand(0, newRelu->getResult(0)); });
2161+
for (Operation *op : reluLikeOps) {
2162+
rewriter.replaceOp(op, op->getOperands());
2163+
}
2164+
return success();
2165+
}
2166+
};
2167+
20992168
// =============================================================================
21002169
/// Register optimization patterns as "canonicalization" patterns.
21012170
/// Add op to OpsWithCanonicalizer in gen_onnx_mlir.py to activate.
@@ -2369,6 +2438,13 @@ void ONNXSpaceToDepthOp::getCanonicalizationPatterns(
23692438
results.insert<RemoveSpaceToDepthDepthToSpacePattern>(context);
23702439
}
23712440

2441+
/// on the ONNXSplitOp
2442+
void ONNXSplitOp::getCanonicalizationPatterns(
2443+
RewritePatternSet &results, MLIRContext *context) {
2444+
results.insert<PullReluLikeOpsThroughSplitPattern>(context);
2445+
;
2446+
}
2447+
23722448
/// on the ONNXSqueezeOp.
23732449
void ONNXSqueezeOp::getCanonicalizationPatterns(
23742450
RewritePatternSet &result, MLIRContext *context) {

src/Dialect/ONNX/Transforms/Recompose.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1273,7 +1273,8 @@ void onnx_mlir::getRecomposeONNXToONNXPatterns(
12731273
patterns.insert<RecomposeDepthToSpaceDCR>(context);
12741274
// AMD Disabled as downstream has no special support for it
12751275
// patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
1276-
patterns.insert<CombineParallelConv2DPattern>(context);
1276+
// AMD Temporary disabled as this pattern is buggy.
1277+
// patterns.insert<CombineParallelConv2DPattern>(context);
12771278
}
12781279

12791280
/*!

test/mlir/onnx/onnx_canonicalization.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1952,6 +1952,78 @@ return %2 : tensor<1x12x4xf32>
19521952

19531953
}
19541954

1955+
// -----
1956+
func.func @test_split_relu_movement(%arg0: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
1957+
%cst = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
1958+
%0:3 = "onnx.Split"(%arg0, %cst) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
1959+
%1 = "onnx.Relu"(%0#0) {onnx_node_name = "onnx.Relu_1"} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
1960+
%2 = "onnx.Relu"(%0#1) {onnx_node_name = "onnx.Relu_2"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1961+
%3 = "onnx.Relu"(%0#2) {onnx_node_name = "onnx.Relu_3"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1962+
onnx.Return %1, %2, %3 : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
1963+
}
1964+
// CHECK-LABEL: func.func @test_split_relu_movement
1965+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
1966+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
1967+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Relu"([[PARAM_0_]]) {onnx_node_name = "onnx.Relu_1"} : (tensor<1x8x2xf32>) -> tensor<1x8x2xf32>
1968+
// CHECK: [[VAR_2_:%.+]]:3 = "onnx.Split"([[VAR_1_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
1969+
// CHECK: onnx.Return [[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2 : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
1970+
// CHECK: }
1971+
1972+
// -----
1973+
func.func @test_split_relu_movement_not_all_equal(%arg0: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
1974+
%cst = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
1975+
%0:3 = "onnx.Split"(%arg0, %cst) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
1976+
%1 = "onnx.Relu"(%0#0) {onnx_node_name = "onnx.Relu_1"} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
1977+
%2 = "onnx.LeakyRelu"(%0#1) {onnx_node_name = "onnx.Relu_2"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1978+
%3 = "onnx.Relu"(%0#2) {onnx_node_name = "onnx.Relu_3"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1979+
onnx.Return %1, %2, %3 : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
1980+
}
1981+
// CHECK-LABEL: func.func @test_split_relu_movement_not_all_equal
1982+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
1983+
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
1984+
// CHECK: [[VAR_1_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
1985+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Relu"([[VAR_1_]]#0) {onnx_node_name = "onnx.Relu_1"} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
1986+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.LeakyRelu"([[VAR_1_]]#1) {alpha = 0.00999999977 : f32, onnx_node_name = "onnx.Relu_2"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1987+
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Relu"([[VAR_1_]]#2) {onnx_node_name = "onnx.Relu_3"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1988+
// CHECK: onnx.Return [[VAR_2_]], [[VAR_3_]], [[VAR_4_]] : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
1989+
// CHECK: }
1990+
1991+
// -----
1992+
func.func @test_split_leakyrelu_movement(%arg0: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
1993+
%cst = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
1994+
%0:3 = "onnx.Split"(%arg0, %cst) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
1995+
%1 = "onnx.LeakyRelu"(%0#0) {onnx_node_name = "onnx.LRelu_1", alpha = 0.2 : f32} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
1996+
%2 = "onnx.LeakyRelu"(%0#1) {onnx_node_name = "onnx.LRelu_2", alpha = 0.2 : f32} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1997+
%3 = "onnx.LeakyRelu"(%0#2) {onnx_node_name = "onnx.LRelu_3", alpha = 0.2 : f32} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1998+
onnx.Return %1, %2, %3 : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
1999+
}
2000+
// CHECK-LABEL: func.func @test_split_leakyrelu_movement
2001+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
2002+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
2003+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.LeakyRelu"([[PARAM_0_]]) {alpha = 2.000000e-01 : f32, onnx_node_name = "onnx.LRelu_1"} : (tensor<1x8x2xf32>) -> tensor<1x8x2xf32>
2004+
// CHECK: [[VAR_2_:%.+]]:3 = "onnx.Split"([[VAR_1_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
2005+
// CHECK: onnx.Return [[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2 : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
2006+
// CHECK: }
2007+
2008+
// -----
2009+
func.func @test_split_leakyrelu_movement_different_alpha(%arg0: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
2010+
%cst = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
2011+
%0:3 = "onnx.Split"(%arg0, %cst) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
2012+
%1 = "onnx.LeakyRelu"(%0#0) {onnx_node_name = "onnx.LRelu_1", alpha = 0.2 : f32} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
2013+
%2 = "onnx.LeakyRelu"(%0#1) {onnx_node_name = "onnx.LRelu_2", alpha = 0.2 : f32} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
2014+
%3 = "onnx.LeakyRelu"(%0#2) {onnx_node_name = "onnx.LRelu_3", alpha = 0.3 : f32} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
2015+
onnx.Return %1, %2, %3 : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
2016+
}
2017+
// CHECK-LABEL: func.func @test_split_leakyrelu_movement_different_alpha
2018+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
2019+
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
2020+
// CHECK: [[VAR_1_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
2021+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.LeakyRelu"([[VAR_1_]]#0) {alpha = 2.000000e-01 : f32, onnx_node_name = "onnx.LRelu_1"} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
2022+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.LeakyRelu"([[VAR_1_]]#1) {alpha = 2.000000e-01 : f32, onnx_node_name = "onnx.LRelu_2"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
2023+
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.LeakyRelu"([[VAR_1_]]#2) {alpha = 3.000000e-01 : f32, onnx_node_name = "onnx.LRelu_3"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
2024+
// CHECK: onnx.Return [[VAR_2_]], [[VAR_3_]], [[VAR_4_]] : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
2025+
// CHECK: }
2026+
19552027
// -----
19562028

19572029
// Not rewriting since the operand in ConcatOp is neither DimOp nor ConstantOp.

test/mlir/onnx/onnx_canonicalization_without_shape_inference.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,23 @@ func.func @test_batchnormv9_f16_dynamic(%arg0: tensor<100x3x?x?xf16>) -> (tensor
230230
// CHECK: [[Y_:%.+]], [[VAR_running_mean_:%.+]], [[VAR_running_var_:%.+]] = "onnx.BatchNormalization"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]], [[VAR_2_]], [[VAR_3_]]) {epsilon = 1.00000007E-5 : f32, momentum = 1.000000e-03 : f32, training_mode = 0 : si64} : (tensor<100x3x?x?xf16>, tensor<3xf16>, tensor<3xf16>, tensor<3xf16>, tensor<3xf16>) -> (tensor<*xf16>, tensor<*xf16>, tensor<*xf16>)
231231
// CHECK: return [[Y_]], [[VAR_running_mean_]], [[VAR_running_var_]] : tensor<*xf16>, tensor<*xf16>, tensor<*xf16>
232232
// CHECK: }
233+
234+
// -----
235+
func.func @test_split_relu_movement_missing_shape(%arg0: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<*xf32>, tensor<1x3x2xf32>) {
236+
%cst = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
237+
%0:3 = "onnx.Split"(%arg0, %cst) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
238+
%1 = "onnx.Relu"(%0#0) {onnx_node_name = "onnx.Relu_1"} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
239+
%2 = "onnx.Relu"(%0#1) {onnx_node_name = "onnx.Relu_2"} : (tensor<1x3x2xf32>) -> tensor<*xf32>
240+
%3 = "onnx.Relu"(%0#2) {onnx_node_name = "onnx.Relu_3"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
241+
onnx.Return %1, %2, %3 : tensor<1x2x2xf32>, tensor<*xf32>, tensor<1x3x2xf32>
242+
}
243+
244+
// CHECK-LABEL: func.func @test_split_relu_movement_missing_shape
245+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<*xf32>, tensor<1x3x2xf32>) {
246+
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
247+
// CHECK: [[VAR_1_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
248+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Relu"([[VAR_1_]]#0) {onnx_node_name = "onnx.Relu_1"} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
249+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Relu"([[VAR_1_]]#1) {onnx_node_name = "onnx.Relu_2"} : (tensor<1x3x2xf32>) -> tensor<*xf32>
250+
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Relu"([[VAR_1_]]#2) {onnx_node_name = "onnx.Relu_3"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
251+
// CHECK: onnx.Return [[VAR_2_]], [[VAR_3_]], [[VAR_4_]] : tensor<1x2x2xf32>, tensor<*xf32>, tensor<1x3x2xf32>
252+
// CHECK: }

test/mlir/onnx/onnx_recompose_combine_parallel_conv.mlir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
// RUN: onnx-mlir --useOnnxModelTypes=false --EmitONNXIR --printIR %s | FileCheck %s
22

3+
// Temporary disabled
4+
// XFAIL: *
5+
36
func.func @test_conv_concat_simple(%arg0: tensor<1x1x512x512xf32>) -> tensor<1x64x512x512xf32> {
47
%0 = onnx.Constant dense<0.00999999977> : tensor<32x1x3x3xf32>
58
%1 = onnx.Constant dense<0.00999999977> : tensor<32xf32>

test/mlir/onnx/onnx_recompose_locations.mlir

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -46,28 +46,28 @@ func.func @test_combine_conv_split(%arg0: tensor<1x1x512x512xf32>) -> tensor<1x9
4646
return %12 : tensor<1x96x512x512xf32>
4747

4848
// CHECK-LABEL: func.func @test_combine_conv_split
49-
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x512x512xf32>
50-
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<32> : tensor<3xi64>
51-
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32x1x3x3xf32>
52-
// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32xf32>
53-
// CHECK-NOT: separator of consecutive DAGs
54-
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_1_]], [[VAR_1_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<32x1x3x3xf32>, tensor<32x1x3x3xf32>, tensor<32x1x3x3xf32>) -> tensor<96x1x3x3xf32> loc([[LOC_FUSED:#.+]])
55-
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_2_]], [[VAR_2_]]) {axis = 0 : si64} : (tensor<32xf32>, tensor<32xf32>, tensor<32xf32>) -> tensor<96xf32> loc([[LOC_FUSED:#.+]])
56-
// CHECK: [[VAR_5_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_3_]], [[VAR_4_]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<96x1x3x3xf32>, tensor<96xf32>) -> tensor<1x96x512x512xf32> loc([[LOC_FUSED:#.+]])
57-
// CHECK: [[VAR_6_:%.+]]:3 = "onnx.Split"([[VAR_5_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x96x512x512xf32>, tensor<3xi64>) -> (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) loc([[LOC_FUSED:#.+]])
58-
// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Relu"([[VAR_6_]]#2) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_RELU:#.+]])
59-
// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Sigmoid"([[VAR_6_]]#1) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_SIGMOID:#.+]])
60-
// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Tanh"([[VAR_6_]]#0) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_TANH:#.+]])
61-
// CHECK: [[VAR_10_:%.+]] = "onnx.Concat"([[VAR_7_]], [[VAR_8_]], [[VAR_9_]]) {axis = 1 : si64} : (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) -> tensor<1x96x512x512xf32> loc([[LOC_ORIGINAL_CONCAT:#.+]])
62-
// CHECK: return [[VAR_10_]] : tensor<1x96x512x512xf32>
63-
// CHECK: }
49+
// XFAIL-CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x512x512xf32>
50+
// XFAIL-CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<32> : tensor<3xi64>
51+
// XFAIL-CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32x1x3x3xf32>
52+
// XFAIL-CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32xf32>
53+
// XFAIL-CHECK-NOT: separator of consecutive DAGs
54+
// XFAIL-CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_1_]], [[VAR_1_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<32x1x3x3xf32>, tensor<32x1x3x3xf32>, tensor<32x1x3x3xf32>) -> tensor<96x1x3x3xf32> loc([[LOC_FUSED:#.+]])
55+
// XFAIL-CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_2_]], [[VAR_2_]]) {axis = 0 : si64} : (tensor<32xf32>, tensor<32xf32>, tensor<32xf32>) -> tensor<96xf32> loc([[LOC_FUSED:#.+]])
56+
// XFAIL-CHECK: [[VAR_5_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_3_]], [[VAR_4_]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<96x1x3x3xf32>, tensor<96xf32>) -> tensor<1x96x512x512xf32> loc([[LOC_FUSED:#.+]])
57+
// XFAIL-CHECK: [[VAR_6_:%.+]]:3 = "onnx.Split"([[VAR_5_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x96x512x512xf32>, tensor<3xi64>) -> (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) loc([[LOC_FUSED:#.+]])
58+
// XFAIL-CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Relu"([[VAR_6_]]#2) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_RELU:#.+]])
59+
// XFAIL-CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Sigmoid"([[VAR_6_]]#1) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_SIGMOID:#.+]])
60+
// XFAIL-CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Tanh"([[VAR_6_]]#0) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_TANH:#.+]])
61+
// XFAIL-CHECK: [[VAR_10_:%.+]] = "onnx.Concat"([[VAR_7_]], [[VAR_8_]], [[VAR_9_]]) {axis = 1 : si64} : (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) -> tensor<1x96x512x512xf32> loc([[LOC_ORIGINAL_CONCAT:#.+]])
62+
// XFAIL-CHECK: return [[VAR_10_]] : tensor<1x96x512x512xf32>
63+
// XFAIL-CHECK: }
6464

65-
// CHECK-DAG: [[LOC_RELU:#.+]] = loc("relu")
66-
// CHECK-DAG: [[LOC_SIGMOID:#.+]] = loc("sigmoid")
67-
// CHECK-DAG: [[LOC_TANH:#.+]] = loc("tanh")
68-
// CHECK-DAG: [[LOC_ORIGINAL_CONCAT:#.+]] = loc("concat")
69-
// CHECK-DAG: [[LOC_CONV1:#.+]] = loc("conv1")
70-
// CHECK-DAG: [[LOC_CONV2:#.+]] = loc("conv2")
71-
// CHECK-DAG: [[LOC_CONV3:#.+]] = loc("conv3")
72-
// CHECK-DAG: [[LOC_FUSED]] = loc(fused[[[LOC_CONV1]], [[LOC_CONV3]], [[LOC_CONV2]]])
65+
// XFAIL-CHECK-DAG: [[LOC_RELU:#.+]] = loc("relu")
66+
// XFAIL-CHECK-DAG: [[LOC_SIGMOID:#.+]] = loc("sigmoid")
67+
// XFAIL-CHECK-DAG: [[LOC_TANH:#.+]] = loc("tanh")
68+
// XFAIL-CHECK-DAG: [[LOC_ORIGINAL_CONCAT:#.+]] = loc("concat")
69+
// XFAIL-CHECK-DAG: [[LOC_CONV1:#.+]] = loc("conv1")
70+
// XFAIL-CHECK-DAG: [[LOC_CONV2:#.+]] = loc("conv2")
71+
// XFAIL-CHECK-DAG: [[LOC_CONV3:#.+]] = loc("conv3")
72+
// XFAIL-CHECK-DAG: [[LOC_FUSED]] = loc(fused[[[LOC_CONV1]], [[LOC_CONV3]], [[LOC_CONV2]]])
7373
}

utils/gen_onnx_mlir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@
358358
"Resize",
359359
"RNN",
360360
"Shape",
361+
"Split",
361362
"Size",
362363
"SoftmaxV11",
363364
"SpaceToDepth",

0 commit comments

Comments
 (0)