44
55from onnxscript .rewriter import _fusion_utils , _ir_utils , pattern
66
7- # Add first version of the RotaryEmbeddingFusion rule. This considers only one simple pattern
8- # for full rotation without interleaving.
9- # TODO(rama): Add pattern variations to handle other cases (interleaved, as well as partial rotation).
7+ # Fusions for RotaryEmbedding:
8+ # Fuse computation patterns seen in HF transformer models for RotaryEmbedding
9+ # and map them to ONNX opset 23 RotaryEmbedding op.
10+
11+ # Basic pattern: For example, see
12+ # https://github.com/huggingface/transformers/blob/541bed22d6e4f97946a3a7d74f7e1a353e58643b/src/transformers/models/llama/modeling_llama.py#L104
13+ # def rotate_half(x):
14+ # """Rotates half the hidden dims of the input."""
15+ # x1 = x[..., : x.shape[-1] // 2]
16+ # x2 = x[..., x.shape[-1] // 2 :]
17+ # return torch.cat((-x2, x1), dim=-1)
18+ # and
19+ # q_embed = (q * cos) + (rotate_half(q) * sin)
1020
1121
1222def _rotate_half_pattern (op , x , start1 , end1 , start2 , end2 ):
@@ -29,12 +39,12 @@ def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult:
2939 check_result = pattern .MatchResult ()
3040 # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads)
3141 if x is None or x .shape is None or len (x .shape ) != 4 :
32- return check_result .fail ("Input is not a 4D tensor." , x )
42+ return check_result .fail ("Input is not known to be a 4D tensor." , x )
3343 if not isinstance (x .shape [1 ], int ):
34- return check_result .fail ("Input dimension 1 is not an integer ." , x )
44+ return check_result .fail ("Input dimension 1 (num_heads) is not static ." , x )
3545 head_size = x .shape [3 ]
3646 if not isinstance (head_size , int ):
37- return check_result .fail ("Head size is not an integer ." , x )
47+ return check_result .fail ("Head size is not static ." , x )
3848 half_head_size = head_size // 2
3949
4050 # Check that x is being split into two equal halves of size half_head_size
@@ -60,10 +70,17 @@ def rewrite(self, op, x, cos, sin, **_):
6070 )
6171
6272
73+ # Extensions for partial rotary embedding fusion: with partial rotary embedding,
74+ # embedding is applied only to the first part of the input, and the second part is left unchanged,
75+ # as captured in the pattern below.
76+
77+ MAX_INT64 = 9223372036854775807
78+
79+
6380class PartialRotaryEmbedding23Fusion (pattern .RewriteRuleClassBase ):
6481 def pattern (self , op , x , end1 , start2 ):
6582 x_part_1 = op .Slice (x , [0 ], end1 , [3 ], [1 ])
66- x_part_2 = op .Slice (x , start2 , [9223372036854775807 ], [3 ], [1 ])
83+ x_part_2 = op .Slice (x , start2 , [MAX_INT64 ], [3 ], [1 ])
6784 x_part_1_rope = op .RotaryEmbedding (
6885 x_part_1 ,
6986 _allow_other_inputs = True ,
@@ -77,9 +94,7 @@ def check(self, op, x, end1, start2, x_part_1_rope, **_) -> pattern.MatchResult:
7794 end1_value = _ir_utils .get_singleton_value (end1 )
7895 start2_value = _ir_utils .get_singleton_value (start2 )
7996 if not isinstance (end1_value , int ) or not isinstance (start2_value , int ):
80- return check_result .fail (
81- "The end1 value of first slice and start2 value of second slice are not integers."
82- )
97+ return check_result .fail ("Unable to validate slice start/end values." )
8398 if end1_value != start2_value :
8499 return check_result .fail (
85100 "The end1 value of first slice and start2 value of second slice are not equal."
0 commit comments