From ebe4540280287edc01c04325bebec89005dc8d32 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 21 Oct 2024 12:50:41 -0700 Subject: [PATCH] [ET-VK][ez] Apply rotary embedding as Module Pull Request resolved: https://github.com/pytorch/executorch/pull/6391 ## Context As title. Wrap the `apply_rotary_emb` function call in a `nn.Module` to make it easy to perform a source module replacement for rotary embedding calculation. The Vulkan delegate will use the source module replacement technique to insert a custom op to calculate rotary embeddings. ghstack-source-id: 249175724 @exported-using-ghexport Differential Revision: [D64697589](https://our.internmc.facebook.com/intern/diff/D64697589/) --- examples/models/llama/llama_transformer.py | 4 ++-- examples/models/llama/rope.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 3c4e3f13e6f..5a96e49ef1b 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -15,10 +15,10 @@ import torch.nn.functional as F from executorch.examples.models.llama.rope import ( - apply_rotary_emb, hf_apply_rotary_emb, hf_precompute_freqs_cis, precompute_freqs_cis, + RotaryEmbedding, ) from torch import nn @@ -311,7 +311,7 @@ def __init__(self, args: ModelArgs, layer_id: int): if args.use_hf_rope: self.apply_rotary_emb = hf_apply_rotary_emb else: - self.apply_rotary_emb = apply_rotary_emb + self.apply_rotary_emb = RotaryEmbedding() def forward( self, diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index 233c7a2f982..0383c798988 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -92,6 +92,21 @@ def apply_rotary_emb( return xq_out.type_as(xq), xk_out.type_as(xk) +class RotaryEmbedding(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + xq_out, xk_out = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) + return xq_out, xk_out + + # ======================= HuggingFace Implementation ========================