Skip to content

Commit ebe4540

Browse files
committed
[ET-VK][ez] Apply rotary embedding as Module
Pull Request resolved: #6391 ## Context As title. Wrap the `apply_rotary_emb` function call in a `nn.Module` to make it easy to perform a source module replacement for rotary embedding calculation. The Vulkan delegate will use the source module replacement technique to insert a custom op to calculate rotary embeddings. ghstack-source-id: 249175724 @exported-using-ghexport Differential Revision: [D64697589](https://our.internmc.facebook.com/intern/diff/D64697589/)
1 parent 46ea1a4 commit ebe4540

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
import torch.nn.functional as F
1616

1717
from executorch.examples.models.llama.rope import (
18-
apply_rotary_emb,
1918
hf_apply_rotary_emb,
2019
hf_precompute_freqs_cis,
2120
precompute_freqs_cis,
21+
RotaryEmbedding,
2222
)
2323

2424
from torch import nn
@@ -311,7 +311,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
311311
if args.use_hf_rope:
312312
self.apply_rotary_emb = hf_apply_rotary_emb
313313
else:
314-
self.apply_rotary_emb = apply_rotary_emb
314+
self.apply_rotary_emb = RotaryEmbedding()
315315

316316
def forward(
317317
self,

examples/models/llama/rope.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,21 @@ def apply_rotary_emb(
9292
return xq_out.type_as(xq), xk_out.type_as(xk)
9393

9494

95+
class RotaryEmbedding(torch.nn.Module):
96+
def __init__(self):
97+
super().__init__()
98+
99+
def forward(
100+
self,
101+
xq: torch.Tensor,
102+
xk: torch.Tensor,
103+
freqs_cos: torch.Tensor,
104+
freqs_sin: torch.Tensor,
105+
):
106+
xq_out, xk_out = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
107+
return xq_out, xk_out
108+
109+
95110
# ======================= HuggingFace Implementation ========================
96111

97112

0 commit comments

Comments
 (0)