@@ -2405,4 +2405,92 @@ func.func @rmslayernorm_without_bias(%arg0: tensor<1x384x768xf32>, %arg1: tensor
24052405// CHECK: }
24062406}
24072407
2408+ // -----
2409+
2410+ // Recognize the scale and fold into LayerNorm.
2411+ func.func @layernorm_with_neutral_scale (%arg0: tensor <1 x384 x768 xf32 >, %arg1: tensor <768 xf32 >, %mulVal: tensor <768 xf32 >) -> tensor <1 x384 x768 xf32 > {
2412+ %0 = " onnx.NoValue" () {value } : () -> none
2413+ %1 = onnx.Constant dense <1.000000e+00 > : tensor <768 xf32 >
2414+ %NormScaled , %Mean , %InvStdDev = " onnx.LayerNormalization" (%arg0 , %1 , %0 ) {axis = 2 : si64 , epsilon = 1.200000e+00 : f32 , stash_type = 1 : si64 } : (tensor <1 x384 x768 xf32 >, tensor <768 xf32 >, none ) -> (tensor <1 x384 x768 xf32 >, none , none )
2415+ %Y = " onnx.Mul" (%mulVal , %NormScaled ) : (tensor <768 xf32 >, tensor <1 x384 x768 xf32 >) -> tensor <1 x384 x768 xf32 >
2416+ return %Y : tensor <1 x384 x768 xf32 >
2417+ // CHECK-LABEL: func.func @layernorm_with_neutral_scale
2418+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> {
2419+ // CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
2420+ // CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_2_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none)
2421+ // CHECK: return [[VAR_Y_]] : tensor<1x384x768xf32>
2422+ // CHECK: }
2423+ }
2424+
2425+
2426+ // -----
2427+
2428+ func.func @layernorm_scale_not_one (%arg0: tensor <1 x384 x768 xf32 >, %arg1: tensor <768 xf32 >, %mulVal: tensor <768 xf32 >) -> tensor <1 x384 x768 xf32 > {
2429+ %0 = " onnx.NoValue" () {value } : () -> none
2430+ %1 = onnx.Constant dense <1.100000e+00 > : tensor <768 xf32 >
2431+ %NormScaled , %Mean , %InvStdDev = " onnx.LayerNormalization" (%arg0 , %1 , %0 ) {axis = 2 : si64 , epsilon = 1.200000e+00 : f32 , stash_type = 1 : si64 } : (tensor <1 x384 x768 xf32 >, tensor <768 xf32 >, none ) -> (tensor <1 x384 x768 xf32 >, none , none )
2432+ %Y = " onnx.Mul" (%mulVal , %NormScaled ) : (tensor <768 xf32 >, tensor <1 x384 x768 xf32 >) -> tensor <1 x384 x768 xf32 >
2433+ return %Y : tensor <1 x384 x768 xf32 >
2434+ // CHECK-LABEL: func.func @layernorm_scale_not_one
2435+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> {
2436+ // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
2437+ // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1.100000e+00> : tensor<768xf32>
2438+ // CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[VAR_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none)
2439+ // CHECK: [[VAR_2_:%.+]] = "onnx.Mul"([[PARAM_2_]], [[VAR_Y_]]) : (tensor<768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32>
2440+ // CHECK: return [[VAR_2_]] : tensor<1x384x768xf32>
2441+ // CHECK: }
2442+ }
2443+
2444+ // -----
2445+
2446+ func.func @layernorm_bias_not_zero (%arg0: tensor <1 x384 x768 xf32 >, %arg1: tensor <768 xf32 >, %mulVal: tensor <768 xf32 >) -> tensor <1 x384 x768 xf32 > {
2447+ %0 = onnx.Constant dense <2.000000e+00 > : tensor <768 xf32 >
2448+ %1 = onnx.Constant dense <1.000000e+00 > : tensor <768 xf32 >
2449+ %NormScaled , %Mean , %InvStdDev = " onnx.LayerNormalization" (%arg0 , %1 , %0 ) {axis = 2 : si64 , epsilon = 1.200000e+00 : f32 , stash_type = 1 : si64 } : (tensor <1 x384 x768 xf32 >, tensor <768 xf32 >, tensor <768 xf32 >) -> (tensor <1 x384 x768 xf32 >, none , none )
2450+ %Y = " onnx.Mul" (%mulVal , %NormScaled ) : (tensor <768 xf32 >, tensor <1 x384 x768 xf32 >) -> tensor <1 x384 x768 xf32 >
2451+ return %Y : tensor <1 x384 x768 xf32 >
2452+ // CHECK-LABEL: func.func @layernorm_bias_not_zero
2453+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> {
2454+ // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<2.000000e+00> : tensor<768xf32>
2455+ // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor<768xf32>
2456+ // CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[VAR_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none)
2457+ // CHECK: [[VAR_2_:%.+]] = "onnx.Mul"([[PARAM_2_]], [[VAR_Y_]]) : (tensor<768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32>
2458+ // CHECK: return [[VAR_2_]] : tensor<1x384x768xf32>
2459+ // CHECK: }
2460+ }
2461+
2462+ // -----
2463+
2464+ func.func @layernorm_broadcast (%arg0: tensor <1 x384 x768 xf32 >, %arg1: tensor <768 xf32 >, %mulVal: tensor <10 x384 x768 xf32 >) -> tensor <10 x384 x768 xf32 > {
2465+ %0 = " onnx.NoValue" () {value } : () -> none
2466+ %1 = onnx.Constant dense <1.000000e+00 > : tensor <768 xf32 >
2467+ %NormScaled , %Mean , %InvStdDev = " onnx.LayerNormalization" (%arg0 , %1 , %0 ) {axis = 2 : si64 , epsilon = 1.200000e+00 : f32 , stash_type = 1 : si64 } : (tensor <1 x384 x768 xf32 >, tensor <768 xf32 >, none ) -> (tensor <1 x384 x768 xf32 >, none , none )
2468+ %Y = " onnx.Mul" (%mulVal , %NormScaled ) : (tensor <10 x384 x768 xf32 >, tensor <1 x384 x768 xf32 >) -> tensor <10 x384 x768 xf32 >
2469+ return %Y : tensor <10 x384 x768 xf32 >
2470+ // CHECK-LABEL: func.func @layernorm_broadcast
2471+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<10x384x768xf32>) -> tensor<10x384x768xf32> {
2472+ // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
2473+ // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor<768xf32>
2474+ // CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[VAR_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none)
2475+ // CHECK: [[VAR_2_:%.+]] = "onnx.Mul"([[PARAM_2_]], [[VAR_Y_]]) : (tensor<10x384x768xf32>, tensor<1x384x768xf32>) -> tensor<10x384x768xf32>
2476+ // CHECK: return [[VAR_2_]] : tensor<10x384x768xf32>
2477+ // CHECK: }
2478+ }
2479+
2480+ // -----
2481+
2482+ // Recognize the scale and fold into the RMSNorm.
2483+ func.func @rmslayernorm_with_neutral_scale (%arg0: tensor <1 x384 x768 xf32 >, %arg1: tensor <768 xf32 >, %mulVal: tensor <768 xf32 >) -> tensor <1 x384 x768 xf32 > {
2484+ %0 = " onnx.NoValue" () {value } : () -> none
2485+ %1 = onnx.Constant dense <1.000000e+00 > : tensor <768 xf32 >
2486+ %NormScaled , %InvStdDev = " onnx.RMSLayerNormalization" (%arg0 , %1 , %0 ) {axis = 2 : si64 , epsilon = 1.200000e+00 : f32 , stash_type = 1 : si64 } : (tensor <1 x384 x768 xf32 >, tensor <768 xf32 >, none ) -> (tensor <1 x384 x768 xf32 >, none )
2487+ %Y = " onnx.Mul" (%mulVal , %NormScaled ) : (tensor <768 xf32 >, tensor <1 x384 x768 xf32 >) -> tensor <1 x384 x768 xf32 >
2488+ return %Y : tensor <1 x384 x768 xf32 >
2489+ // CHECK-LABEL: func.func @rmslayernorm_with_neutral_scale
2490+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> {
2491+ // CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
2492+ // CHECK: [[VAR_Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[PARAM_0_]], [[PARAM_2_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none)
2493+ // CHECK: return [[VAR_Y_]] : tensor<1x384x768xf32>
2494+ // CHECK: }
2495+ }
24082496
0 commit comments