Skip to content

Commit fb69354

Browse files
committed
fix: address comments and add more tests
1 parent 4555efd commit fb69354

File tree

4 files changed

+107
-12
lines changed

4 files changed

+107
-12
lines changed

src/Compiler/OnnxToMlirPasses.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
5151
/*enableQuarkQuantizedOpsLegalization=*/false,
5252
opts.enableConvTransposeDecompose,
5353
opts.enableConvTransposeDecomposeToPhasedConv,
54-
opts.enableConvTranspose1dDecomposeToPhasedConv));
54+
opts.enableConvTranspose1dDecomposeToPhasedConv,
55+
opts.enableRecomposeLayernormByTranspose));
5556
}
5657
} else {
5758
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
@@ -107,7 +108,8 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
107108
!opts.disableRecomposeOption, opts.enableQuarkQuantizedLegalization,
108109
opts.enableConvTransposeDecompose,
109110
opts.enableConvTransposeDecomposeToPhasedConv,
110-
opts.enableConvTranspose1dDecomposeToPhasedConv));
111+
opts.enableConvTranspose1dDecomposeToPhasedConv,
112+
opts.enableRecomposeLayernormByTranspose));
111113
} else {
112114
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
113115
pm.addPass(mlir::createCanonicalizerPass());

src/Dialect/ONNX/ONNXOps/Canonicalize.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -587,22 +587,23 @@ class PropagateReshapeThroughBinaryOpPattern
587587

588588
// This pattern bubbles up AddOp through transpose to keep the bias Add
589589
// operation right after LN_type op. This will helps the other patterns fold the
590-
// add into the operands of LayerNorm.
590+
// add into the operands of a Norm operator.
591591
//
592-
// From: LayerNorm
592+
// From:
593+
// Norm operator
593594
// |
594595
// Transpose
595596
// |
596597
// Add
597598
//
598599
// To:
599-
// LayerNorm
600+
// Norm operator
600601
// |
601602
// Add
602603
// |
603604
// Transpose
604605
template <typename LN_TYPE>
605-
class BubbleUpBiasForLayerNormPattern : public OpRewritePattern<ONNXAddOp> {
606+
class BubbleUpBiasForNormOpPattern : public OpRewritePattern<ONNXAddOp> {
606607
public:
607608
using OpRewritePattern<ONNXAddOp>::OpRewritePattern;
608609

@@ -2487,9 +2488,9 @@ void ONNXAddOp::getCanonicalizationPatterns(
24872488
PropagateBiasIntoLayerNormRewritePattern<ONNXRMSLayerNormalizationOp>>(
24882489
context);
24892490
results.insert<PropagateReshapeThroughBinaryOpPattern<ONNXAddOp>>(context);
2490-
results.insert<BubbleUpBiasForLayerNormPattern<ONNXLayerNormalizationOp>>(
2491+
results.insert<BubbleUpBiasForNormOpPattern<ONNXLayerNormalizationOp>>(
24912492
context);
2492-
results.insert<BubbleUpBiasForLayerNormPattern<ONNXRMSLayerNormalizationOp>>(
2493+
results.insert<BubbleUpBiasForNormOpPattern<ONNXRMSLayerNormalizationOp>>(
24932494
context);
24942495
}
24952496

src/Dialect/ONNX/Transforms/Recompose.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern<ONNXMulOp> {
176176
Value res;
177177

178178
if constexpr (RecomposeLayernormByTranspose) {
179+
if (!hasShapeAndRank(scale))
180+
return rewriter.notifyMatchFailure(
181+
mulOp, "the scale doesn't have shape or rank");
179182
// if the permutation is empty, nothing is needed to be permuted.
180183
// Otherwise, both input and scale must be transposed.
181184
if (!permutation.empty()) {

test/mlir/onnx/onnx_recompose_layernorm_with_transpose.mlir

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
// RUN: onnx-mlir-opt --recompose-onnx="recompose-layernorm-by-transpose" --canonicalize %s -split-input-file | FileCheck %s
1+
// RUN: onnx-mlir-opt -split-input-file --recompose-onnx="recompose-layernorm-by-transpose" --canonicalize %s | FileCheck %s
22

3-
func.func @main_graph(%arg0: tensor<1x4x128x128xf32>) -> tensor<1x4x128x128xf32> {
3+
func.func @decomposition_to_layernorm_axis_1(%arg0: tensor<1x4x128x128xf32>) -> tensor<1x4x128x128xf32> {
44
%2 = onnx.Constant dense<1.000000e+00> : tensor<f32>
55
%4 = onnx.Constant dense<9.99999997E-7> : tensor<f32>
66
%5 = onnx.Constant dense<[[[0.0970484465]], [[0.0882187337]], [[0.135120019]], [[0.14907673]]]> : tensor<4x1x1xf32>
@@ -16,7 +16,7 @@ func.func @main_graph(%arg0: tensor<1x4x128x128xf32>) -> tensor<1x4x128x128xf32>
1616
%17 = "onnx.Add"(%16, %6) {onnx_node_name = "/mask_downscaling/mask_downscaling.1/Add_1"} : (tensor<1x4x128x128xf32>, tensor<4x1x1xf32>) -> tensor<1x4x128x128xf32>
1717
return %17 : tensor<1x4x128x128xf32>
1818
}
19-
// CHECK-LABEL: func.func @main_graph
19+
// CHECK-LABEL: func.func @decomposition_to_layernorm_axis_1
2020
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4x128x128xf32>) -> tensor<1x4x128x128xf32> {
2121
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<{{.}}{{.}}[0.0970484465]{{.}}, {{.}}[0.0882187337]{{.}}, {{.}}[0.135120019]{{.}}, {{.}}[0.14907673]{{.}}{{.}}> : tensor<4x1x1xf32>
2222
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<{{.}}{{.}}[0.191972837]{{.}}, {{.}}[0.286264896]{{.}}, {{.}}[0.180535644]{{.}}, {{.}}[0.166878015]{{.}}{{.}}> : tensor<4x1x1xf32>
@@ -33,4 +33,93 @@ func.func @main_graph(%arg0: tensor<1x4x128x128xf32>) -> tensor<1x4x128x128xf32>
3333

3434
// -----
3535

36-
// TODO: ADD more lit tests here
36+
func.func @decomposition_to_layernorm_axis_2(%arg0: tensor<1x128x4x128xf32> {onnx.name = "in0"}) -> (tensor<1x128x4x128xf32> {onnx.name = "out"}) {
37+
%0 = onnx.Constant dense<[[[0.976699769], [0.380195737], [0.923246204], [0.261692435]]]> : tensor<1x4x1xf32>
38+
%1 = onnx.Constant dense<9.99999997E-7> : tensor<f32>
39+
%2 = onnx.Constant dense<2> : tensor<1xi64>
40+
%3 = "onnx.ReduceMean"(%arg0, %2) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64, onnx_node_name = "ReduceMean1"} : (tensor<1x128x4x128xf32>, tensor<1xi64>) -> tensor<1x128x1x128xf32>
41+
%4 = "onnx.Sub"(%arg0, %3) {onnx_node_name = "Sub2"} : (tensor<1x128x4x128xf32>, tensor<1x128x1x128xf32>) -> tensor<1x128x4x128xf32>
42+
%5 = "onnx.Mul"(%4, %4) {onnx_node_name = "Mul3"} : (tensor<1x128x4x128xf32>, tensor<1x128x4x128xf32>) -> tensor<1x128x4x128xf32>
43+
%6 = "onnx.ReduceMean"(%5, %2) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64, onnx_node_name = "ReduceMean4"} : (tensor<1x128x4x128xf32>, tensor<1xi64>) -> tensor<1x128x1x128xf32>
44+
%7 = "onnx.Add"(%6, %1) {onnx_node_name = "Add6"} : (tensor<1x128x1x128xf32>, tensor<f32>) -> tensor<1x128x1x128xf32>
45+
%8 = "onnx.Sqrt"(%7) {onnx_node_name = "Sqrt7"} : (tensor<1x128x1x128xf32>) -> tensor<1x128x1x128xf32>
46+
%9 = "onnx.Div"(%4, %8) {onnx_node_name = "Div8"} : (tensor<1x128x4x128xf32>, tensor<1x128x1x128xf32>) -> tensor<1x128x4x128xf32>
47+
%10 = "onnx.Mul"(%0, %9) {onnx_node_name = "Mul10"} : (tensor<1x4x1xf32>, tensor<1x128x4x128xf32>) -> tensor<1x128x4x128xf32>
48+
%11 = "onnx.Add"(%10, %0) {onnx_node_name = "Add12"} : (tensor<1x128x4x128xf32>, tensor<1x4x1xf32>) -> tensor<1x128x4x128xf32>
49+
onnx.Return %11 : tensor<1x128x4x128xf32>
50+
}
51+
// CHECK-LABEL: func.func @decomposition_to_layernorm_axis_2
52+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x128x4x128xf32> {onnx.name = "in0"}) -> (tensor<1x128x4x128xf32> {onnx.name = "out"}) {
53+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<{{.}}{{.}}[0.976699769], [0.380195737], [0.923246204], [0.261692435]{{.}}{{.}}> : tensor<1x4x1xf32>
54+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1, 1, 4, 1]> : tensor<4xi64>
55+
// CHECK: [[VAR_2_:%.+]] = "onnx.Reshape"([[VAR_0_]], [[VAR_1_]]) {allowzero = 0 : si64} : (tensor<1x4x1xf32>, tensor<4xi64>) -> tensor<1x1x4x1xf32>
56+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Transpose"([[VAR_2_]]) {perm = [0, 1, 3, 2]} : (tensor<1x1x4x1xf32>) -> tensor<1x1x1x4xf32>
57+
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Transpose"([[PARAM_0_]]) {perm = [0, 1, 3, 2]} : (tensor<1x128x4x128xf32>) -> tensor<1x128x128x4xf32>
58+
// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Reshape"([[VAR_0_]], [[VAR_1_]]) {allowzero = 0 : si64} : (tensor<1x4x1xf32>, tensor<4xi64>) -> tensor<1x1x4x1xf32>
59+
// CHECK: [[VAR_6_:%.+]] = "onnx.Transpose"([[VAR_5_]]) {perm = [0, 1, 3, 2]} : (tensor<1x1x4x1xf32>) -> tensor<1x1x1x4xf32>
60+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_4_]], [[VAR_3_]], [[VAR_6_]]) {axis = 3 : si64, epsilon = 9.99999997E-7 : f32, stash_type = 1 : si64} : (tensor<1x128x128x4xf32>, tensor<1x1x1x4xf32>, tensor<1x1x1x4xf32>) -> (tensor<1x128x128x4xf32>, none, none)
61+
// CHECK: [[VAR_7_:%.+]] = "onnx.Transpose"([[VAR_Y_]]) {perm = [0, 1, 3, 2]} : (tensor<1x128x128x4xf32>) -> tensor<1x128x4x128xf32>
62+
// CHECK: onnx.Return [[VAR_7_]] : tensor<1x128x4x128xf32>
63+
// CHECK: }
64+
65+
// -----
66+
67+
func.func @decomposition_to_layernorm_axis_1_and_2(%arg0: tensor<1x4x128x128xf32> {onnx.name = "in0"}) -> (tensor<1x4x128x128xf32> {onnx.name = "out"}) {
68+
%0 = onnx.Constant dense<[[[0.976699769]], [[0.380195737]], [[0.923246204]], [[0.261692435]]]> : tensor<4x1x1xf32>
69+
%1 = onnx.Constant dense<9.99999997E-7> : tensor<f32>
70+
%2 = onnx.Constant dense<[1, 2]> : tensor<2xi64>
71+
%3 = "onnx.ReduceMean"(%arg0, %2) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64, onnx_node_name = "ReduceMean1"} : (tensor<1x4x128x128xf32>, tensor<2xi64>) -> tensor<1x1x1x128xf32>
72+
%4 = "onnx.Sub"(%arg0, %3) {onnx_node_name = "Sub2"} : (tensor<1x4x128x128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x4x128x128xf32>
73+
%5 = "onnx.Mul"(%4, %4) {onnx_node_name = "Mul3"} : (tensor<1x4x128x128xf32>, tensor<1x4x128x128xf32>) -> tensor<1x4x128x128xf32>
74+
%6 = "onnx.ReduceMean"(%5, %2) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64, onnx_node_name = "ReduceMean4"} : (tensor<1x4x128x128xf32>, tensor<2xi64>) -> tensor<1x1x1x128xf32>
75+
%7 = "onnx.Add"(%6, %1) {onnx_node_name = "Add6"} : (tensor<1x1x1x128xf32>, tensor<f32>) -> tensor<1x1x1x128xf32>
76+
%8 = "onnx.Sqrt"(%7) {onnx_node_name = "Sqrt7"} : (tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32>
77+
%9 = "onnx.Div"(%4, %8) {onnx_node_name = "Div8"} : (tensor<1x4x128x128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x4x128x128xf32>
78+
%10 = "onnx.Mul"(%0, %9) {onnx_node_name = "Mul10"} : (tensor<4x1x1xf32>, tensor<1x4x128x128xf32>) -> tensor<1x4x128x128xf32>
79+
%11 = "onnx.Add"(%10, %0) {onnx_node_name = "Add12"} : (tensor<1x4x128x128xf32>, tensor<4x1x1xf32>) -> tensor<1x4x128x128xf32>
80+
onnx.Return %11 : tensor<1x4x128x128xf32>
81+
}
82+
// CHECK-LABEL: func.func @decomposition_to_layernorm_axis_1_and_2
83+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4x128x128xf32> {onnx.name = "in0"}) -> (tensor<1x4x128x128xf32> {onnx.name = "out"}) {
84+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<{{.}}{{.}}[0.976699769]{{.}}, {{.}}[0.380195737]{{.}}, {{.}}[0.923246204]{{.}}, {{.}}[0.261692435]{{.}}{{.}}> : tensor<4x1x1xf32>
85+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1, 4, 1, 1]> : tensor<4xi64>
86+
// CHECK: [[VAR_2_:%.+]] = "onnx.Reshape"([[VAR_0_]], [[VAR_1_]]) {allowzero = 0 : si64} : (tensor<4x1x1xf32>, tensor<4xi64>) -> tensor<1x4x1x1xf32>
87+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Transpose"([[VAR_2_]]) {perm = [0, 3, 1, 2]} : (tensor<1x4x1x1xf32>) -> tensor<1x1x4x1xf32>
88+
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Transpose"([[PARAM_0_]]) {perm = [0, 3, 1, 2]} : (tensor<1x4x128x128xf32>) -> tensor<1x128x4x128xf32>
89+
// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Reshape"([[VAR_0_]], [[VAR_1_]]) {allowzero = 0 : si64} : (tensor<4x1x1xf32>, tensor<4xi64>) -> tensor<1x4x1x1xf32>
90+
// CHECK: [[VAR_6_:%.+]] = "onnx.Transpose"([[VAR_5_]]) {perm = [0, 3, 1, 2]} : (tensor<1x4x1x1xf32>) -> tensor<1x1x4x1xf32>
91+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_4_]], [[VAR_3_]], [[VAR_6_]]) {axis = 2 : si64, epsilon = 9.99999997E-7 : f32, stash_type = 1 : si64} : (tensor<1x128x4x128xf32>, tensor<1x1x4x1xf32>, tensor<1x1x4x1xf32>) -> (tensor<1x128x4x128xf32>, none, none)
92+
// CHECK: [[VAR_7_:%.+]] = "onnx.Transpose"([[VAR_Y_]]) {perm = [0, 2, 3, 1]} : (tensor<1x128x4x128xf32>) -> tensor<1x4x128x128xf32>
93+
// CHECK: onnx.Return [[VAR_7_]] : tensor<1x4x128x128xf32>
94+
// CHECK: }
95+
96+
// -----
97+
98+
func.func @decomposition_to_layernorm_axis_1_and_3(%arg0: tensor<1x4x128x128xf32> {onnx.name = "in0"}) -> (tensor<1x4x128x128xf32> {onnx.name = "out"}) {
99+
%0 = onnx.Constant dense<[[[0.976699769]], [[0.380195737]], [[0.923246204]], [[0.261692435]]]> : tensor<4x1x1xf32>
100+
%1 = onnx.Constant dense<9.99999997E-7> : tensor<f32>
101+
%2 = onnx.Constant dense<[1, 3]> : tensor<2xi64>
102+
%3 = "onnx.ReduceMean"(%arg0, %2) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64, onnx_node_name = "ReduceMean1"} : (tensor<1x4x128x128xf32>, tensor<2xi64>) -> tensor<1x1x128x1xf32>
103+
%4 = "onnx.Sub"(%arg0, %3) {onnx_node_name = "Sub2"} : (tensor<1x4x128x128xf32>, tensor<1x1x128x1xf32>) -> tensor<1x4x128x128xf32>
104+
%5 = "onnx.Mul"(%4, %4) {onnx_node_name = "Mul3"} : (tensor<1x4x128x128xf32>, tensor<1x4x128x128xf32>) -> tensor<1x4x128x128xf32>
105+
%6 = "onnx.ReduceMean"(%5, %2) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64, onnx_node_name = "ReduceMean4"} : (tensor<1x4x128x128xf32>, tensor<2xi64>) -> tensor<1x1x128x1xf32>
106+
%7 = "onnx.Add"(%6, %1) {onnx_node_name = "Add6"} : (tensor<1x1x128x1xf32>, tensor<f32>) -> tensor<1x1x128x1xf32>
107+
%8 = "onnx.Sqrt"(%7) {onnx_node_name = "Sqrt7"} : (tensor<1x1x128x1xf32>) -> tensor<1x1x128x1xf32>
108+
%9 = "onnx.Div"(%4, %8) {onnx_node_name = "Div8"} : (tensor<1x4x128x128xf32>, tensor<1x1x128x1xf32>) -> tensor<1x4x128x128xf32>
109+
%10 = "onnx.Mul"(%0, %9) {onnx_node_name = "Mul10"} : (tensor<4x1x1xf32>, tensor<1x4x128x128xf32>) -> tensor<1x4x128x128xf32>
110+
%11 = "onnx.Add"(%10, %0) {onnx_node_name = "Add12"} : (tensor<1x4x128x128xf32>, tensor<4x1x1xf32>) -> tensor<1x4x128x128xf32>
111+
onnx.Return %11 : tensor<1x4x128x128xf32>
112+
}
113+
// CHECK-LABEL: func.func @decomposition_to_layernorm_axis_1_and_3
114+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4x128x128xf32> {onnx.name = "in0"}) -> (tensor<1x4x128x128xf32> {onnx.name = "out"}) {
115+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<{{.}}{{.}}[0.976699769]{{.}}, {{.}}[0.380195737]{{.}}, {{.}}[0.923246204]{{.}}, {{.}}[0.261692435]{{.}}{{.}}> : tensor<4x1x1xf32>
116+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1, 4, 1, 1]> : tensor<4xi64>
117+
// CHECK: [[VAR_2_:%.+]] = "onnx.Reshape"([[VAR_0_]], [[VAR_1_]]) {allowzero = 0 : si64} : (tensor<4x1x1xf32>, tensor<4xi64>) -> tensor<1x4x1x1xf32>
118+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Transpose"([[VAR_2_]]) {perm = [0, 2, 1, 3]} : (tensor<1x4x1x1xf32>) -> tensor<1x1x4x1xf32>
119+
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Transpose"([[PARAM_0_]]) {perm = [0, 2, 1, 3]} : (tensor<1x4x128x128xf32>) -> tensor<1x128x4x128xf32>
120+
// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Reshape"([[VAR_0_]], [[VAR_1_]]) {allowzero = 0 : si64} : (tensor<4x1x1xf32>, tensor<4xi64>) -> tensor<1x4x1x1xf32>
121+
// CHECK: [[VAR_6_:%.+]] = "onnx.Transpose"([[VAR_5_]]) {perm = [0, 2, 1, 3]} : (tensor<1x4x1x1xf32>) -> tensor<1x1x4x1xf32>
122+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_4_]], [[VAR_3_]], [[VAR_6_]]) {axis = 2 : si64, epsilon = 9.99999997E-7 : f32, stash_type = 1 : si64} : (tensor<1x128x4x128xf32>, tensor<1x1x4x1xf32>, tensor<1x1x4x1xf32>) -> (tensor<1x128x4x128xf32>, none, none)
123+
// CHECK: [[VAR_7_:%.+]] = "onnx.Transpose"([[VAR_Y_]]) {perm = [0, 2, 1, 3]} : (tensor<1x128x4x128xf32>) -> tensor<1x4x128x128xf32>
124+
// CHECK: onnx.Return [[VAR_7_]] : tensor<1x4x128x128xf32>
125+
// CHECK: }

0 commit comments

Comments
 (0)