@@ -42,6 +42,7 @@ class ModelArgs:
4242    ffn_dim_multiplier : Optional [float ] =  None 
4343    norm_eps : float  =  1e-5 
4444    max_batch_size : int  =  1 
45+     static_seq_len : int  =  32 
4546    max_seq_len : int  =  128 
4647    max_context_len : int  =  2048 
4748    moe : bool  =  False   # True to enable the MoE (Mixture of Experts) 
@@ -398,15 +399,18 @@ def __init__(self, params: ModelArgs):
398399        self .input_prune_map  =  params .input_prune_map 
399400        self .output_prune_map  =  params .output_prune_map 
400401        self .use_cache_list  =  params .use_cache_list 
402+         if  self .use_cache_list :
403+             # pyre-ignore: Incompatible attribute type [8] 
404+             self .forward  =  self .forward_use_cache_list 
401405
402-     def  forward (
406+     def  forward_use_cache_list (
403407        self ,
404408        tokens : torch .LongTensor ,  # tokens 
405409        input_pos : torch .LongTensor ,
406-         input_length : torch .LongTensor ,  # input_length 
407410        k_caches : List [torch .FloatTensor ],
408411        v_caches : List [torch .FloatTensor ],
409-         attn_mask : torch .LongTensor ,
412+         attn_mask : torch .FloatTensor ,
413+         input_length : torch .LongTensor ,  # input_length 
410414        h : Optional [torch .FloatTensor ] =  None ,  # embeddings 
411415    ) ->  Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
412416        if  (tokens  is  None ) ^  (h  is  not   None ):
@@ -425,8 +429,8 @@ def forward(
425429                h ,
426430                freqs_cos ,
427431                freqs_sin ,
428-                 k_caches [i ]  if   self . use_cache_list   else   k_caches [ i , :, :, :, :] ,
429-                 v_caches [i ]  if   self . use_cache_list   else   v_caches [ i , :, :, :, :] ,
432+                 k_caches [i ],
433+                 v_caches [i ],
430434                attn_mask ,
431435            )
432436            k_out .append (new_k )
@@ -445,15 +449,64 @@ def forward(
445449            v_out  =  torch .stack (v_out , dim = 0 )
446450        return  logits , k_out , v_out   # pyre-ignore[7] 
447451
452+     def  forward (
453+         self ,
454+         tokens : torch .LongTensor ,  # tokens 
455+         input_pos : torch .LongTensor ,
456+         k_caches : torch .FloatTensor ,
457+         v_caches : torch .FloatTensor ,
458+         attn_mask : torch .FloatTensor ,
459+         input_length : torch .LongTensor ,  # input_length 
460+         h : Optional [torch .FloatTensor ] =  None ,  # embeddings 
461+     ) ->  Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
462+         if  (tokens  is  None ) ^  (h  is  not   None ):
463+             raise  ValueError (
464+                 "You cannot specify both tokens and h at the same time, and must specify either one" 
465+             )
466+         if  tokens  is  not   None  and  h  is  None :
467+             h  =  self .tok_embeddings (tokens )
468+         seqlen  =  h .shape [1 ]
469+         freqs_cos , freqs_sin  =  self .rope .get_freqs (input_pos , seqlen )
448470
449- def  load_model (checkpoint_path , params_path , max_seq_length , use_cache_list ):
471+         k_out  =  []
472+         v_out  =  []
473+         for  i , layer  in  enumerate (self .layers ):
474+             h , new_k , new_v  =  layer (
475+                 h ,
476+                 freqs_cos ,
477+                 freqs_sin ,
478+                 k_caches [i , :, :, :, :],
479+                 v_caches [i , :, :, :, :],
480+                 attn_mask ,
481+             )
482+             k_out .append (new_k )
483+             v_out .append (new_v )
484+ 
485+         if  not  self .generate_full_logits :
486+             # Only the last logit is used for the new generated token 
487+             h  =  h [:, input_length  -  1 , :].squeeze (1 )
488+ 
489+         h  =  self .norm (h )
490+ 
491+         logits  =  self .output (h )
492+ 
493+         if  not  self .use_cache_list :
494+             k_out  =  torch .stack (k_out , dim = 0 )
495+             v_out  =  torch .stack (v_out , dim = 0 )
496+         return  logits , k_out , v_out   # pyre-ignore[7] 
497+ 
498+ 
499+ def  load_model (
500+     checkpoint_path , params_path , max_seq_length , use_cache_list , static_seq_len = 32 
501+ ):
450502    import  json 
451503
452504    with  open (params_path , "r" ) as  f :
453505        params  =  json .loads (f .read ())
454506
455507    args  =  ModelArgs (
456508        max_seq_len = max_seq_length ,
509+         static_seq_len = static_seq_len ,
457510        generate_full_logits = False ,
458511        use_cache_list = use_cache_list ,
459512        ** params ,
@@ -618,14 +671,14 @@ def get_inputs(self, tokens: List[int]):
618671            ).reshape (1 , - 1 ),
619672            # input_pos 
620673            torch .tensor ([self .input_pos ], dtype = torch .long ),
621-             # input_length 
622-             torch .tensor ([input_length ], dtype = torch .long ),
623674            # k_cache 
624675            self .k_caches ,
625676            # v_cache 
626677            self .v_caches ,
627678            # attn_mask 
628-             self .attn_mask ,
679+             torch .zeros (self .attn_mask .shape , dtype = torch .float16 ),
680+             # input_length 
681+             torch .tensor ([input_length ], dtype = torch .long ),
629682        )
630683
631684    def  get_inputs_and_remaining_tokens (self , tokens : List [int ]):
0 commit comments