@@ -115,6 +115,8 @@ class ModelArgs:
115115 num_activated_experts : int = 2 # Number of experts to activate
116116 use_kv_cache : bool = False # Use key/value cache
117117 prefill_return_kv : bool = False # Return kv cache for prefill
118+ decode_kv_cache_as_io : bool = False # Decode uses KV caches as IO
119+ use_additive_kv_cache_update : bool = False # Additive KV cache update
118120 use_sdpa_with_kv_cache_op : bool = (
119121 False # Use custom sdpa op that updates kv cache in-place
120122 )
@@ -367,6 +369,9 @@ class Attention(nn.Module):
367369 def __init__ (self , args : ModelArgs , layer_id : int , rope : Rope ):
368370 super ().__init__ ()
369371 self .use_kv_cache = args .use_kv_cache
372+ self .decode_kv_cache_as_io = args .decode_kv_cache_as_io
373+ self .use_additive_kv_cache_update = args .use_additive_kv_cache_update
374+ self .return_kv_values = (args .prefill_return_kv or args .decode_kv_cache_as_io )
370375 self .n_heads = args .n_heads
371376 self .n_kv_heads = self .n_heads if args .n_kv_heads is None else args .n_kv_heads
372377 assert self .n_heads % self .n_kv_heads == 0
@@ -397,7 +402,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
397402 )
398403 self .register_buffer ("mask" , causal_mask , persistent = False )
399404
400- if self .use_kv_cache :
405+ if self .use_kv_cache and not self . decode_kv_cache_as_io :
401406 self .kv_cache = KVCache (
402407 args .max_batch_size ,
403408 args .max_seq_len ,
@@ -421,10 +426,19 @@ def forward(
421426 freqs_cos : torch .Tensor ,
422427 freqs_sin : torch .Tensor ,
423428 input_pos : Optional [torch .Tensor ] = None ,
424- return_kv : bool = False ,
429+ k_cache : Optional [torch .Tensor ] = None ,
430+ v_cache : Optional [torch .Tensor ] = None ,
431+ cache_pos_mask : Optional [torch .Tensor ] = None ,
425432 ):
426- if return_kv :
427- assert self .use_kv_cache == False , "Can't return kv when use_kv_cache is True"
433+ if self .decode_kv_cache_as_io :
434+ assert self .use_kv_cache
435+ assert k_cache is not None
436+ assert v_cache is not None
437+ assert self .return_kv_values
438+
439+ if self .use_additive_kv_cache_update :
440+ assert self .decode_kv_cache_as_io
441+ assert cache_pos_mask is not None
428442
429443 bsz , seqlen , _ = x .shape
430444
@@ -438,34 +452,53 @@ def forward(
438452 # RoPE relative positional embeddings
439453 q , k = self .rope .forward (q , k , freqs_cos , freqs_sin )
440454
441- if self .use_kv_cache :
455+ if self .use_kv_cache and not self . decode_kv_cache_as_io :
442456 assert input_pos is not None
457+ assert not self .return_kv_values
443458 output = self .SDPA (input_pos , q , k , v , bsz , seqlen , self .mask )
444459 return self .wo (output )
445460
446461 q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
447462 k = k .transpose (1 , 2 )
448463 v = v .transpose (1 , 2 )
449464
450- if return_kv :
465+ if self . return_kv_values :
451466 k_ret = k
452467 v_ret = v
453-
454- # grouped multiquery attention: expand out keys and values
455- k = k .repeat_interleave (self .n_rep , dim = 1 )
456- v = v .repeat_interleave (self .n_rep , dim = 1 )
457-
468+
458469 assert hasattr (self , "mask" )
470+ if self .decode_kv_cache_as_io :
471+ assert self .use_kv_cache
472+ mask = self .mask [None , None , input_pos ]
473+ if self .use_additive_kv_cache_update :
474+ assert cache_pos_mask is not None
475+ assert seqlen == 1
476+ k_update = cache_pos_mask * k
477+ v_update = cache_pos_mask * v
478+ k = k_cache + k_update
479+ v = v_cache + v_update
480+ assert k .shape == k_cache .shape
481+ assert v .shape == v_cache .shape
482+ else :
483+ k = torch .ops .aten .index_put (k_cache , [None , None , input_pos , None ], k )
484+ v = torch .ops .aten .index_put (v_cache , [None , None , input_pos , None ], v )
485+ else :
486+ assert not self .use_kv_cache
487+ mask = self .mask [:seqlen , :seqlen ]
488+
459489
460- mask = self .mask [:seqlen , :seqlen ]
490+ # grouped multiquery attention: expand out keys and values
491+ if self .n_rep > 1 :
492+ k = k .repeat_interleave (self .n_rep , dim = 1 )
493+ v = v .repeat_interleave (self .n_rep , dim = 1 )
461494
462495 output = torch .ops .coreml .sdpa (q , k , v , mask )
463496
464497 output = output .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , - 1 )
465498
466499 output = self .wo (output )
467500
468- if return_kv :
501+ if self . return_kv_values :
469502 return output , k_ret , v_ret
470503 return output
471504
@@ -533,6 +566,8 @@ class TransformerBlock(nn.Module):
533566 def __init__ (self , layer_id : int , args : ModelArgs , rope : Rope ):
534567 super ().__init__ ()
535568 self .use_kv_cache = args .use_kv_cache
569+ self .decode_kv_cache_as_io = args .decode_kv_cache_as_io
570+ self .return_kv_values = (args .prefill_return_kv or args .decode_kv_cache_as_io )
536571 self .n_heads = args .n_heads
537572 self .dim = args .dim
538573 self .head_dim = args .head_dim
@@ -544,14 +579,19 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
544579 self .attention_norm = RMSNorm (args .dim , eps = args .norm_eps )
545580 self .ffn_norm = RMSNorm (args .dim , eps = args .norm_eps )
546581
547- def forward (self , x , freqs_cos , freqs_sin , input_pos = None , return_kv = False ): # x: 1xN
548- if not return_kv :
582+ def forward (self , x , freqs_cos , freqs_sin , input_pos = None , k_cache = None , v_cache = None , cache_pos_mask = None ): # x: 1xN
583+ if self .decode_kv_cache_as_io :
584+ assert self .use_kv_cache
585+ assert k_cache is not None
586+ assert v_cache is not None
587+
588+ if not self .return_kv_values :
549589 h = self .attention .forward (
550- self .attention_norm (x ), freqs_cos , freqs_sin , input_pos , return_kv = False ,
590+ self .attention_norm (x ), freqs_cos , freqs_sin , input_pos , k_cache , v_cache , cache_pos_mask ,
551591 )
552592 else :
553593 h , k , v = self .attention .forward (
554- self .attention_norm (x ), freqs_cos , freqs_sin , input_pos , return_kv = True ,
594+ self .attention_norm (x ), freqs_cos , freqs_sin , input_pos , k_cache , v_cache , cache_pos_mask ,
555595 )
556596
557597 h = x + h
@@ -560,7 +600,7 @@ def forward(self, x, freqs_cos, freqs_sin, input_pos=None, return_kv=False): #
560600 else :
561601 out = h + self .feed_forward (self .ffn_norm (h ))
562602
563- if return_kv :
603+ if self . return_kv_values :
564604 return out , k , v
565605 return out
566606
@@ -580,49 +620,71 @@ def __init__(self, params: ModelArgs):
580620 self .norm = RMSNorm (params .dim , eps = params .norm_eps )
581621 self .output = nn .Linear (params .dim , params .vocab_size , bias = False )
582622 self .use_kv_cache = params .use_kv_cache
623+ self .decode_kv_cache_as_io = params .decode_kv_cache_as_io
583624 self .generate_full_logits = params .generate_full_logits
584625 self .max_seq_len = params .max_seq_len
585626 self .input_prune_map = params .input_prune_map
586627 self .output_prune_map = params .output_prune_map
587- self .prefill_return_kv = params .prefill_return_kv
628+
629+ # Whether model returns newly computed KV values
630+ self .return_kv_values = (params .prefill_return_kv or params .decode_kv_cache_as_io )
588631
589632 def forward (
590633 self ,
591634 tokens : Optional [torch .LongTensor ] = None , # tokens
592635 input_pos : Optional [
593636 torch .LongTensor
594637 ] = None , # Scalar tensor indicating size of window of the caches
595- h : Optional [torch .FloatTensor ] = None , # embeddings
638+ k_cache : Optional [torch .FloatTensor ] = None ,
639+ v_cache : Optional [torch .FloatTensor ] = None ,
640+ cache_pos_mask : Optional [torch .FloatTensor ] = None ,
596641 ) -> torch .Tensor :
597- if (tokens is None ) ^ (h is not None ):
598- raise ValueError (
599- "You cannot specify both tokens and h at the same time, and must specify either one"
600- )
601- if tokens is not None and h is None :
602- h = self .tok_embeddings (tokens )
642+ h = self .tok_embeddings (tokens )
643+ if self .decode_kv_cache_as_io :
644+ assert self .use_kv_cache
645+ assert k_cache is not None
646+ assert v_cache is not None
647+
648+
649+
603650 seqlen = h .shape [1 ]
604651 freqs_cos , freqs_sin = self .rope .get_freqs (input_pos , seqlen )
605652
606- if not self .prefill_return_kv :
653+ if not self .return_kv_values :
607654 for layer in self .layers :
608655 h = layer (
609656 h ,
610657 freqs_cos ,
611658 freqs_sin ,
612659 input_pos ,
613- return_kv = False ,
660+ k_cache ,
661+ v_cache ,
662+ cache_pos_mask ,
614663 )
615664 else :
616665 k_caches = []
617666 v_caches = []
618- for layer in self .layers :
619- h , k , v = layer (
620- h ,
621- freqs_cos ,
622- freqs_sin ,
623- input_pos ,
624- return_kv = True ,
625- )
667+ for i , layer in enumerate (self .layers ):
668+ if not self .decode_kv_cache_as_io :
669+ h , k , v = layer (
670+ h ,
671+ freqs_cos ,
672+ freqs_sin ,
673+ input_pos ,
674+ k_cache ,
675+ v_cache ,
676+ cache_pos_mask ,
677+ )
678+ else :
679+ h , k , v = layer (
680+ h ,
681+ freqs_cos ,
682+ freqs_sin ,
683+ input_pos ,
684+ k_cache [i ,:,:,:,:],
685+ v_cache [i ,:,:,:,:],
686+ cache_pos_mask ,
687+ )
626688 k_caches .append (k )
627689 v_caches .append (v )
628690 k_ret = torch .stack (k_caches , dim = 0 )
@@ -658,6 +720,6 @@ def forward(
658720 expanded_logits [:, list (self .output_prune_map .values ())] = logits
659721 logits = expanded_logits
660722
661- if self .prefill_return_kv :
723+ if self .return_kv_values :
662724 return logits , k_ret , v_ret
663725 return logits
0 commit comments