@@ -514,4 +514,96 @@ func.func @simplified_layernorm_two_outputs_mean_used(%input: tensor<2x4x8xf32>,
514514// CHECK: onnx.Return [[VAR_0_]]#0, [[VAR_0_]]#1 : tensor<2x4x8xf32>, tensor<2x4x1xf32>
515515}
516516
517+ // -----
518+ // SkipSimplifiedLayerNormalization: 3 inputs, 1 output
519+
520+ func.func @skip_simplified_layernorm_basic (%input: tensor <2 x4 x8 xf32 >, %skip: tensor <2 x4 x8 xf32 >, %gamma: tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 > {
521+ %r = " onnx.Custom" (%input , %skip , %gamma ) {domain_name = " com.microsoft" , function_name = " SkipSimplifiedLayerNormalization" , epsilon = 1.000000e-05 : f32 } : (tensor <2 x4 x8 xf32 >, tensor <2 x4 x8 xf32 >, tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 >
522+ onnx.Return %r : tensor <2 x4 x8 xf32 >
523+ // CHECK-LABEL: func.func @skip_simplified_layernorm_basic
524+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
525+ // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
526+ // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
527+ // CHECK: [[VAR_Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[VAR_1_]], [[PARAM_2_]], [[VAR_0_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, none) -> (tensor<2x4x8xf32>, none)
528+ // CHECK: onnx.Return [[VAR_Y_]] : tensor<2x4x8xf32>
529+ }
530+
531+
532+ // -----
533+ // SkipSimplifiedLayerNormalization: 4 inputs (bias), 1 output
534+
535+ func.func @skip_simplified_layernorm_bias (%input: tensor <2 x4 x8 xf32 >, %skip: tensor <2 x4 x8 xf32 >, %gamma: tensor <8 xf32 >, %bias: tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 > {
536+ %r = " onnx.Custom" (%input , %skip , %gamma , %bias ) {domain_name = " com.microsoft" , function_name = " SkipSimplifiedLayerNormalization" , epsilon = 1.000000e-05 : f32 } : (tensor <2 x4 x8 xf32 >, tensor <2 x4 x8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 >
537+ onnx.Return %r : tensor <2 x4 x8 xf32 >
538+ // CHECK-LABEL: func.func @skip_simplified_layernorm_bias
539+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
540+ // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
541+ // CHECK: [[VAR_2_:%.+]] = "onnx.Add"([[VAR_1_]], [[PARAM_3_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
542+ // CHECK: [[VAR_Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[VAR_2_]], [[PARAM_2_]], [[VAR_0_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, none) -> (tensor<2x4x8xf32>, none)
543+ // CHECK: onnx.Return [[VAR_Y_]] : tensor<2x4x8xf32>
544+ }
545+
546+
547+
548+ // -----
549+ // SkipSimplifiedLayerNormalization: 4 inputs, 2 outputs (output, mean)
550+
551+ func.func @skip_simplified_layernorm_two_outputs (%input: tensor <2 x4 x8 xf32 >, %skip: tensor <2 x4 x8 xf32 >, %gamma: tensor <8 xf32 >, %bias: tensor <8 xf32 >) -> (tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >) {
552+ %r0 , %r1 = " onnx.Custom" (%input , %skip , %gamma , %bias ) {domain_name = " com.microsoft" , function_name = " SkipSimplifiedLayerNormalization" , epsilon = 1.000000e-05 : f32 } : (tensor <2 x4 x8 xf32 >, tensor <2 x4 x8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >) -> (tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >)
553+ onnx.Return %r0 , %r1 : tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >
554+ // CHECK-LABEL: func.func @skip_simplified_layernorm_two_outputs
555+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>) {
556+ // CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
557+ // CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_3_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
558+ // CHECK: [[VAR_2_:%.+]]:3 = "onnx.Custom"([[VAR_1_]], [[PARAM_2_]]) {axis = -1 : si64, domain_name = "", epsilon = 9.99999974E-6 : f32, function_name = "SimplifiedLayerNormalization", stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, none)
559+ // CHECK: onnx.Return [[VAR_2_]]#0, [[VAR_2_]]#1 : tensor<2x4x8xf32>, tensor<2x4x1xf32>
560+ // CHECK: }
561+ }
562+
563+
564+ // -----
565+ // SkipSimplifiedLayerNormalization: 4 inputs, 3 outputs (output, mean, inv_std_var)
566+
567+ func.func @skip_simplified_layernorm_three_outputs (%input: tensor <2 x4 x8 xf32 >, %skip: tensor <2 x4 x8 xf32 >, %gamma: tensor <8 xf32 >, %bias: tensor <8 xf32 >) -> (tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x1 xf32 >) {
568+ %r0 , %r1 , %r2 = " onnx.Custom" (%input , %skip , %gamma , %bias ) {domain_name = " com.microsoft" , function_name = " SkipSimplifiedLayerNormalization" , epsilon = 1.000000e-05 : f32 } : (tensor <2 x4 x8 xf32 >, tensor <2 x4 x8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >) -> (tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x1 xf32 >)
569+ onnx.Return %r0 , %r1 , %r2 : tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x1 xf32 >
570+ // CHECK-LABEL: func.func @skip_simplified_layernorm_three_outputs
571+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>) {
572+ // CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
573+ // CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_3_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
574+ // CHECK: [[VAR_2_:%.+]]:3 = "onnx.Custom"([[VAR_1_]], [[PARAM_2_]]) {axis = -1 : si64, domain_name = "", epsilon = 9.99999974E-6 : f32, function_name = "SimplifiedLayerNormalization", stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>)
575+ // CHECK: onnx.Return [[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2 : tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>
576+ }
577+
578+
579+ // -----
580+ // SkipSimplifiedLayerNormalization: 4 inputs, 4 outputs (output, mean, inv_std_var, sum)
581+
582+ func.func @skip_simplified_layernorm_four_outputs (%input: tensor <2 x4 x8 xf32 >, %skip: tensor <2 x4 x8 xf32 >, %gamma: tensor <8 xf32 >, %bias: tensor <8 xf32 >) -> (tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x8 xf32 >) {
583+ %r0 , %r1 , %r2 , %r3 = " onnx.Custom" (%input , %skip , %gamma , %bias ) {domain_name = " com.microsoft" , function_name = " SkipSimplifiedLayerNormalization" , epsilon = 1.000000e-05 : f32 } : (tensor <2 x4 x8 xf32 >, tensor <2 x4 x8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >) -> (tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x8 xf32 >)
584+ onnx.Return %r0 , %r1 , %r2 , %r3 : tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x8 xf32 >
585+ // CHECK-LABEL: func.func @skip_simplified_layernorm_four_outputs
586+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>) {
587+ // CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
588+ // CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_3_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
589+ // CHECK: [[VAR_2_:%.+]]:3 = "onnx.Custom"([[VAR_1_]], [[PARAM_2_]]) {axis = -1 : si64, domain_name = "", epsilon = 9.99999974E-6 : f32, function_name = "SimplifiedLayerNormalization", stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>)
590+ // CHECK: onnx.Return [[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2, [[VAR_1_]] : tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>
591+ // CHECK: }
592+ }
593+
594+ // -----
595+ // SkipSimplifiedLayerNormalization: 4 inputs, 4 outputs (output, mean, inv_std_var, sum), mean unused
596+
597+ func.func @skip_simplified_layernorm_four_outputs_mean_unused (%input: tensor <2 x4 x8 xf32 >, %skip: tensor <2 x4 x8 xf32 >, %gamma: tensor <8 xf32 >, %bias: tensor <8 xf32 >) -> (tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x8 xf32 >) {
598+ %r0 , %r1 , %r2 , %r3 = " onnx.Custom" (%input , %skip , %gamma , %bias ) {domain_name = " com.microsoft" , function_name = " SkipSimplifiedLayerNormalization" , epsilon = 1.000000e-05 : f32 } : (tensor <2 x4 x8 xf32 >, tensor <2 x4 x8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >) -> (tensor <2 x4 x8 xf32 >, none , tensor <2 x4 x1 xf32 >, tensor <2 x4 x8 xf32 >)
599+ onnx.Return %r0 , %r2 , %r3 : tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x8 xf32 >
600+ // CHECK-LABEL: func.func @skip_simplified_layernorm_four_outputs_mean_unused
601+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>) {
602+ // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
603+ // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
604+ // CHECK: [[VAR_2_:%.+]] = "onnx.Add"([[VAR_1_]], [[PARAM_3_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
605+ // CHECK: [[VAR_Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[VAR_2_]], [[PARAM_2_]], [[VAR_0_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, none) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>)
606+ // CHECK: onnx.Return [[VAR_Y_]], [[VAR_InvStdDev_]], [[VAR_2_]] : tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>
607+ }
608+
517609
0 commit comments