1- // RUN: onnx-mlir-opt -- recompose-onnx="recompose-layernorm-by-transpose" --canonicalize %s -split-input-file | FileCheck %s
1+ // RUN: onnx-mlir-opt -split-input-file -- recompose-onnx="recompose-layernorm-by-transpose" --canonicalize %s | FileCheck %s
22
3- func.func @main_graph (%arg0: tensor <1 x4 x128 x128 xf32 >) -> tensor <1 x4 x128 x128 xf32 > {
3+ func.func @decomposition_to_layernorm_axis_1 (%arg0: tensor <1 x4 x128 x128 xf32 >) -> tensor <1 x4 x128 x128 xf32 > {
44 %2 = onnx.Constant dense <1.000000e+00 > : tensor <f32 >
55 %4 = onnx.Constant dense <9.99999997E-7 > : tensor <f32 >
66 %5 = onnx.Constant dense <[[[0.0970484465 ]], [[0.0882187337 ]], [[0.135120019 ]], [[0.14907673 ]]]> : tensor <4 x1 x1 xf32 >
@@ -16,7 +16,7 @@ func.func @main_graph(%arg0: tensor<1x4x128x128xf32>) -> tensor<1x4x128x128xf32>
1616 %17 = " onnx.Add" (%16 , %6 ) {onnx_node_name = " /mask_downscaling/mask_downscaling.1/Add_1" } : (tensor <1 x4 x128 x128 xf32 >, tensor <4 x1 x1 xf32 >) -> tensor <1 x4 x128 x128 xf32 >
1717 return %17 : tensor <1 x4 x128 x128 xf32 >
1818}
19- // CHECK-LABEL: func.func @main_graph
19+ // CHECK-LABEL: func.func @decomposition_to_layernorm_axis_1
2020// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4x128x128xf32>) -> tensor<1x4x128x128xf32> {
2121// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<{{.}}{{.}}[0.0970484465]{{.}}, {{.}}[0.0882187337]{{.}}, {{.}}[0.135120019]{{.}}, {{.}}[0.14907673]{{.}}{{.}}> : tensor<4x1x1xf32>
2222// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<{{.}}{{.}}[0.191972837]{{.}}, {{.}}[0.286264896]{{.}}, {{.}}[0.180535644]{{.}}, {{.}}[0.166878015]{{.}}{{.}}> : tensor<4x1x1xf32>
@@ -33,4 +33,93 @@ func.func @main_graph(%arg0: tensor<1x4x128x128xf32>) -> tensor<1x4x128x128xf32>
3333
3434// -----
3535
36- // TODO: ADD more lit tests here
36+ func.func @decomposition_to_layernorm_axis_2 (%arg0: tensor <1 x128 x4 x128 xf32 > {onnx.name = " in0" }) -> (tensor <1 x128 x4 x128 xf32 > {onnx.name = " out" }) {
37+ %0 = onnx.Constant dense <[[[0.976699769 ], [0.380195737 ], [0.923246204 ], [0.261692435 ]]]> : tensor <1 x4 x1 xf32 >
38+ %1 = onnx.Constant dense <9.99999997E-7 > : tensor <f32 >
39+ %2 = onnx.Constant dense <2 > : tensor <1 xi64 >
40+ %3 = " onnx.ReduceMean" (%arg0 , %2 ) {keepdims = 1 : si64 , noop_with_empty_axes = 0 : si64 , onnx_node_name = " ReduceMean1" } : (tensor <1 x128 x4 x128 xf32 >, tensor <1 xi64 >) -> tensor <1 x128 x1 x128 xf32 >
41+ %4 = " onnx.Sub" (%arg0 , %3 ) {onnx_node_name = " Sub2" } : (tensor <1 x128 x4 x128 xf32 >, tensor <1 x128 x1 x128 xf32 >) -> tensor <1 x128 x4 x128 xf32 >
42+ %5 = " onnx.Mul" (%4 , %4 ) {onnx_node_name = " Mul3" } : (tensor <1 x128 x4 x128 xf32 >, tensor <1 x128 x4 x128 xf32 >) -> tensor <1 x128 x4 x128 xf32 >
43+ %6 = " onnx.ReduceMean" (%5 , %2 ) {keepdims = 1 : si64 , noop_with_empty_axes = 0 : si64 , onnx_node_name = " ReduceMean4" } : (tensor <1 x128 x4 x128 xf32 >, tensor <1 xi64 >) -> tensor <1 x128 x1 x128 xf32 >
44+ %7 = " onnx.Add" (%6 , %1 ) {onnx_node_name = " Add6" } : (tensor <1 x128 x1 x128 xf32 >, tensor <f32 >) -> tensor <1 x128 x1 x128 xf32 >
45+ %8 = " onnx.Sqrt" (%7 ) {onnx_node_name = " Sqrt7" } : (tensor <1 x128 x1 x128 xf32 >) -> tensor <1 x128 x1 x128 xf32 >
46+ %9 = " onnx.Div" (%4 , %8 ) {onnx_node_name = " Div8" } : (tensor <1 x128 x4 x128 xf32 >, tensor <1 x128 x1 x128 xf32 >) -> tensor <1 x128 x4 x128 xf32 >
47+ %10 = " onnx.Mul" (%0 , %9 ) {onnx_node_name = " Mul10" } : (tensor <1 x4 x1 xf32 >, tensor <1 x128 x4 x128 xf32 >) -> tensor <1 x128 x4 x128 xf32 >
48+ %11 = " onnx.Add" (%10 , %0 ) {onnx_node_name = " Add12" } : (tensor <1 x128 x4 x128 xf32 >, tensor <1 x4 x1 xf32 >) -> tensor <1 x128 x4 x128 xf32 >
49+ onnx.Return %11 : tensor <1 x128 x4 x128 xf32 >
50+ }
51+ // CHECK-LABEL: func.func @decomposition_to_layernorm_axis_2
52+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x128x4x128xf32> {onnx.name = "in0"}) -> (tensor<1x128x4x128xf32> {onnx.name = "out"}) {
53+ // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<{{.}}{{.}}[0.976699769], [0.380195737], [0.923246204], [0.261692435]{{.}}{{.}}> : tensor<1x4x1xf32>
54+ // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1, 1, 4, 1]> : tensor<4xi64>
55+ // CHECK: [[VAR_2_:%.+]] = "onnx.Reshape"([[VAR_0_]], [[VAR_1_]]) {allowzero = 0 : si64} : (tensor<1x4x1xf32>, tensor<4xi64>) -> tensor<1x1x4x1xf32>
56+ // CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Transpose"([[VAR_2_]]) {perm = [0, 1, 3, 2]} : (tensor<1x1x4x1xf32>) -> tensor<1x1x1x4xf32>
57+ // CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Transpose"([[PARAM_0_]]) {perm = [0, 1, 3, 2]} : (tensor<1x128x4x128xf32>) -> tensor<1x128x128x4xf32>
58+ // CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Reshape"([[VAR_0_]], [[VAR_1_]]) {allowzero = 0 : si64} : (tensor<1x4x1xf32>, tensor<4xi64>) -> tensor<1x1x4x1xf32>
59+ // CHECK: [[VAR_6_:%.+]] = "onnx.Transpose"([[VAR_5_]]) {perm = [0, 1, 3, 2]} : (tensor<1x1x4x1xf32>) -> tensor<1x1x1x4xf32>
60+ // CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_4_]], [[VAR_3_]], [[VAR_6_]]) {axis = 3 : si64, epsilon = 9.99999997E-7 : f32, stash_type = 1 : si64} : (tensor<1x128x128x4xf32>, tensor<1x1x1x4xf32>, tensor<1x1x1x4xf32>) -> (tensor<1x128x128x4xf32>, none, none)
61+ // CHECK: [[VAR_7_:%.+]] = "onnx.Transpose"([[VAR_Y_]]) {perm = [0, 1, 3, 2]} : (tensor<1x128x128x4xf32>) -> tensor<1x128x4x128xf32>
62+ // CHECK: onnx.Return [[VAR_7_]] : tensor<1x128x4x128xf32>
63+ // CHECK: }
64+
65+ // -----
66+
67+ func.func @decomposition_to_layernorm_axis_1_and_2 (%arg0: tensor <1 x4 x128 x128 xf32 > {onnx.name = " in0" }) -> (tensor <1 x4 x128 x128 xf32 > {onnx.name = " out" }) {
68+ %0 = onnx.Constant dense <[[[0.976699769 ]], [[0.380195737 ]], [[0.923246204 ]], [[0.261692435 ]]]> : tensor <4 x1 x1 xf32 >
69+ %1 = onnx.Constant dense <9.99999997E-7 > : tensor <f32 >
70+ %2 = onnx.Constant dense <[1 , 2 ]> : tensor <2 xi64 >
71+ %3 = " onnx.ReduceMean" (%arg0 , %2 ) {keepdims = 1 : si64 , noop_with_empty_axes = 0 : si64 , onnx_node_name = " ReduceMean1" } : (tensor <1 x4 x128 x128 xf32 >, tensor <2 xi64 >) -> tensor <1 x1 x1 x128 xf32 >
72+ %4 = " onnx.Sub" (%arg0 , %3 ) {onnx_node_name = " Sub2" } : (tensor <1 x4 x128 x128 xf32 >, tensor <1 x1 x1 x128 xf32 >) -> tensor <1 x4 x128 x128 xf32 >
73+ %5 = " onnx.Mul" (%4 , %4 ) {onnx_node_name = " Mul3" } : (tensor <1 x4 x128 x128 xf32 >, tensor <1 x4 x128 x128 xf32 >) -> tensor <1 x4 x128 x128 xf32 >
74+ %6 = " onnx.ReduceMean" (%5 , %2 ) {keepdims = 1 : si64 , noop_with_empty_axes = 0 : si64 , onnx_node_name = " ReduceMean4" } : (tensor <1 x4 x128 x128 xf32 >, tensor <2 xi64 >) -> tensor <1 x1 x1 x128 xf32 >
75+ %7 = " onnx.Add" (%6 , %1 ) {onnx_node_name = " Add6" } : (tensor <1 x1 x1 x128 xf32 >, tensor <f32 >) -> tensor <1 x1 x1 x128 xf32 >
76+ %8 = " onnx.Sqrt" (%7 ) {onnx_node_name = " Sqrt7" } : (tensor <1 x1 x1 x128 xf32 >) -> tensor <1 x1 x1 x128 xf32 >
77+ %9 = " onnx.Div" (%4 , %8 ) {onnx_node_name = " Div8" } : (tensor <1 x4 x128 x128 xf32 >, tensor <1 x1 x1 x128 xf32 >) -> tensor <1 x4 x128 x128 xf32 >
78+ %10 = " onnx.Mul" (%0 , %9 ) {onnx_node_name = " Mul10" } : (tensor <4 x1 x1 xf32 >, tensor <1 x4 x128 x128 xf32 >) -> tensor <1 x4 x128 x128 xf32 >
79+ %11 = " onnx.Add" (%10 , %0 ) {onnx_node_name = " Add12" } : (tensor <1 x4 x128 x128 xf32 >, tensor <4 x1 x1 xf32 >) -> tensor <1 x4 x128 x128 xf32 >
80+ onnx.Return %11 : tensor <1 x4 x128 x128 xf32 >
81+ }
82+ // CHECK-LABEL: func.func @decomposition_to_layernorm_axis_1_and_2
83+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4x128x128xf32> {onnx.name = "in0"}) -> (tensor<1x4x128x128xf32> {onnx.name = "out"}) {
84+ // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<{{.}}{{.}}[0.976699769]{{.}}, {{.}}[0.380195737]{{.}}, {{.}}[0.923246204]{{.}}, {{.}}[0.261692435]{{.}}{{.}}> : tensor<4x1x1xf32>
85+ // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1, 4, 1, 1]> : tensor<4xi64>
86+ // CHECK: [[VAR_2_:%.+]] = "onnx.Reshape"([[VAR_0_]], [[VAR_1_]]) {allowzero = 0 : si64} : (tensor<4x1x1xf32>, tensor<4xi64>) -> tensor<1x4x1x1xf32>
87+ // CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Transpose"([[VAR_2_]]) {perm = [0, 3, 1, 2]} : (tensor<1x4x1x1xf32>) -> tensor<1x1x4x1xf32>
88+ // CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Transpose"([[PARAM_0_]]) {perm = [0, 3, 1, 2]} : (tensor<1x4x128x128xf32>) -> tensor<1x128x4x128xf32>
89+ // CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Reshape"([[VAR_0_]], [[VAR_1_]]) {allowzero = 0 : si64} : (tensor<4x1x1xf32>, tensor<4xi64>) -> tensor<1x4x1x1xf32>
90+ // CHECK: [[VAR_6_:%.+]] = "onnx.Transpose"([[VAR_5_]]) {perm = [0, 3, 1, 2]} : (tensor<1x4x1x1xf32>) -> tensor<1x1x4x1xf32>
91+ // CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_4_]], [[VAR_3_]], [[VAR_6_]]) {axis = 2 : si64, epsilon = 9.99999997E-7 : f32, stash_type = 1 : si64} : (tensor<1x128x4x128xf32>, tensor<1x1x4x1xf32>, tensor<1x1x4x1xf32>) -> (tensor<1x128x4x128xf32>, none, none)
92+ // CHECK: [[VAR_7_:%.+]] = "onnx.Transpose"([[VAR_Y_]]) {perm = [0, 2, 3, 1]} : (tensor<1x128x4x128xf32>) -> tensor<1x4x128x128xf32>
93+ // CHECK: onnx.Return [[VAR_7_]] : tensor<1x4x128x128xf32>
94+ // CHECK: }
95+
96+ // -----
97+
98+ func.func @decomposition_to_layernorm_axis_1_and_3 (%arg0: tensor <1 x4 x128 x128 xf32 > {onnx.name = " in0" }) -> (tensor <1 x4 x128 x128 xf32 > {onnx.name = " out" }) {
99+ %0 = onnx.Constant dense <[[[0.976699769 ]], [[0.380195737 ]], [[0.923246204 ]], [[0.261692435 ]]]> : tensor <4 x1 x1 xf32 >
100+ %1 = onnx.Constant dense <9.99999997E-7 > : tensor <f32 >
101+ %2 = onnx.Constant dense <[1 , 3 ]> : tensor <2 xi64 >
102+ %3 = " onnx.ReduceMean" (%arg0 , %2 ) {keepdims = 1 : si64 , noop_with_empty_axes = 0 : si64 , onnx_node_name = " ReduceMean1" } : (tensor <1 x4 x128 x128 xf32 >, tensor <2 xi64 >) -> tensor <1 x1 x128 x1 xf32 >
103+ %4 = " onnx.Sub" (%arg0 , %3 ) {onnx_node_name = " Sub2" } : (tensor <1 x4 x128 x128 xf32 >, tensor <1 x1 x128 x1 xf32 >) -> tensor <1 x4 x128 x128 xf32 >
104+ %5 = " onnx.Mul" (%4 , %4 ) {onnx_node_name = " Mul3" } : (tensor <1 x4 x128 x128 xf32 >, tensor <1 x4 x128 x128 xf32 >) -> tensor <1 x4 x128 x128 xf32 >
105+ %6 = " onnx.ReduceMean" (%5 , %2 ) {keepdims = 1 : si64 , noop_with_empty_axes = 0 : si64 , onnx_node_name = " ReduceMean4" } : (tensor <1 x4 x128 x128 xf32 >, tensor <2 xi64 >) -> tensor <1 x1 x128 x1 xf32 >
106+ %7 = " onnx.Add" (%6 , %1 ) {onnx_node_name = " Add6" } : (tensor <1 x1 x128 x1 xf32 >, tensor <f32 >) -> tensor <1 x1 x128 x1 xf32 >
107+ %8 = " onnx.Sqrt" (%7 ) {onnx_node_name = " Sqrt7" } : (tensor <1 x1 x128 x1 xf32 >) -> tensor <1 x1 x128 x1 xf32 >
108+ %9 = " onnx.Div" (%4 , %8 ) {onnx_node_name = " Div8" } : (tensor <1 x4 x128 x128 xf32 >, tensor <1 x1 x128 x1 xf32 >) -> tensor <1 x4 x128 x128 xf32 >
109+ %10 = " onnx.Mul" (%0 , %9 ) {onnx_node_name = " Mul10" } : (tensor <4 x1 x1 xf32 >, tensor <1 x4 x128 x128 xf32 >) -> tensor <1 x4 x128 x128 xf32 >
110+ %11 = " onnx.Add" (%10 , %0 ) {onnx_node_name = " Add12" } : (tensor <1 x4 x128 x128 xf32 >, tensor <4 x1 x1 xf32 >) -> tensor <1 x4 x128 x128 xf32 >
111+ onnx.Return %11 : tensor <1 x4 x128 x128 xf32 >
112+ }
113+ // CHECK-LABEL: func.func @decomposition_to_layernorm_axis_1_and_3
114+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4x128x128xf32> {onnx.name = "in0"}) -> (tensor<1x4x128x128xf32> {onnx.name = "out"}) {
115+ // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<{{.}}{{.}}[0.976699769]{{.}}, {{.}}[0.380195737]{{.}}, {{.}}[0.923246204]{{.}}, {{.}}[0.261692435]{{.}}{{.}}> : tensor<4x1x1xf32>
116+ // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1, 4, 1, 1]> : tensor<4xi64>
117+ // CHECK: [[VAR_2_:%.+]] = "onnx.Reshape"([[VAR_0_]], [[VAR_1_]]) {allowzero = 0 : si64} : (tensor<4x1x1xf32>, tensor<4xi64>) -> tensor<1x4x1x1xf32>
118+ // CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Transpose"([[VAR_2_]]) {perm = [0, 2, 1, 3]} : (tensor<1x4x1x1xf32>) -> tensor<1x1x4x1xf32>
119+ // CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Transpose"([[PARAM_0_]]) {perm = [0, 2, 1, 3]} : (tensor<1x4x128x128xf32>) -> tensor<1x128x4x128xf32>
120+ // CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Reshape"([[VAR_0_]], [[VAR_1_]]) {allowzero = 0 : si64} : (tensor<4x1x1xf32>, tensor<4xi64>) -> tensor<1x4x1x1xf32>
121+ // CHECK: [[VAR_6_:%.+]] = "onnx.Transpose"([[VAR_5_]]) {perm = [0, 2, 1, 3]} : (tensor<1x4x1x1xf32>) -> tensor<1x1x4x1xf32>
122+ // CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_4_]], [[VAR_3_]], [[VAR_6_]]) {axis = 2 : si64, epsilon = 9.99999997E-7 : f32, stash_type = 1 : si64} : (tensor<1x128x4x128xf32>, tensor<1x1x4x1xf32>, tensor<1x1x4x1xf32>) -> (tensor<1x128x4x128xf32>, none, none)
123+ // CHECK: [[VAR_7_:%.+]] = "onnx.Transpose"([[VAR_Y_]]) {perm = [0, 2, 1, 3]} : (tensor<1x128x4x128xf32>) -> tensor<1x4x128x128xf32>
124+ // CHECK: onnx.Return [[VAR_7_]] : tensor<1x4x128x128xf32>
125+ // CHECK: }
0 commit comments