@@ -272,7 +272,11 @@ def __init__(
272
272
)
273
273
274
274
self .norm_1 = nn .Identity () if not config .norm_1 else config .norm_class (config .n_embd , eps = config .norm_eps )
275
- self .attn = CausalSelfAttention (config , block_idx )
275
+ self .attn = (
276
+ CausalSelfAttention (config , block_idx )
277
+ if not config .latent_attention
278
+ else MultiheadLatentAttention (config , block_idx )
279
+ )
276
280
self .post_attention_norm = (
277
281
config .norm_class (config .n_embd , eps = config .norm_eps ) if config .post_attention_norm else nn .Identity ()
278
282
)
@@ -549,6 +553,146 @@ def _load_from_state_dict(self, state_dict: dict, prefix: str, *args: Any, **kwa
549
553
super ()._load_from_state_dict (state_dict , prefix , * args , ** kwargs )
550
554
551
555
556
+ class MultiheadLatentAttention (nn .Module ):
557
+ def __init__ (self , config : Config , block_idx : int ) -> None :
558
+ super ().__init__ ()
559
+
560
+ self .q_a_proj = nn .Linear (config .n_embd , config .q_lora_rank , bias = config .attn_bias )
561
+ self .q_a_norm = RMSNorm (config .q_lora_rank , eps = config .norm_eps )
562
+ self .q_b_proj = nn .Linear (config .q_lora_rank , config .n_head * config .qk_head_dim , bias = config .bias )
563
+
564
+ self .kv_a_proj_with_mqa = nn .Linear (
565
+ config .n_embd , config .kv_lora_rank + config .qk_rope_head_dim , bias = config .attn_bias
566
+ )
567
+ self .kv_a_norm = RMSNorm (config .kv_lora_rank , eps = config .norm_eps )
568
+ self .kv_b_proj = nn .Linear (
569
+ config .kv_lora_rank ,
570
+ config .n_query_groups * (config .qk_nope_head_dim + config .v_head_dim ),
571
+ bias = config .bias ,
572
+ )
573
+
574
+ # output projection
575
+ self .proj = nn .Linear (config .n_head * config .v_head_dim , config .n_embd , bias = config .bias )
576
+ # disabled by default
577
+ self .kv_cache : Optional [KVCache ] = None
578
+
579
+ self .config = config
580
+ self .block_idx = block_idx
581
+
582
+ def forward (
583
+ self ,
584
+ x : torch .Tensor ,
585
+ cos : torch .Tensor ,
586
+ sin : torch .Tensor ,
587
+ mask : Optional [torch .Tensor ] = None ,
588
+ input_pos : Optional [torch .Tensor ] = None ,
589
+ input_pos_maxp1 : Optional [int ] = None ,
590
+ ) -> torch .Tensor :
591
+ # Notation:
592
+ # - B | batch size
593
+ # - T | time-step (sequence length)
594
+ # - C | model's embeddings size (n_embd)
595
+ # - C* | attentions's embeddings size
596
+ # - hs | head size
597
+ # - nh_(q,k,v) | number of heads for query, key and value
598
+ # - n_query_groups = nh_k = nh_v | number of query groups sharing key and value heads
599
+ # alternative notation: num_kv_groups = n_query_groups
600
+ B , T , C = x .size () # batch size, sequence length, embedding dimensionality (n_embd)
601
+
602
+ q = self .q_b_proj (self .q_a_norm (self .q_a_proj (x ))) # (B, T, n_head * qk_head_dim)
603
+ q = q .view (B , T , - 1 , self .config .qk_head_dim ) # (B, T, n_head, qk_head_dim)
604
+ q = q .transpose (1 , 2 ) # (B, n_head, T, qk_head_dim)
605
+ q_pass , q_rot = torch .split (q , [self .config .qk_nope_head_dim , self .config .qk_rope_head_dim ], dim = - 1 )
606
+
607
+ compressed_kv = self .kv_a_proj_with_mqa (x ) # (B, T, kv_lora_rank + qk_rope_head_dim)
608
+ k_pass , k_rot = torch .split (compressed_kv , [self .config .kv_lora_rank , self .config .qk_rope_head_dim ], dim = - 1 )
609
+
610
+ k_pass = self .kv_b_proj (self .kv_a_norm (k_pass ))
611
+ k_pass = k_pass .view (B , T , self .config .n_query_groups , - 1 )
612
+ k_pass = k_pass .transpose (1 , 2 )
613
+
614
+ k_pass , v = torch .split (k_pass , [self .config .qk_nope_head_dim , self .config .v_head_dim ], dim = - 1 )
615
+ k_rot = k_rot .view (B , 1 , T , self .config .qk_rope_head_dim ) # (B, 1, T, qk_rope_head_dim)
616
+
617
+ # Unlike standard positional embeddings rotary embeddings must be applied at every layer.
618
+ q_roped = apply_rope (q_rot , cos , sin )
619
+ k_roped = apply_rope (k_rot , cos , sin )
620
+ k_roped = k_roped .expand (* k_pass .shape [:- 1 ], - 1 ) # (B, n_head, T, qk_rope_head_dim)
621
+
622
+ q = torch .cat ((q_pass , q_roped ), dim = - 1 )
623
+ k = torch .cat ((k_pass , k_roped ), dim = - 1 )
624
+
625
+ # Apply kv-cache during inference.
626
+ if input_pos is not None :
627
+ if not isinstance (self .kv_cache , KVCache ):
628
+ raise TypeError ("You need to call `gpt.set_kv_cache()`" )
629
+ k , v = self .kv_cache (input_pos , k , v )
630
+ if input_pos_maxp1 is not None :
631
+ # Subselect along sequence dimension
632
+ k = k [..., :input_pos_maxp1 , :]
633
+ v = v [..., :input_pos_maxp1 , :]
634
+ # k, v: (B, nh_k, input_pos_maxp1, hs)
635
+ # If input_pos_maxp1 is None -> max_seq_length
636
+
637
+ # Grouped queries: balance the number of heads across all three matrices.
638
+ # NOTE: flash attention requires it in training mode.
639
+ # Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting.
640
+ if self .config .n_query_groups != self .config .n_head and (input_pos is None or self .config .n_query_groups != 1 ):
641
+ q_per_kv = self .config .n_head // self .config .n_query_groups
642
+ k = k .repeat_interleave (q_per_kv , dim = 1 ) # (B, nh_q, T, hs)
643
+ v = v .repeat_interleave (q_per_kv , dim = 1 ) # (B, nh_q, T, hs)
644
+
645
+ # Efficient attention using Flash Attention CUDA kernels.
646
+ # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled.
647
+ # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
648
+ y = self .scaled_dot_product_attention (q , k , v , mask )
649
+
650
+ # Re-assemble all head outputs side by side.
651
+ y = y .reshape (B , T , self .config .n_head * self .config .v_head_dim )
652
+
653
+ # Output projection.
654
+ return self .proj (y ) # (B, T, C)
655
+
656
+ def scaled_dot_product_attention (
657
+ self , q : torch .Tensor , k : torch .Tensor , v : torch .Tensor , mask : Optional [torch .Tensor ] = None
658
+ ) -> torch .Tensor :
659
+ scale = 1.0 / math .sqrt (self .config .attention_scores_scalar or self .config .qk_head_dim )
660
+
661
+ # with softcapping we cannot use SDPA
662
+ if self .config .attention_logit_softcapping is not None :
663
+ scores = q @ k .mT * scale
664
+ scores = do_softcapping (scores , self .config .attention_logit_softcapping )
665
+ if mask is None :
666
+ mask = torch .ones (q .size (2 ), q .size (2 ), dtype = q .dtype , device = q .device ).triu (diagonal = 1 )
667
+ mask .masked_fill_ (mask .bool (), torch .finfo (q .dtype ).min )
668
+ scores = scores + mask
669
+ scores = F .softmax (scores , dim = - 1 , dtype = torch .float ).to (dtype = q .dtype )
670
+ y = scores @ v
671
+ else :
672
+ y = F .scaled_dot_product_attention (
673
+ q , k , v , attn_mask = mask , dropout_p = 0.0 , scale = scale , is_causal = mask is None
674
+ )
675
+ return y .transpose (1 , 2 )
676
+
677
+ def build_kv_cache (
678
+ self ,
679
+ batch_size : int ,
680
+ max_seq_length : int ,
681
+ rope_cache_length : Optional [int ] = None ,
682
+ device : Optional [torch .device ] = None ,
683
+ dtype : Optional [torch .dtype ] = None ,
684
+ ) -> "KVCache" :
685
+ v_shape = (batch_size , self .config .n_head , max_seq_length , self .config .v_head_dim )
686
+ k_shape = (batch_size , self .config .n_head , max_seq_length , self .config .qk_head_dim )
687
+
688
+ if rope_cache_length is not None :
689
+ print ("Warning: `rope_cache_length` has no effect on MultiheadLatentAttention!" )
690
+ if self .config .rotary_percentage != 1.0 :
691
+ print ("Warning: `rotary_percentage` has no effect on MultiheadLatentAttention!" )
692
+
693
+ return KVCache (k_shape , v_shape , device = device , dtype = dtype )
694
+
695
+
552
696
class GptNeoxMLP (nn .Module ):
553
697
def __init__ (self , config : Config , intermediate_size : Optional [int ] = None ) -> None :
554
698
super ().__init__ ()
0 commit comments