Skip to content

Commit 66f19d8

Browse files
authored
Merge pull request #383 from Xilinx/jrickert.split-relu-movement
Add an optimization pattern to move Relu and LeakyRelu before Split operations.
2 parents e8ccd0a + 3fe791a commit 66f19d8

File tree

5 files changed

+170
-0
lines changed

5 files changed

+170
-0
lines changed

src/Dialect/ONNX/ONNXOps.td.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9233,6 +9233,7 @@ def ONNXSpaceToDepthOp:ONNX_Op<"SpaceToDepth",
92339233

92349234
def ONNXSplitOp:ONNX_Op<"Split",
92359235
[Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
9236+
let hasCanonicalizer = 1;
92369237
let summary = "ONNX Split operation";
92379238
let description = [{
92389239
Split a tensor into a list of tensors, along the specified 'axis'.

src/Dialect/ONNX/ONNXOps/Canonicalize.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2096,6 +2096,75 @@ struct RemoveBatchNormPattern
20962096
}
20972097
};
20982098

2099+
// "Pulls" Relu-like operations up through a SplitOp
2100+
struct PullReluLikeOpsThroughSplitPattern
2101+
: public OpRewritePattern<ONNXSplitOp> {
2102+
using OpRewritePattern<ONNXSplitOp>::OpRewritePattern;
2103+
2104+
LogicalResult matchAndRewrite(
2105+
ONNXSplitOp splitOp, PatternRewriter &rewriter) const final {
2106+
2107+
Operation *firstUser = nullptr;
2108+
SmallVector<Operation *> reluLikeOps;
2109+
Location newLoc = rewriter.getUnknownLoc();
2110+
2111+
const auto areFilteredAttrsEqual = [](Operation *op1, Operation *op2) {
2112+
DenseMap<StringRef, Attribute> filteredAttrs1;
2113+
DenseMap<StringRef, Attribute> filteredAttrs2;
2114+
for (const auto &attr : op1->getAttrs()) {
2115+
if (attr.getName() != "onnx_node_name") {
2116+
filteredAttrs1[attr.getName()] = attr.getValue();
2117+
}
2118+
}
2119+
for (const auto &attr : op2->getAttrs()) {
2120+
if (attr.getName() != "onnx_node_name") {
2121+
filteredAttrs2[attr.getName()] = attr.getValue();
2122+
}
2123+
}
2124+
return filteredAttrs1 == filteredAttrs2;
2125+
};
2126+
2127+
for (Operation *op : splitOp->getUsers()) {
2128+
// TODO: This pattern could be more generic, for all unary, elementwise
2129+
// ops. Having a trait for them would make this easier.
2130+
if (!isa<ONNXReluOp, ONNXLeakyReluOp>(op)) {
2131+
return rewriter.notifyMatchFailure(
2132+
splitOp, "SplitOp must be used by a Relu-like op");
2133+
}
2134+
if (op->getOperand(0).getType() != op->getResult(0).getType()) {
2135+
// This could happen if shape inference did not run
2136+
return rewriter.notifyMatchFailure(
2137+
splitOp, "Relu-like op must have same input and output type");
2138+
}
2139+
if (!firstUser) {
2140+
firstUser = op;
2141+
} else {
2142+
if (firstUser->getName() != op->getName() ||
2143+
!areFilteredAttrsEqual(firstUser, op)) {
2144+
return rewriter.notifyMatchFailure(splitOp,
2145+
"SplitOp must be used by Relu-like ops of the same type "
2146+
"and attributes");
2147+
}
2148+
}
2149+
reluLikeOps.push_back(op);
2150+
newLoc = rewriter.getFusedLoc({newLoc, op->getLoc()});
2151+
}
2152+
rewriter.setInsertionPoint(splitOp);
2153+
auto *newRelu = rewriter.clone(*reluLikeOps.front());
2154+
rewriter.modifyOpInPlace(newRelu, [&]() {
2155+
newRelu->setOperand(0, splitOp.getOperand(0));
2156+
newRelu->getResult(0).setType(splitOp.getOperand(0).getType());
2157+
newRelu->setLoc(newLoc);
2158+
});
2159+
rewriter.modifyOpInPlace(
2160+
splitOp, [&]() { splitOp->setOperand(0, newRelu->getResult(0)); });
2161+
for (Operation *op : reluLikeOps) {
2162+
rewriter.replaceOp(op, op->getOperands());
2163+
}
2164+
return success();
2165+
}
2166+
};
2167+
20992168
// =============================================================================
21002169
/// Register optimization patterns as "canonicalization" patterns.
21012170
/// Add op to OpsWithCanonicalizer in gen_onnx_mlir.py to activate.
@@ -2369,6 +2438,13 @@ void ONNXSpaceToDepthOp::getCanonicalizationPatterns(
23692438
results.insert<RemoveSpaceToDepthDepthToSpacePattern>(context);
23702439
}
23712440

2441+
/// on the ONNXSplitOp
2442+
void ONNXSplitOp::getCanonicalizationPatterns(
2443+
RewritePatternSet &results, MLIRContext *context) {
2444+
results.insert<PullReluLikeOpsThroughSplitPattern>(context);
2445+
;
2446+
}
2447+
23722448
/// on the ONNXSqueezeOp.
23732449
void ONNXSqueezeOp::getCanonicalizationPatterns(
23742450
RewritePatternSet &result, MLIRContext *context) {

test/mlir/onnx/onnx_canonicalization.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1952,6 +1952,78 @@ return %2 : tensor<1x12x4xf32>
19521952

19531953
}
19541954

1955+
// -----
1956+
func.func @test_split_relu_movement(%arg0: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
1957+
%cst = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
1958+
%0:3 = "onnx.Split"(%arg0, %cst) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
1959+
%1 = "onnx.Relu"(%0#0) {onnx_node_name = "onnx.Relu_1"} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
1960+
%2 = "onnx.Relu"(%0#1) {onnx_node_name = "onnx.Relu_2"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1961+
%3 = "onnx.Relu"(%0#2) {onnx_node_name = "onnx.Relu_3"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1962+
onnx.Return %1, %2, %3 : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
1963+
}
1964+
// CHECK-LABEL: func.func @test_split_relu_movement
1965+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
1966+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
1967+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Relu"([[PARAM_0_]]) {onnx_node_name = "onnx.Relu_1"} : (tensor<1x8x2xf32>) -> tensor<1x8x2xf32>
1968+
// CHECK: [[VAR_2_:%.+]]:3 = "onnx.Split"([[VAR_1_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
1969+
// CHECK: onnx.Return [[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2 : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
1970+
// CHECK: }
1971+
1972+
// -----
1973+
func.func @test_split_relu_movement_not_all_equal(%arg0: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
1974+
%cst = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
1975+
%0:3 = "onnx.Split"(%arg0, %cst) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
1976+
%1 = "onnx.Relu"(%0#0) {onnx_node_name = "onnx.Relu_1"} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
1977+
%2 = "onnx.LeakyRelu"(%0#1) {onnx_node_name = "onnx.Relu_2"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1978+
%3 = "onnx.Relu"(%0#2) {onnx_node_name = "onnx.Relu_3"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1979+
onnx.Return %1, %2, %3 : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
1980+
}
1981+
// CHECK-LABEL: func.func @test_split_relu_movement_not_all_equal
1982+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
1983+
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
1984+
// CHECK: [[VAR_1_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
1985+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Relu"([[VAR_1_]]#0) {onnx_node_name = "onnx.Relu_1"} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
1986+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.LeakyRelu"([[VAR_1_]]#1) {alpha = 0.00999999977 : f32, onnx_node_name = "onnx.Relu_2"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1987+
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Relu"([[VAR_1_]]#2) {onnx_node_name = "onnx.Relu_3"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1988+
// CHECK: onnx.Return [[VAR_2_]], [[VAR_3_]], [[VAR_4_]] : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
1989+
// CHECK: }
1990+
1991+
// -----
1992+
func.func @test_split_leakyrelu_movement(%arg0: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
1993+
%cst = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
1994+
%0:3 = "onnx.Split"(%arg0, %cst) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
1995+
%1 = "onnx.LeakyRelu"(%0#0) {onnx_node_name = "onnx.LRelu_1", alpha = 0.2 : f32} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
1996+
%2 = "onnx.LeakyRelu"(%0#1) {onnx_node_name = "onnx.LRelu_2", alpha = 0.2 : f32} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1997+
%3 = "onnx.LeakyRelu"(%0#2) {onnx_node_name = "onnx.LRelu_3", alpha = 0.2 : f32} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
1998+
onnx.Return %1, %2, %3 : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
1999+
}
2000+
// CHECK-LABEL: func.func @test_split_leakyrelu_movement
2001+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
2002+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
2003+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.LeakyRelu"([[PARAM_0_]]) {alpha = 2.000000e-01 : f32, onnx_node_name = "onnx.LRelu_1"} : (tensor<1x8x2xf32>) -> tensor<1x8x2xf32>
2004+
// CHECK: [[VAR_2_:%.+]]:3 = "onnx.Split"([[VAR_1_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
2005+
// CHECK: onnx.Return [[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2 : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
2006+
// CHECK: }
2007+
2008+
// -----
2009+
func.func @test_split_leakyrelu_movement_different_alpha(%arg0: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
2010+
%cst = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
2011+
%0:3 = "onnx.Split"(%arg0, %cst) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
2012+
%1 = "onnx.LeakyRelu"(%0#0) {onnx_node_name = "onnx.LRelu_1", alpha = 0.2 : f32} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
2013+
%2 = "onnx.LeakyRelu"(%0#1) {onnx_node_name = "onnx.LRelu_2", alpha = 0.2 : f32} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
2014+
%3 = "onnx.LeakyRelu"(%0#2) {onnx_node_name = "onnx.LRelu_3", alpha = 0.3 : f32} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
2015+
onnx.Return %1, %2, %3 : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
2016+
}
2017+
// CHECK-LABEL: func.func @test_split_leakyrelu_movement_different_alpha
2018+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) {
2019+
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
2020+
// CHECK: [[VAR_1_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
2021+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.LeakyRelu"([[VAR_1_]]#0) {alpha = 2.000000e-01 : f32, onnx_node_name = "onnx.LRelu_1"} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
2022+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.LeakyRelu"([[VAR_1_]]#1) {alpha = 2.000000e-01 : f32, onnx_node_name = "onnx.LRelu_2"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
2023+
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.LeakyRelu"([[VAR_1_]]#2) {alpha = 3.000000e-01 : f32, onnx_node_name = "onnx.LRelu_3"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
2024+
// CHECK: onnx.Return [[VAR_2_]], [[VAR_3_]], [[VAR_4_]] : tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>
2025+
// CHECK: }
2026+
19552027
// -----
19562028

19572029
// Not rewriting since the operand in ConcatOp is neither DimOp nor ConstantOp.

test/mlir/onnx/onnx_canonicalization_without_shape_inference.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,23 @@ func.func @test_batchnormv9_f16_dynamic(%arg0: tensor<100x3x?x?xf16>) -> (tensor
230230
// CHECK: [[Y_:%.+]], [[VAR_running_mean_:%.+]], [[VAR_running_var_:%.+]] = "onnx.BatchNormalization"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]], [[VAR_2_]], [[VAR_3_]]) {epsilon = 1.00000007E-5 : f32, momentum = 1.000000e-03 : f32, training_mode = 0 : si64} : (tensor<100x3x?x?xf16>, tensor<3xf16>, tensor<3xf16>, tensor<3xf16>, tensor<3xf16>) -> (tensor<*xf16>, tensor<*xf16>, tensor<*xf16>)
231231
// CHECK: return [[Y_]], [[VAR_running_mean_]], [[VAR_running_var_]] : tensor<*xf16>, tensor<*xf16>, tensor<*xf16>
232232
// CHECK: }
233+
234+
// -----
235+
func.func @test_split_relu_movement_missing_shape(%arg0: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<*xf32>, tensor<1x3x2xf32>) {
236+
%cst = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
237+
%0:3 = "onnx.Split"(%arg0, %cst) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
238+
%1 = "onnx.Relu"(%0#0) {onnx_node_name = "onnx.Relu_1"} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
239+
%2 = "onnx.Relu"(%0#1) {onnx_node_name = "onnx.Relu_2"} : (tensor<1x3x2xf32>) -> tensor<*xf32>
240+
%3 = "onnx.Relu"(%0#2) {onnx_node_name = "onnx.Relu_3"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
241+
onnx.Return %1, %2, %3 : tensor<1x2x2xf32>, tensor<*xf32>, tensor<1x3x2xf32>
242+
}
243+
244+
// CHECK-LABEL: func.func @test_split_relu_movement_missing_shape
245+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<*xf32>, tensor<1x3x2xf32>) {
246+
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64>
247+
// CHECK: [[VAR_1_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)
248+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Relu"([[VAR_1_]]#0) {onnx_node_name = "onnx.Relu_1"} : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32>
249+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Relu"([[VAR_1_]]#1) {onnx_node_name = "onnx.Relu_2"} : (tensor<1x3x2xf32>) -> tensor<*xf32>
250+
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Relu"([[VAR_1_]]#2) {onnx_node_name = "onnx.Relu_3"} : (tensor<1x3x2xf32>) -> tensor<1x3x2xf32>
251+
// CHECK: onnx.Return [[VAR_2_]], [[VAR_3_]], [[VAR_4_]] : tensor<1x2x2xf32>, tensor<*xf32>, tensor<1x3x2xf32>
252+
// CHECK: }

utils/gen_onnx_mlir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@
358358
"Resize",
359359
"RNN",
360360
"Shape",
361+
"Split",
361362
"Size",
362363
"SoftmaxV11",
363364
"SpaceToDepth",

0 commit comments

Comments
 (0)