Skip to content

Commit fa4456b

Browse files
authored
merge q_a_proj and kv_a_proj to reduce the kernel launch overhead (#1055)
1 parent 6c32d4f commit fa4456b

File tree

2 files changed

+35
-20
lines changed

2 files changed

+35
-20
lines changed

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,14 @@ def _bind_attention(self):
140140
Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self
141141
)
142142

143+
def _pre_cache_kv(
144+
self, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
145+
) -> torch.Tensor:
146+
# q_lora_rank 不是None的时候,融合 q_a_proj 和 kv_a_proj_with_mqa
147+
if self.q_lora_rank is None:
148+
return super()._pre_cache_kv(infer_state, layer_weight)
149+
return None
150+
143151
def _get_qkv(
144152
self,
145153
input: torch.Tensor,
@@ -151,13 +159,16 @@ def _get_qkv(
151159

152160
if self.q_lora_rank is None:
153161
q = layer_weight.q_weight_.mm(input)
162+
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
154163
else:
155-
q = layer_weight.q_a_proj_.mm(input)
156-
rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_, out=q)
164+
q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split(
165+
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
166+
)
167+
q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_)
157168
q = layer_weight.q_b_proj_.mm(q)
169+
cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim)
158170
q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim)
159171
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
160-
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
161172
rmsnorm_forward(
162173
cache_kv[:, :, : self.kv_lora_rank],
163174
weight=layer_weight.kv_a_layernorm_.weight,
@@ -185,16 +196,18 @@ def _tpsp_get_qkv(
185196
input = gather_input[0 : len(infer_state.position_cos), :]
186197

187198
input = input.view(-1, self.embed_dim_)
188-
189199
if self.q_lora_rank is None:
190200
q = layer_weight.q_weight_.mm(input)
201+
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
191202
else:
192-
q = layer_weight.q_a_proj_.mm(input)
193-
rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_, out=q)
203+
q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split(
204+
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
205+
)
206+
q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_)
194207
q = layer_weight.q_b_proj_.mm(q)
208+
cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim)
195209
q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim)
196210
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
197-
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
198211
rmsnorm_forward(
199212
cache_kv[:, :, : self.kv_lora_rank],
200213
weight=layer_weight.kv_a_layernorm_.weight,

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,25 @@ def _init_qkvo(self):
148148
layer_num=self.layer_num_,
149149
name="q_weight",
150150
)
151+
self.kv_a_proj_with_mqa_ = ROWMMWeight(
152+
weight_name=f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight",
153+
data_type=self.data_type_,
154+
quant_cfg=self.quant_cfg,
155+
layer_num=self.layer_num_,
156+
name="kv_a_proj_with_mqa",
157+
tp_rank=0,
158+
tp_world_size=1,
159+
)
151160
else:
152-
self.q_a_proj_ = ROWMMWeight(
153-
weight_name=f"model.layers.{self.layer_num_}.self_attn.q_a_proj.weight",
161+
self.qkv_a_proj_with_mqa_ = MultiROWMMWeight(
162+
weight_names=[
163+
f"model.layers.{self.layer_num_}.self_attn.q_a_proj.weight",
164+
f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight",
165+
],
154166
data_type=self.data_type_,
155167
quant_cfg=self.quant_cfg,
156168
layer_num=self.layer_num_,
157-
name="q_a_proj",
169+
name="qkv_a_proj_with_mqa",
158170
tp_rank=0,
159171
tp_world_size=1,
160172
)
@@ -165,16 +177,6 @@ def _init_qkvo(self):
165177
layer_num=self.layer_num_,
166178
name="q_b_proj",
167179
)
168-
169-
self.kv_a_proj_with_mqa_ = ROWMMWeight(
170-
weight_name=f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight",
171-
data_type=self.data_type_,
172-
quant_cfg=self.quant_cfg,
173-
layer_num=self.layer_num_,
174-
name="kv_a_proj_with_mqa",
175-
tp_rank=0,
176-
tp_world_size=1,
177-
)
178180
self.k_b_proj_ = ROWBMMWeight(
179181
weight_name=f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight",
180182
data_type=self.data_type_,

0 commit comments

Comments
 (0)