diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 3c4e3f13e6f..5a96e49ef1b 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -15,10 +15,10 @@ import torch.nn.functional as F from executorch.examples.models.llama.rope import ( - apply_rotary_emb, hf_apply_rotary_emb, hf_precompute_freqs_cis, precompute_freqs_cis, + RotaryEmbedding, ) from torch import nn @@ -311,7 +311,7 @@ def __init__(self, args: ModelArgs, layer_id: int): if args.use_hf_rope: self.apply_rotary_emb = hf_apply_rotary_emb else: - self.apply_rotary_emb = apply_rotary_emb + self.apply_rotary_emb = RotaryEmbedding() def forward( self, diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index 233c7a2f982..0383c798988 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -92,6 +92,21 @@ def apply_rotary_emb( return xq_out.type_as(xq), xk_out.type_as(xk) +class RotaryEmbedding(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + xq_out, xk_out = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) + return xq_out, xk_out + + # ======================= HuggingFace Implementation ========================