@@ -114,6 +114,7 @@ class ModelArgs:
114114    num_experts : int  =  8   # Number of experts 
115115    num_activated_experts : int  =  2   # Number of experts to activate 
116116    use_kv_cache : bool  =  False   # Use key/value cache 
117+     prefill_return_kv : bool  =  False   # Return kv cache for prefill 
117118    use_sdpa_with_kv_cache_op : bool  =  (
118119        False   # Use custom sdpa op that updates kv cache in-place 
119120    )
@@ -420,7 +421,11 @@ def forward(
420421        freqs_cos : torch .Tensor ,
421422        freqs_sin : torch .Tensor ,
422423        input_pos : Optional [torch .Tensor ] =  None ,
424+         return_kv : bool  =  False ,
423425    ):
426+         if  return_kv :
427+             assert  self .use_kv_cache  ==  False , "Can't return kv when use_kv_cache is True" 
428+ 
424429        bsz , seqlen , _  =  x .shape 
425430
426431        # QKV 
@@ -442,6 +447,10 @@ def forward(
442447        k  =  k .transpose (1 , 2 )
443448        v  =  v .transpose (1 , 2 )
444449
450+         if  return_kv :
451+             k_ret  =  k 
452+             v_ret  =  v 
453+ 
445454        # grouped multiquery attention: expand out keys and values 
446455        k  =  k .repeat_interleave (self .n_rep , dim = 1 )
447456        v  =  v .repeat_interleave (self .n_rep , dim = 1 )
@@ -456,6 +465,8 @@ def forward(
456465
457466        output  =  self .wo (output )
458467
468+         if  return_kv :
469+             return  output , k_ret , v_ret 
459470        return  output 
460471
461472
@@ -533,16 +544,24 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
533544        self .attention_norm  =  RMSNorm (args .dim , eps = args .norm_eps )
534545        self .ffn_norm  =  RMSNorm (args .dim , eps = args .norm_eps )
535546
536-     def  forward (self , x , freqs_cos , freqs_sin , input_pos = None ):  # x: 1xN 
537-         h  =  self .attention .forward (
538-             self .attention_norm (x ), freqs_cos , freqs_sin , input_pos 
539-         )
547+     def  forward (self , x , freqs_cos , freqs_sin , input_pos = None , return_kv = False ):  # x: 1xN 
548+         if  not  return_kv :
549+             h  =  self .attention .forward (
550+                 self .attention_norm (x ), freqs_cos , freqs_sin , input_pos , return_kv = False ,
551+             )
552+         else :
553+             h , k , v  =  self .attention .forward (
554+                 self .attention_norm (x ), freqs_cos , freqs_sin , input_pos , return_kv = True ,
555+             )
540556
541557        h  =  x  +  h 
542558        if  hasattr (self , "block_sparse_moe" ):
543559            out  =  h  +  self .block_sparse_moe (self .ffn_norm (h ))
544560        else :
545561            out  =  h  +  self .feed_forward (self .ffn_norm (h ))
562+         
563+         if  return_kv :
564+             return  out , k , v 
546565        return  out 
547566
548567
@@ -565,6 +584,7 @@ def __init__(self, params: ModelArgs):
565584        self .max_seq_len  =  params .max_seq_len 
566585        self .input_prune_map  =  params .input_prune_map 
567586        self .output_prune_map  =  params .output_prune_map 
587+         self .prefill_return_kv  =  params .prefill_return_kv 
568588
569589    def  forward (
570590        self ,
@@ -583,13 +603,30 @@ def forward(
583603        seqlen  =  h .shape [1 ]
584604        freqs_cos , freqs_sin  =  self .rope .get_freqs (input_pos , seqlen )
585605
586-         for  layer  in  self .layers :
587-             h  =  layer (
588-                 h ,
589-                 freqs_cos ,
590-                 freqs_sin ,
591-                 input_pos ,
592-             )
606+         if  not  self .prefill_return_kv :
607+             for  layer  in  self .layers :
608+                 h  =  layer (
609+                     h ,
610+                     freqs_cos ,
611+                     freqs_sin ,
612+                     input_pos ,
613+                     return_kv = False ,
614+                 )
615+         else :
616+             k_caches  =  []
617+             v_caches  =  []
618+             for  layer  in  self .layers :
619+                 h , k , v  =  layer (
620+                     h ,
621+                     freqs_cos ,
622+                     freqs_sin ,
623+                     input_pos ,
624+                     return_kv = True ,
625+                 )
626+                 k_caches .append (k )
627+                 v_caches .append (v )
628+             k_ret  =  torch .stack (k_caches , dim = 0 )
629+             v_ret  =  torch .stack (v_caches , dim = 0 )
593630
594631        if  not  self .generate_full_logits :
595632            # Only the last logit is used for the new generated token 
@@ -621,4 +658,6 @@ def forward(
621658                expanded_logits [:, list (self .output_prune_map .values ())] =  logits 
622659            logits  =  expanded_logits 
623660
661+         if  self .prefill_return_kv :
662+             return  logits , k_ret , v_ret 
624663        return  logits 
0 commit comments