Skip to content

Commit a181003

Browse files
committed
Address PR feedback
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent 2b4b473 commit a181003

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

onnxscript/rewriter/onnx_fusions/_rotary_embedding.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,19 @@
44

55
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern
66

7-
# Add first version of the RotaryEmbeddingFusion rule. This considers only one simple pattern
8-
# for full rotation without interleaving.
9-
# TODO(rama): Add pattern variations to handle other cases (interleaved, as well as partial rotation).
7+
# Fusions for RotaryEmbedding:
8+
# Fuse computation patterns seen in HF transformer models for RotaryEmbedding
9+
# and map them to ONNX opset 23 RotaryEmbedding op.
10+
11+
# Basic pattern: For example, see
12+
# https://github.com/huggingface/transformers/blob/541bed22d6e4f97946a3a7d74f7e1a353e58643b/src/transformers/models/llama/modeling_llama.py#L104
13+
# def rotate_half(x):
14+
# """Rotates half the hidden dims of the input."""
15+
# x1 = x[..., : x.shape[-1] // 2]
16+
# x2 = x[..., x.shape[-1] // 2 :]
17+
# return torch.cat((-x2, x1), dim=-1)
18+
# and
19+
# q_embed = (q * cos) + (rotate_half(q) * sin)
1020

1121

1222
def _rotate_half_pattern(op, x, start1, end1, start2, end2):
@@ -29,12 +39,12 @@ def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult:
2939
check_result = pattern.MatchResult()
3040
# x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads)
3141
if x is None or x.shape is None or len(x.shape) != 4:
32-
return check_result.fail("Input is not a 4D tensor.", x)
42+
return check_result.fail("Input is not known to be a 4D tensor.", x)
3343
if not isinstance(x.shape[1], int):
34-
return check_result.fail("Input dimension 1 is not an integer.", x)
44+
return check_result.fail("Input dimension 1 (num_heads) is not static.", x)
3545
head_size = x.shape[3]
3646
if not isinstance(head_size, int):
37-
return check_result.fail("Head size is not an integer.", x)
47+
return check_result.fail("Head size is not static.", x)
3848
half_head_size = head_size // 2
3949

4050
# Check that x is being split into two equal halves of size half_head_size
@@ -60,10 +70,17 @@ def rewrite(self, op, x, cos, sin, **_):
6070
)
6171

6272

73+
# Extensions for partial rotary embedding fusion: with partial rotary embedding,
74+
# embedding is applied only to the first part of the input, and the second part is left unchanged,
75+
# as captured in the pattern below.
76+
77+
MAX_INT64 = 9223372036854775807
78+
79+
6380
class PartialRotaryEmbedding23Fusion(pattern.RewriteRuleClassBase):
6481
def pattern(self, op, x, end1, start2):
6582
x_part_1 = op.Slice(x, [0], end1, [3], [1])
66-
x_part_2 = op.Slice(x, start2, [9223372036854775807], [3], [1])
83+
x_part_2 = op.Slice(x, start2, [MAX_INT64], [3], [1])
6784
x_part_1_rope = op.RotaryEmbedding(
6885
x_part_1,
6986
_allow_other_inputs=True,
@@ -77,9 +94,7 @@ def check(self, op, x, end1, start2, x_part_1_rope, **_) -> pattern.MatchResult:
7794
end1_value = _ir_utils.get_singleton_value(end1)
7895
start2_value = _ir_utils.get_singleton_value(start2)
7996
if not isinstance(end1_value, int) or not isinstance(start2_value, int):
80-
return check_result.fail(
81-
"The end1 value of first slice and start2 value of second slice are not integers."
82-
)
97+
return check_result.fail("Unable to validate slice start/end values.")
8398
if end1_value != start2_value:
8499
return check_result.fail(
85100
"The end1 value of first slice and start2 value of second slice are not equal."

0 commit comments

Comments
 (0)