Skip to content

Commit 2a89e14

Browse files
committed
feat: add single lit test for recompose-onnx pass
1 parent 2656a25 commit 2a89e14

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: onnx-mlir-opt --recompose-onnx="recompose-layernorm-by-transpose" --canonicalize %s -split-input-file | FileCheck %s
2+
3+
func.func @main_graph(%arg0: tensor<1x4x128x128xf32>) -> tensor<1x4x128x128xf32> {
4+
%2 = onnx.Constant dense<1.000000e+00> : tensor<f32>
5+
%4 = onnx.Constant dense<9.99999997E-7> : tensor<f32>
6+
%5 = onnx.Constant dense<[[[0.0970484465]], [[0.0882187337]], [[0.135120019]], [[0.14907673]]]> : tensor<4x1x1xf32>
7+
%6 = onnx.Constant dense<[[[0.191972837]], [[0.286264896]], [[0.180535644]], [[0.166878015]]]> : tensor<4x1x1xf32>
8+
%9 = "onnx.ReduceMeanV13"(%arg0) {axes = [1], keepdims = 1 : si64, onnx_node_name = "/mask_downscaling/mask_downscaling.1/ReduceMean"} : (tensor<1x4x128x128xf32>) -> tensor<1x1x128x128xf32>
9+
%10 = "onnx.Sub"(%arg0, %9) {onnx_node_name = "/mask_downscaling/mask_downscaling.1/Sub"} : (tensor<1x4x128x128xf32>, tensor<1x1x128x128xf32>) -> tensor<1x4x128x128xf32>
10+
%11 = "onnx.Mul"(%10, %10) {onnx_node_name = "/mask_downscaling/mask_downscaling.1/Pow_1"} : (tensor<1x4x128x128xf32>, tensor<1x4x128x128xf32>) -> tensor<1x4x128x128xf32>
11+
%12 = "onnx.ReduceMeanV13"(%11) {axes = [1], keepdims = 1 : si64, onnx_node_name = "/mask_downscaling/mask_downscaling.1/ReduceMean_1"} : (tensor<1x4x128x128xf32>) -> tensor<1x1x128x128xf32>
12+
%13 = "onnx.Add"(%12, %4) {onnx_node_name = "/mask_downscaling/mask_downscaling.1/Add"} : (tensor<1x1x128x128xf32>, tensor<f32>) -> tensor<1x1x128x128xf32>
13+
%14 = "onnx.Sqrt"(%13) {onnx_node_name = "/mask_downscaling/mask_downscaling.1/Sqrt"} : (tensor<1x1x128x128xf32>) -> tensor<1x1x128x128xf32>
14+
%15 = "onnx.Div"(%10, %14) {onnx_node_name = "/mask_downscaling/mask_downscaling.1/Div"} : (tensor<1x4x128x128xf32>, tensor<1x1x128x128xf32>) -> tensor<1x4x128x128xf32>
15+
%16 = "onnx.Mul"(%15, %5) {onnx_node_name = "/mask_downscaling/mask_downscaling.1/Mul_2"} : (tensor<1x4x128x128xf32>, tensor<4x1x1xf32>) -> tensor<1x4x128x128xf32>
16+
%17 = "onnx.Add"(%16, %6) {onnx_node_name = "/mask_downscaling/mask_downscaling.1/Add_1"} : (tensor<1x4x128x128xf32>, tensor<4x1x1xf32>) -> tensor<1x4x128x128xf32>
17+
return %17 : tensor<1x4x128x128xf32>
18+
}
19+
// CHECK-LABEL: func.func @main_graph
20+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4x128x128xf32>) -> tensor<1x4x128x128xf32> {
21+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<{{.}}{{.}}[0.0970484465]{{.}}, {{.}}[0.0882187337]{{.}}, {{.}}[0.135120019]{{.}}, {{.}}[0.14907673]{{.}}{{.}}> : tensor<4x1x1xf32>
22+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<{{.}}{{.}}[0.191972837]{{.}}, {{.}}[0.286264896]{{.}}, {{.}}[0.180535644]{{.}}, {{.}}[0.166878015]{{.}}{{.}}> : tensor<4x1x1xf32>
23+
// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[1, 4, 1, 1]> : tensor<4xi64>
24+
// CHECK: [[VAR_3_:%.+]] = "onnx.Reshape"([[VAR_0_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<4x1x1xf32>, tensor<4xi64>) -> tensor<1x4x1x1xf32>
25+
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Transpose"([[VAR_3_]]) {perm = [0, 2, 3, 1]} : (tensor<1x4x1x1xf32>) -> tensor<1x1x1x4xf32>
26+
// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Transpose"([[PARAM_0_]]) {perm = [0, 2, 3, 1]} : (tensor<1x4x128x128xf32>) -> tensor<1x128x128x4xf32>
27+
// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Reshape"([[VAR_1_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<4x1x1xf32>, tensor<4xi64>) -> tensor<1x4x1x1xf32>
28+
// CHECK: [[VAR_7_:%.+]] = "onnx.Transpose"([[VAR_6_]]) {perm = [0, 2, 3, 1]} : (tensor<1x4x1x1xf32>) -> tensor<1x1x1x4xf32>
29+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_5_]], [[VAR_4_]], [[VAR_7_]]) {axis = 3 : si64, epsilon = 9.99999997E-7 : f32, stash_type = 1 : si64} : (tensor<1x128x128x4xf32>, tensor<1x1x1x4xf32>, tensor<1x1x1x4xf32>) -> (tensor<1x128x128x4xf32>, none, none)
30+
// CHECK: [[VAR_8_:%.+]] = "onnx.Transpose"([[VAR_Y_]]) {perm = [0, 3, 1, 2]} : (tensor<1x128x128x4xf32>) -> tensor<1x4x128x128xf32>
31+
// CHECK: return [[VAR_8_]] : tensor<1x4x128x128xf32>
32+
// CHECK: }
33+
34+
// -----
35+
36+
// TODO: ADD more lit tests here

0 commit comments

Comments
 (0)