@@ -866,3 +866,48 @@ func.func @gqa_with_scale_softcap_and_qk_output_2(
866866// CHECK-SAME: : (tensor<1x128x3072xf32>, tensor<1x128x1536xf32>, tensor<1x128x1536xf32>, none, tensor<1x16x256x96xf32>, tensor<1x16x256x96xf32>) -> (tensor<1x128x3072xf32>, tensor<1x16x384x96xf32>, tensor<1x16x384x96xf32>, tensor<1x32x128x256xf32>)
867867// CHECK: return %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] : tensor<1x128x3072xf32>, tensor<1x16x384x96xf32>, tensor<1x16x384x96xf32>, tensor<1x32x128x256xf32>
868868// CHECK: }
869+
870+ // -----
871+
872+ func.func @rotary_embedding_4d_interleaved_rotdim_16 (%data: tensor <1 x32 x128 x96 xf32 >, %pos_ids: tensor <1 x128 xi64 >, %cos_cache: tensor <4096 x8 xf32 >, %sin_cache: tensor <4096 x8 xf32 >) -> tensor <1 x32 x128 x96 xf32 > {
873+ %0 = " onnx.Custom" (%data , %pos_ids , %cos_cache , %sin_cache ) {
874+ domain_name = " com.microsoft" ,
875+ function_name = " RotaryEmbedding" ,
876+ interleaved = 1 : si64 ,
877+ rotary_embedding_dim = 16 : si64
878+ }: (tensor <1 x32 x128 x96 xf32 >, tensor <1 x128 xi64 >, tensor <4096 x8 xf32 >, tensor <4096 x8 xf32 >) -> tensor <1 x32 x128 x96 xf32 >
879+ return %0 : tensor <1 x32 x128 x96 xf32 >
880+ }
881+
882+ // CHECK-LABEL: func.func @rotary_embedding_4d_interleaved_rotdim_16(
883+ // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x32x128x96xf32>,
884+ // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x128xi64>,
885+ // CHECK-SAME: %[[VAL_2:.*]]: tensor<4096x8xf32>,
886+ // CHECK-SAME: %[[VAL_3:.*]]: tensor<4096x8xf32>) -> tensor<1x32x128x96xf32> {
887+ // CHECK: %[[VAL_4:.*]] = "onnx.RotaryEmbedding"(%[[VAL_0]], %[[VAL_2]], %[[VAL_3]], %[[VAL_1]])
888+ // CHECK-SAME: {interleaved = 1 : si64, rotary_embedding_dim = 16 : si64}
889+ // CHECK-SAME: : (tensor<1x32x128x96xf32>, tensor<4096x8xf32>, tensor<4096x8xf32>, tensor<1x128xi64>) -> tensor<1x32x128x96xf32>
890+ // CHECK: return %[[VAL_4]] : tensor<1x32x128x96xf32>
891+ // CHECK: }
892+
893+ // -----
894+
895+ func.func @test_rotary_embedding_3d (%data: tensor <1 x128 x3072 xf32 >, %pos_ids: tensor <1 x128 xi64 >, %cos_cache: tensor <4096 x48 xf32 >, %sin_cache: tensor <4096 x48 xf32 >) -> tensor <1 x128 x3072 xf32 > {
896+ %0 = " onnx.Custom" (%data , %pos_ids , %cos_cache , %sin_cache ) {
897+ domain_name = " com.microsoft" ,
898+ function_name = " RotaryEmbedding" ,
899+ num_heads = 32 : si64
900+ } : (tensor <1 x128 x3072 xf32 >, tensor <1 x128 xi64 >, tensor <4096 x48 xf32 >, tensor <4096 x48 xf32 >) -> tensor <1 x128 x3072 xf32 >
901+ return %0 : tensor <1 x128 x3072 xf32 >
902+ }
903+
904+ // CHECK-LABEL: func.func @test_rotary_embedding_3d(
905+ // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x128x3072xf32>,
906+ // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x128xi64>,
907+ // CHECK-SAME: %[[VAL_2:.*]]: tensor<4096x48xf32>,
908+ // CHECK-SAME: %[[VAL_3:.*]]: tensor<4096x48xf32>) -> tensor<1x128x3072xf32> {
909+ // CHECK: %[[VAL_4:.*]] = "onnx.RotaryEmbedding"(%[[VAL_0]], %[[VAL_2]], %[[VAL_3]], %[[VAL_1]])
910+ // CHECK-SAME: {interleaved = 0 : si64, num_heads = 32 : si64, rotary_embedding_dim = 0 : si64}
911+ // CHECK-SAME: : (tensor<1x128x3072xf32>, tensor<4096x48xf32>, tensor<4096x48xf32>, tensor<1x128xi64>) -> tensor<1x128x3072xf32>
912+ // CHECK: return %[[VAL_4]] : tensor<1x128x3072xf32>
913+ // CHECK: }
0 commit comments