@@ -511,6 +511,30 @@ func.func @rms_layer_norm_v3_dyn_shape(%x: tensor<1x?x768xf32>) -> (tensor<1x?x7
511511
512512// -----
513513
514+ // RMS Layer norm with multiple uses of the scale multiplication
515+
516+ func.func @rms_layer_norm_multi_use (%x: tensor <1 x384 x768 xf32 >, %scale: tensor <768 xf32 >, %bias: tensor <768 xf32 >) -> (tensor <1 x384 x768 xf32 >) {
517+ %eps = onnx.Constant dense <1.2E+0 > : tensor <f32 >
518+ %dd = " onnx.Mul" (%x , %x ) : (tensor <1 x384 x768 xf32 >, tensor <1 x384 x768 xf32 >) -> tensor <1 x384 x768 xf32 >
519+ %var = " onnx.ReduceMeanV13" (%dd ) {axes = [-1 ], keepdims = 1 : si64 } : (tensor <1 x384 x768 xf32 >) -> tensor <1 x384 x1 xf32 >
520+ %varEps = " onnx.Add" (%eps , %var ) : (tensor <f32 >, tensor <1 x384 x1 xf32 >) -> tensor <1 x384 x1 xf32 >
521+ %StdDev = " onnx.Sqrt" (%varEps ) : (tensor <1 x384 x1 xf32 >) -> tensor <1 x384 x1 xf32 >
522+ %Norm = " onnx.Div" (%x , %StdDev ) : (tensor <1 x384 x768 xf32 >, tensor <1 x384 x1 xf32 >) -> tensor <1 x384 x768 xf32 >
523+ %NormScaled = " onnx.Mul" (%scale , %Norm ) : (tensor <768 xf32 >, tensor <1 x384 x768 xf32 >) -> tensor <1 x384 x768 xf32 >
524+ %MultiUse = " onnx.Add" (%NormScaled , %NormScaled ) : (tensor <1 x384 x768 xf32 >, tensor <1 x384 x768 xf32 >) -> tensor <1 x384 x768 xf32 >
525+ return %MultiUse : tensor <1 x384 x768 xf32 >
526+ // mlir2FileCheck.py
527+ // CHECK-LABEL: func.func @rms_layer_norm_multi_use
528+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> {
529+ // CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
530+ // CHECK: [[VAR_Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none)
531+ // CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_Y_]], [[VAR_Y_]]) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32>
532+ // CHECK: return [[VAR_1_]] : tensor<1x384x768xf32>
533+ // CHECK: }
534+ }
535+
536+ // -----
537+
514538// COM: QLinearMatMul
515539func.func @qlinear_matmul (%arg0: tensor <?x?x768 xi8 >, %arg1: tensor <f32 >, %arg2: tensor <i8 >, %arg3: tensor <768 x768 xi8 >, %arg4: tensor <f32 >, %arg5: tensor <i8 >, %arg6: tensor <f32 >, %arg7: tensor <i8 >) -> (tensor <?x?x768 xi8 >) {
516540 %0 = " onnx.DequantizeLinear" (%arg0 , %arg1 , %arg2 ) {axis = 1 : si64 } : (tensor <?x?x768 xi8 >, tensor <f32 >, tensor <i8 >) -> tensor <?x?x768 xf32 >
0 commit comments