Skip to content

Commit a364c43

Browse files
committed
Add test for folding an add into RMSNorm
Signed-off-by: Rickert, Jonas <[email protected]>
1 parent 4867e28 commit a364c43

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

test/mlir/onnx/onnx_canonicalization.mlir

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2383,12 +2383,26 @@ func.func @layernorm_without_bias(%arg0: tensor<1x384x768xf32>, %arg1: tensor<76
23832383
%NormScaled, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %arg1, %0) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none)
23842384
%Y = "onnx.Add"(%bias, %NormScaled) : (tensor<768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32>
23852385
return %Y : tensor<1x384x768xf32>
2386-
// mlir2FileCheck.py
23872386
// CHECK-LABEL: func.func @layernorm_without_bias
23882387
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> {
23892388
// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none)
23902389
// CHECK: return [[Y_]] : tensor<1x384x768xf32>
23912390
// CHECK: }
23922391
}
23932392

2393+
// -----
2394+
2395+
// Recognize the bias and fold into RMSLayerNorm.
2396+
func.func @rmslayernorm_without_bias(%arg0: tensor<1x384x768xf32>, %arg1: tensor<768xf32>, %bias: tensor<768xf32>) -> tensor<1x384x768xf32> {
2397+
%0 = "onnx.NoValue"() {value} : () -> none
2398+
%NormScaled, %InvStdDev = "onnx.RMSLayerNormalization"(%arg0, %arg1, %0) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none)
2399+
%Y = "onnx.Add"(%bias, %NormScaled) : (tensor<768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32>
2400+
return %Y : tensor<1x384x768xf32>
2401+
// CHECK-LABEL: func.func @rmslayernorm_without_bias
2402+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> {
2403+
// CHECK: [[Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none)
2404+
// CHECK: return [[Y_]] : tensor<1x384x768xf32>
2405+
// CHECK: }
2406+
}
2407+
23942408

0 commit comments

Comments
 (0)