@@ -2383,12 +2383,26 @@ func.func @layernorm_without_bias(%arg0: tensor<1x384x768xf32>, %arg1: tensor<76
23832383 %NormScaled , %Mean , %InvStdDev = " onnx.LayerNormalization" (%arg0 , %arg1 , %0 ) {axis = 2 : si64 , epsilon = 1.200000e+00 : f32 , stash_type = 1 : si64 } : (tensor <1 x384 x768 xf32 >, tensor <768 xf32 >, none ) -> (tensor <1 x384 x768 xf32 >, none , none )
23842384 %Y = " onnx.Add" (%bias , %NormScaled ) : (tensor <768 xf32 >, tensor <1 x384 x768 xf32 >) -> tensor <1 x384 x768 xf32 >
23852385 return %Y : tensor <1 x384 x768 xf32 >
2386- // mlir2FileCheck.py
23872386// CHECK-LABEL: func.func @layernorm_without_bias
23882387// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> {
23892388// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none)
23902389// CHECK: return [[Y_]] : tensor<1x384x768xf32>
23912390// CHECK: }
23922391}
23932392
2393+ // -----
2394+
2395+ // Recognize the bias and fold into RMSLayerNorm.
2396+ func.func @rmslayernorm_without_bias (%arg0: tensor <1 x384 x768 xf32 >, %arg1: tensor <768 xf32 >, %bias: tensor <768 xf32 >) -> tensor <1 x384 x768 xf32 > {
2397+ %0 = " onnx.NoValue" () {value } : () -> none
2398+ %NormScaled , %InvStdDev = " onnx.RMSLayerNormalization" (%arg0 , %arg1 , %0 ) {axis = 2 : si64 , epsilon = 1.200000e+00 : f32 , stash_type = 1 : si64 } : (tensor <1 x384 x768 xf32 >, tensor <768 xf32 >, none ) -> (tensor <1 x384 x768 xf32 >, none )
2399+ %Y = " onnx.Add" (%bias , %NormScaled ) : (tensor <768 xf32 >, tensor <1 x384 x768 xf32 >) -> tensor <1 x384 x768 xf32 >
2400+ return %Y : tensor <1 x384 x768 xf32 >
2401+ // CHECK-LABEL: func.func @rmslayernorm_without_bias
2402+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> {
2403+ // CHECK: [[Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none)
2404+ // CHECK: return [[Y_]] : tensor<1x384x768xf32>
2405+ // CHECK: }
2406+ }
2407+
23942408
0 commit comments