77
88# Please refer to README.md in the same folder for more information.
99
10- from typing import Optional
10+ from typing import Any , Optional , Tuple , Union
1111
1212import torch
1313import torch .nn .functional as F
1414
15- from executorch .examples .models .llama .attention import ATTENTION_REGISTRY
15+ from executorch .examples .models .llama .attention import (
16+ ATTENTION_REGISTRY ,
17+ ForwardOptions ,
18+ )
1619
1720from executorch .examples .models .llama .model_args import ModelArgs
1821
@@ -148,17 +151,17 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
148151 self .attention_norm = RMSNorm (args .dim , eps = args .norm_eps )
149152 self .ffn_norm = RMSNorm (args .dim , eps = args .norm_eps )
150153
151- def forward (self , x , freqs_cos , freqs_sin , input_pos = None ): # x: 1xN
152- h = self .attention .forward (
153- self .attention_norm (x ), freqs_cos , freqs_sin , input_pos = input_pos
154+ def forward (self , x , freqs_cos , freqs_sin , attn_options : ForwardOptions ): # x: 1xN
155+ h , attn_options_update = self .attention .forward (
156+ self .attention_norm (x ), freqs_cos , freqs_sin , ** attn_options
154157 )
155158
156159 h = x + h
157160 if hasattr (self , "block_sparse_moe" ):
158161 out = h + self .block_sparse_moe (self .ffn_norm (h ))
159162 else :
160163 out = h + self .feed_forward (self .ffn_norm (h ))
161- return out
164+ return out , attn_options_update
162165
163166
164167class Transformer (nn .Module ):
@@ -185,27 +188,28 @@ def __init__(self, params: ModelArgs):
185188 def forward (
186189 self ,
187190 tokens : Optional [torch .LongTensor ] = None , # tokens
188- input_pos : Optional [
189- torch .LongTensor
190- ] = None , # Scalar tensor indicating size of window of the caches
191+ attn_options : Optional [ForwardOptions ] = None ,
191192 h : Optional [torch .FloatTensor ] = None , # embeddings
192- ) -> torch .Tensor :
193+ ) -> Union [ torch .Tensor , Tuple [ torch . Tensor , Optional [ Any ]]] :
193194 if (tokens is None ) ^ (h is not None ):
194195 raise ValueError (
195196 "You cannot specify both tokens and h at the same time, and must specify either one"
196197 )
197198 if tokens is not None and h is None :
198199 h = self .tok_embeddings (tokens )
200+
201+ if attn_options is None :
202+ attn_options = {}
199203 seqlen = h .shape [1 ]
200- freqs_cos , freqs_sin = self .rope .get_freqs (input_pos , seqlen )
204+ freqs_cos , freqs_sin = self .rope .get_freqs (
205+ attn_options .get ("input_pos" ), seqlen
206+ )
201207
208+ attn_options_update = None
202209 for layer in self .layers :
203- h = layer (
204- h ,
205- freqs_cos ,
206- freqs_sin ,
207- input_pos ,
208- )
210+ h , attn_options_update = layer (h , freqs_cos , freqs_sin , attn_options )
211+ if attn_options_update is not None :
212+ attn_options .update (** attn_options_update )
209213
210214 if not self .generate_full_logits :
211215 # Only the last logit is used for the new generated token
@@ -237,4 +241,7 @@ def forward(
237241 expanded_logits [:, list (self .output_prune_map .values ())] = logits
238242 logits = expanded_logits
239243
244+ if attn_options_update is not None :
245+ return logits , attn_options_update
246+
240247 return logits
0 commit comments