@@ -357,4 +357,93 @@ func.func @fusedconv_too_many_operands(%x: tensor<1x3x8x8xf32>, %w: tensor<4x3x3
357357// CHECK: onnx.Return [[VAR_0_]] : tensor<1x4x8x8xf32>
358358// CHECK: }
359359
360- }
360+ }
361+
362+ // -----
363+ // SkipLayerNormalization: 3 inputs, 1 output
364+
365+ func.func @skip_layernorm_basic (%input: tensor <2 x4 x8 xf32 >, %skip: tensor <2 x4 x8 xf32 >, %gamma: tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 > {
366+ %r = " onnx.Custom" (%input , %skip , %gamma ) {domain_name = " com.microsoft" , function_name = " SkipLayerNormalization" , epsilon = 1.000000e-05 : f32 } : (tensor <2 x4 x8 xf32 >, tensor <2 x4 x8 xf32 >, tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 >
367+ onnx.Return %r : tensor <2 x4 x8 xf32 >
368+ // CHECK-LABEL: func.func @skip_layernorm_basic
369+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
370+ // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
371+ // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
372+ // CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_1_]], [[PARAM_2_]], [[VAR_0_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, none) -> (tensor<2x4x8xf32>, none, none)
373+ // CHECK: onnx.Return [[VAR_Y_]] : tensor<2x4x8xf32>
374+ }
375+
376+
377+ // -----
378+ // SkipLayerNormalization: 4 inputs (beta), 1 output
379+
380+ func.func @skip_layernorm_beta (%input: tensor <2 x4 x8 xf32 >, %skip: tensor <2 x4 x8 xf32 >, %gamma: tensor <8 xf32 >, %beta: tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 > {
381+ %r = " onnx.Custom" (%input , %skip , %gamma , %beta ) {domain_name = " com.microsoft" , function_name = " SkipLayerNormalization" , epsilon = 1.000000e-05 : f32 } : (tensor <2 x4 x8 xf32 >, tensor <2 x4 x8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 >
382+ onnx.Return %r : tensor <2 x4 x8 xf32 >
383+ // CHECK-LABEL: func.func @skip_layernorm_beta
384+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
385+ // CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
386+ // CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_0_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, none, none)
387+ // CHECK: onnx.Return [[VAR_Y_]] : tensor<2x4x8xf32>
388+ }
389+
390+
391+ // -----
392+ // SkipLayerNormalization: 5 inputs (beta + bias), 1 output
393+
394+ func.func @skip_layernorm_beta_bias (%input: tensor <2 x4 x8 xf32 >, %skip: tensor <2 x4 x8 xf32 >, %gamma: tensor <8 xf32 >, %beta: tensor <8 xf32 >, %bias: tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 > {
395+ %r = " onnx.Custom" (%input , %skip , %gamma , %beta , %bias ) {domain_name = " com.microsoft" , function_name = " SkipLayerNormalization" , epsilon = 1.000000e-05 : f32 } : (tensor <2 x4 x8 xf32 >, tensor <2 x4 x8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >) -> tensor <2 x4 x8 xf32 >
396+ onnx.Return %r : tensor <2 x4 x8 xf32 >
397+ // CHECK-LABEL: func.func @skip_layernorm_beta_bias
398+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>, [[PARAM_4_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
399+ // CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
400+ // CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_4_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
401+ // CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_1_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, none, none)
402+ // CHECK: onnx.Return [[VAR_Y_]] : tensor<2x4x8xf32>
403+ }
404+
405+
406+ // -----
407+ // SkipLayerNormalization: 5 inputs, 2 outputs (output, mean)
408+
409+ func.func @skip_layernorm_two_outputs (%input: tensor <2 x4 x8 xf32 >, %skip: tensor <2 x4 x8 xf32 >, %gamma: tensor <8 xf32 >, %beta: tensor <8 xf32 >, %bias: tensor <8 xf32 >) -> (tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >) {
410+ %r0 , %r1 = " onnx.Custom" (%input , %skip , %gamma , %beta , %bias ) {domain_name = " com.microsoft" , function_name = " SkipLayerNormalization" , epsilon = 1.000000e-05 : f32 } : (tensor <2 x4 x8 xf32 >, tensor <2 x4 x8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >) -> (tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >)
411+ onnx.Return %r0 , %r1 : tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >
412+ // CHECK-LABEL: func.func @skip_layernorm_two_outputs
413+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>, [[PARAM_4_:%.+]]: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>) {
414+ // CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
415+ // CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_4_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
416+ // CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_1_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, none)
417+ // CHECK: onnx.Return [[VAR_Y_]], [[VAR_Mean_]] : tensor<2x4x8xf32>, tensor<2x4x1xf32>
418+ }
419+
420+
421+ // -----
422+ // SkipLayerNormalization: 5 inputs, 3 outputs (output, mean, inv_std_var)
423+
424+ func.func @skip_layernorm_three_outputs (%input: tensor <2 x4 x8 xf32 >, %skip: tensor <2 x4 x8 xf32 >, %gamma: tensor <8 xf32 >, %beta: tensor <8 xf32 >, %bias: tensor <8 xf32 >) -> (tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x1 xf32 >) {
425+ %r0 , %r1 , %r2 = " onnx.Custom" (%input , %skip , %gamma , %beta , %bias ) {domain_name = " com.microsoft" , function_name = " SkipLayerNormalization" , epsilon = 1.000000e-05 : f32 } : (tensor <2 x4 x8 xf32 >, tensor <2 x4 x8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >) -> (tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x1 xf32 >)
426+ onnx.Return %r0 , %r1 , %r2 : tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x1 xf32 >
427+ // CHECK-LABEL: func.func @skip_layernorm_three_outputs
428+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>, [[PARAM_4_:%.+]]: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>) {
429+ // CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
430+ // CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_4_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
431+ // CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_1_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>)
432+ // CHECK: onnx.Return [[VAR_Y_]], [[VAR_Mean_]], [[VAR_InvStdDev_]] : tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>
433+ }
434+
435+
436+ // -----
437+ // SkipLayerNormalization: 5 inputs, 4 outputs (output, mean, inv_std_var, sum)
438+
439+ func.func @skip_layernorm_four_outputs (%input: tensor <2 x4 x8 xf32 >, %skip: tensor <2 x4 x8 xf32 >, %gamma: tensor <8 xf32 >, %beta: tensor <8 xf32 >, %bias: tensor <8 xf32 >) -> (tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x8 xf32 >) {
440+ %r0 , %r1 , %r2 , %r3 = " onnx.Custom" (%input , %skip , %gamma , %beta , %bias ) {domain_name = " com.microsoft" , function_name = " SkipLayerNormalization" , epsilon = 1.000000e-05 : f32 } : (tensor <2 x4 x8 xf32 >, tensor <2 x4 x8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >) -> (tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x8 xf32 >)
441+ onnx.Return %r0 , %r1 , %r2 , %r3 : tensor <2 x4 x8 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x1 xf32 >, tensor <2 x4 x8 xf32 >
442+ // CHECK-LABEL: func.func @skip_layernorm_four_outputs
443+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>, [[PARAM_4_:%.+]]: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>) {
444+ // CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
445+ // CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_4_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
446+ // CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_1_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>)
447+ // CHECK: onnx.Return [[VAR_Y_]], [[VAR_Mean_]], [[VAR_InvStdDev_]], [[VAR_1_]] : tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>
448+ }
449+
0 commit comments