Skip to content

Commit 9a6178a

Browse files
committed
Formatting
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent 2d8192a commit 9a6178a

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

onnxscript/rewriter/onnx_fusions/_rotary_embedding.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010

1111
# Basic pattern: For example, see
1212
# https://github.com/huggingface/transformers/blob/541bed22d6e4f97946a3a7d74f7e1a353e58643b/src/transformers/models/llama/modeling_llama.py#L104
13-
# def rotate_half(x):
14-
# """Rotates half the hidden dims of the input."""
15-
# x1 = x[..., : x.shape[-1] // 2]
16-
# x2 = x[..., x.shape[-1] // 2 :]
17-
# return torch.cat((-x2, x1), dim=-1)
13+
# def rotate_half(x):
14+
# """Rotates half the hidden dims of the input."""
15+
# x1 = x[..., : x.shape[-1] // 2]
16+
# x2 = x[..., x.shape[-1] // 2 :]
17+
# return torch.cat((-x2, x1), dim=-1)
1818
# and
19-
# q_embed = (q * cos) + (rotate_half(q) * sin)
19+
# q_embed = (q * cos) + (rotate_half(q) * sin)
2020

2121

2222
def _rotate_half_pattern(op, x, start1, end1, start2, end2):

0 commit comments

Comments
 (0)