@@ -374,13 +374,15 @@ func.func @skip_layernorm_basic(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf
374374}
375375
376376
377+
378+
377379// -----
378380// SkipLayerNormalization: 4 inputs (beta), 1 output
379381
380- func.func @skip_layernorm_beta (%input: tensor <2 x4 x8 xf32 >, %skip: tensor <2 x4 x8 xf32 >, %gamma: tensor <8 xf32 >, %beta: tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 > {
381- %r = " onnx.Custom" (%input , %skip , %gamma , %beta ) {domain_name = " com.microsoft" , function_name = " SkipLayerNormalization" , epsilon = 1.000000e-05 : f32 } : (tensor <2 x4 x8 xf32 >, tensor <2 x4 x8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 >
382+ func.func @skip_layernorm_beta_no_eps (%input: tensor <2 x4 x8 xf32 >, %skip: tensor <2 x4 x8 xf32 >, %gamma: tensor <8 xf32 >, %beta: tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 > {
383+ %r = " onnx.Custom" (%input , %skip , %gamma , %beta ) {domain_name = " com.microsoft" , function_name = " SkipLayerNormalization" } : (tensor <2 x4 x8 xf32 >, tensor <2 x4 x8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 >
382384 onnx.Return %r : tensor <2 x4 x8 xf32 >
383- // CHECK-LABEL: func.func @skip_layernorm_beta
385+ // CHECK-LABEL: func.func @skip_layernorm_beta_no_eps
384386// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
385387// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
386388// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_0_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, none, none)
@@ -461,6 +463,19 @@ func.func @simplified_layernorm_basic(%input: tensor<2x4x8xf32>, %scale: tensor<
461463// CHECK: onnx.Return [[VAR_Y_]] : tensor<2x4x8xf32>
462464}
463465
466+ // -----
467+
468+ func.func @simplified_layernorm_no_attrs (%input: tensor <2 x4 x8 xf32 >, %scale: tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 > {
469+ %r = " onnx.Custom" (%input , %scale ) {domain_name = " " , function_name = " SimplifiedLayerNormalization" } : (tensor <2 x4 x8 xf32 >, tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 >
470+ onnx.Return %r : tensor <2 x4 x8 xf32 >
471+ // CHECK-LABEL: func.func @simplified_layernorm_no_attrs
472+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
473+ // CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
474+ // CHECK: [[VAR_Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, none) -> (tensor<2x4x8xf32>, none)
475+ // CHECK: onnx.Return [[VAR_Y_]] : tensor<2x4x8xf32>
476+ }
477+
478+
464479
465480// -----
466481// SimplifiedLayerNormalization: 3 inputs (with bias), 1 output
@@ -517,10 +532,10 @@ func.func @simplified_layernorm_two_outputs_mean_used(%input: tensor<2x4x8xf32>,
517532// -----
518533// SkipSimplifiedLayerNormalization: 3 inputs, 1 output
519534
520- func.func @skip_simplified_layernorm_basic (%input: tensor <2 x4 x8 xf32 >, %skip: tensor <2 x4 x8 xf32 >, %gamma: tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 > {
521- %r = " onnx.Custom" (%input , %skip , %gamma ) {domain_name = " com.microsoft" , function_name = " SkipSimplifiedLayerNormalization" , epsilon = 1.000000e-05 : f32 } : (tensor <2 x4 x8 xf32 >, tensor <2 x4 x8 xf32 >, tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 >
535+ func.func @skip_simplified_layernorm_basic_no_attr (%input: tensor <2 x4 x8 xf32 >, %skip: tensor <2 x4 x8 xf32 >, %gamma: tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 > {
536+ %r = " onnx.Custom" (%input , %skip , %gamma ) {domain_name = " com.microsoft" , function_name = " SkipSimplifiedLayerNormalization" } : (tensor <2 x4 x8 xf32 >, tensor <2 x4 x8 xf32 >, tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 >
522537 onnx.Return %r : tensor <2 x4 x8 xf32 >
523- // CHECK-LABEL: func.func @skip_simplified_layernorm_basic
538+ // CHECK-LABEL: func.func @skip_simplified_layernorm_basic_no_attr
524539// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
525540// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
526541// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
0 commit comments