77
88# Please refer to README.md in the same folder for more information. 
99
10- from  typing  import  Optional 
10+ from  typing  import  Any ,  Optional ,  Tuple ,  Union 
1111
1212import  torch 
1313import  torch .nn .functional  as  F 
1414
15- from  executorch .examples .models .llama .attention  import  ATTENTION_REGISTRY 
15+ from  executorch .examples .models .llama .attention  import  (
16+     ATTENTION_REGISTRY ,
17+     ForwardOptions ,
18+ )
1619
1720from  executorch .examples .models .llama .model_args  import  ModelArgs 
1821
@@ -148,17 +151,17 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
148151        self .attention_norm  =  RMSNorm (args .dim , eps = args .norm_eps )
149152        self .ffn_norm  =  RMSNorm (args .dim , eps = args .norm_eps )
150153
151-     def  forward (self , x , freqs_cos , freqs_sin , input_pos = None ):  # x: 1xN 
152-         h  =  self .attention .forward (
153-             self .attention_norm (x ), freqs_cos , freqs_sin , input_pos = input_pos 
154+     def  forward (self , x , freqs_cos , freqs_sin , attn_options :  ForwardOptions ):  # x: 1xN 
155+         h ,  attn_options_update  =  self .attention .forward (
156+             self .attention_norm (x ), freqs_cos , freqs_sin , ** attn_options 
154157        )
155158
156159        h  =  x  +  h 
157160        if  hasattr (self , "block_sparse_moe" ):
158161            out  =  h  +  self .block_sparse_moe (self .ffn_norm (h ))
159162        else :
160163            out  =  h  +  self .feed_forward (self .ffn_norm (h ))
161-         return  out 
164+         return  out ,  attn_options_update 
162165
163166
164167class  Transformer (nn .Module ):
@@ -185,27 +188,28 @@ def __init__(self, params: ModelArgs):
185188    def  forward (
186189        self ,
187190        tokens : Optional [torch .LongTensor ] =  None ,  # tokens 
188-         input_pos : Optional [
189-             torch .LongTensor 
190-         ] =  None ,  # Scalar tensor indicating size of window of the caches 
191+         attn_options : Optional [ForwardOptions ] =  None ,
191192        h : Optional [torch .FloatTensor ] =  None ,  # embeddings 
192-     ) ->  torch .Tensor :
193+     ) ->  Union [ torch .Tensor ,  Tuple [ torch . Tensor ,  Optional [ Any ]]] :
193194        if  (tokens  is  None ) ^  (h  is  not   None ):
194195            raise  ValueError (
195196                "You cannot specify both tokens and h at the same time, and must specify either one" 
196197            )
197198        if  tokens  is  not   None  and  h  is  None :
198199            h  =  self .tok_embeddings (tokens )
200+ 
201+         if  attn_options  is  None :
202+             attn_options  =  {}
199203        seqlen  =  h .shape [1 ]
200-         freqs_cos , freqs_sin  =  self .rope .get_freqs (input_pos , seqlen )
204+         freqs_cos , freqs_sin  =  self .rope .get_freqs (
205+             attn_options .get ("input_pos" ), seqlen 
206+         )
201207
208+         attn_options_update  =  None 
202209        for  layer  in  self .layers :
203-             h  =  layer (
204-                 h ,
205-                 freqs_cos ,
206-                 freqs_sin ,
207-                 input_pos ,
208-             )
210+             h , attn_options_update  =  layer (h , freqs_cos , freqs_sin , attn_options )
211+             if  attn_options_update  is  not   None :
212+                 attn_options .update (** attn_options_update )
209213
210214        if  not  self .generate_full_logits :
211215            # Only the last logit is used for the new generated token 
@@ -237,4 +241,7 @@ def forward(
237241                expanded_logits [:, list (self .output_prune_map .values ())] =  logits 
238242            logits  =  expanded_logits 
239243
244+         if  attn_options_update  is  not   None :
245+             return  logits , attn_options_update 
246+ 
240247        return  logits 
0 commit comments