@@ -140,6 +140,14 @@ def _bind_attention(self):
140140 Deepseek2TransformerLayerInfer ._context_attention_kernel_origin , self
141141 )
142142
143+ def _pre_cache_kv (
144+ self , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
145+ ) -> torch .Tensor :
146+ # q_lora_rank 不是None的时候,融合 q_a_proj 和 kv_a_proj_with_mqa
147+ if self .q_lora_rank is None :
148+ return super ()._pre_cache_kv (infer_state , layer_weight )
149+ return None
150+
143151 def _get_qkv (
144152 self ,
145153 input : torch .Tensor ,
@@ -151,13 +159,16 @@ def _get_qkv(
151159
152160 if self .q_lora_rank is None :
153161 q = layer_weight .q_weight_ .mm (input )
162+ layer_weight .kv_a_proj_with_mqa_ .mm (input , out = cache_kv .view (- 1 , self .kv_lora_rank + self .qk_rope_head_dim ))
154163 else :
155- q = layer_weight .q_a_proj_ .mm (input )
156- rmsnorm_forward (q , weight = layer_weight .q_a_layernorm_ .weight , eps = self .eps_ , out = q )
164+ q , cache_kv = layer_weight .qkv_a_proj_with_mqa_ .mm (input ).split (
165+ [self .q_lora_rank , self .kv_lora_rank + self .qk_rope_head_dim ], dim = - 1
166+ )
167+ q = rmsnorm_forward (q , weight = layer_weight .q_a_layernorm_ .weight , eps = self .eps_ )
157168 q = layer_weight .q_b_proj_ .mm (q )
169+ cache_kv = cache_kv .view (- 1 , 1 , self .kv_lora_rank + self .qk_rope_head_dim )
158170 q = q .view (- 1 , self .tp_q_head_num_ , self .qk_nope_head_dim + self .qk_rope_head_dim )
159171 q_nope , q_rope = torch .split (q , [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
160- layer_weight .kv_a_proj_with_mqa_ .mm (input , out = cache_kv .view (- 1 , self .kv_lora_rank + self .qk_rope_head_dim ))
161172 rmsnorm_forward (
162173 cache_kv [:, :, : self .kv_lora_rank ],
163174 weight = layer_weight .kv_a_layernorm_ .weight ,
@@ -185,16 +196,18 @@ def _tpsp_get_qkv(
185196 input = gather_input [0 : len (infer_state .position_cos ), :]
186197
187198 input = input .view (- 1 , self .embed_dim_ )
188-
189199 if self .q_lora_rank is None :
190200 q = layer_weight .q_weight_ .mm (input )
201+ layer_weight .kv_a_proj_with_mqa_ .mm (input , out = cache_kv .view (- 1 , self .kv_lora_rank + self .qk_rope_head_dim ))
191202 else :
192- q = layer_weight .q_a_proj_ .mm (input )
193- rmsnorm_forward (q , weight = layer_weight .q_a_layernorm_ .weight , eps = self .eps_ , out = q )
203+ q , cache_kv = layer_weight .qkv_a_proj_with_mqa_ .mm (input ).split (
204+ [self .q_lora_rank , self .kv_lora_rank + self .qk_rope_head_dim ], dim = - 1
205+ )
206+ q = rmsnorm_forward (q , weight = layer_weight .q_a_layernorm_ .weight , eps = self .eps_ )
194207 q = layer_weight .q_b_proj_ .mm (q )
208+ cache_kv = cache_kv .view (- 1 , 1 , self .kv_lora_rank + self .qk_rope_head_dim )
195209 q = q .view (- 1 , self .tp_q_head_num_ , self .qk_nope_head_dim + self .qk_rope_head_dim )
196210 q_nope , q_rope = torch .split (q , [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
197- layer_weight .kv_a_proj_with_mqa_ .mm (input , out = cache_kv .view (- 1 , self .kv_lora_rank + self .qk_rope_head_dim ))
198211 rmsnorm_forward (
199212 cache_kv [:, :, : self .kv_lora_rank ],
200213 weight = layer_weight .kv_a_layernorm_ .weight ,
0 commit comments