@@ -114,6 +114,7 @@ class ModelArgs:
114114 num_experts : int = 8 # Number of experts
115115 num_activated_experts : int = 2 # Number of experts to activate
116116 use_kv_cache : bool = False # Use key/value cache
117+ prefill_return_kv : bool = False # Return kv cache for prefill
117118 use_sdpa_with_kv_cache_op : bool = (
118119 False # Use custom sdpa op that updates kv cache in-place
119120 )
@@ -420,7 +421,11 @@ def forward(
420421 freqs_cos : torch .Tensor ,
421422 freqs_sin : torch .Tensor ,
422423 input_pos : Optional [torch .Tensor ] = None ,
424+ return_kv : bool = False ,
423425 ):
426+ if return_kv :
427+ assert self .use_kv_cache == False , "Can't return kv when use_kv_cache is True"
428+
424429 bsz , seqlen , _ = x .shape
425430
426431 # QKV
@@ -442,6 +447,10 @@ def forward(
442447 k = k .transpose (1 , 2 )
443448 v = v .transpose (1 , 2 )
444449
450+ if return_kv :
451+ k_ret = k
452+ v_ret = v
453+
445454 # grouped multiquery attention: expand out keys and values
446455 k = k .repeat_interleave (self .n_rep , dim = 1 )
447456 v = v .repeat_interleave (self .n_rep , dim = 1 )
@@ -456,6 +465,8 @@ def forward(
456465
457466 output = self .wo (output )
458467
468+ if return_kv :
469+ return output , k_ret , v_ret
459470 return output
460471
461472
@@ -533,16 +544,24 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
533544 self .attention_norm = RMSNorm (args .dim , eps = args .norm_eps )
534545 self .ffn_norm = RMSNorm (args .dim , eps = args .norm_eps )
535546
536- def forward (self , x , freqs_cos , freqs_sin , input_pos = None ): # x: 1xN
537- h = self .attention .forward (
538- self .attention_norm (x ), freqs_cos , freqs_sin , input_pos
539- )
547+ def forward (self , x , freqs_cos , freqs_sin , input_pos = None , return_kv = False ): # x: 1xN
548+ if not return_kv :
549+ h = self .attention .forward (
550+ self .attention_norm (x ), freqs_cos , freqs_sin , input_pos , return_kv = False ,
551+ )
552+ else :
553+ h , k , v = self .attention .forward (
554+ self .attention_norm (x ), freqs_cos , freqs_sin , input_pos , return_kv = True ,
555+ )
540556
541557 h = x + h
542558 if hasattr (self , "block_sparse_moe" ):
543559 out = h + self .block_sparse_moe (self .ffn_norm (h ))
544560 else :
545561 out = h + self .feed_forward (self .ffn_norm (h ))
562+
563+ if return_kv :
564+ return out , k , v
546565 return out
547566
548567
@@ -565,6 +584,7 @@ def __init__(self, params: ModelArgs):
565584 self .max_seq_len = params .max_seq_len
566585 self .input_prune_map = params .input_prune_map
567586 self .output_prune_map = params .output_prune_map
587+ self .prefill_return_kv = params .prefill_return_kv
568588
569589 def forward (
570590 self ,
@@ -583,13 +603,30 @@ def forward(
583603 seqlen = h .shape [1 ]
584604 freqs_cos , freqs_sin = self .rope .get_freqs (input_pos , seqlen )
585605
586- for layer in self .layers :
587- h = layer (
588- h ,
589- freqs_cos ,
590- freqs_sin ,
591- input_pos ,
592- )
606+ if not self .prefill_return_kv :
607+ for layer in self .layers :
608+ h = layer (
609+ h ,
610+ freqs_cos ,
611+ freqs_sin ,
612+ input_pos ,
613+ return_kv = False ,
614+ )
615+ else :
616+ k_caches = []
617+ v_caches = []
618+ for layer in self .layers :
619+ h , k , v = layer (
620+ h ,
621+ freqs_cos ,
622+ freqs_sin ,
623+ input_pos ,
624+ return_kv = True ,
625+ )
626+ k_caches .append (k )
627+ v_caches .append (v )
628+ k_ret = torch .stack (k_caches , dim = 0 )
629+ v_ret = torch .stack (v_caches , dim = 0 )
593630
594631 if not self .generate_full_logits :
595632 # Only the last logit is used for the new generated token
@@ -621,4 +658,6 @@ def forward(
621658 expanded_logits [:, list (self .output_prune_map .values ())] = logits
622659 logits = expanded_logits
623660
661+ if self .prefill_return_kv :
662+ return logits , k_ret , v_ret
624663 return logits
0 commit comments