Skip to content

Commit e38c664

Browse files
committed
refine code
1 parent 7e11e11 commit e38c664

File tree

2 files changed

+8
-24
lines changed

2 files changed

+8
-24
lines changed

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -185,14 +185,9 @@ def _ACC_method(
185185
num_local_heads //= self.world_size_
186186
num_local_kv_heads //= self.world_size_
187187
# ACC
188-
q_nope_up_ = self.alloc_tensor([q_nope.shape[1], q_nope.shape[0], self.kv_lora_rank], dtype=q_nope.dtype)
189-
q_nope_up_ = torch.bmm( # TODO: 转换成einsum 或者 cublas
190-
q_nope.transpose(0, 1), # (h, b*s, qk_n)
191-
layer_weight.k_b_proj_.weight, # (h, qk_n, kv_lora)
192-
out=q_nope_up_.view(q_nope.shape[1], q_nope.shape[0], self.kv_lora_rank),
193-
).transpose(
194-
0, 1
195-
) # (b*s, h, kv_lora)
188+
q_nope = layer_weight.k_b_proj_.weight.bmm(
189+
q_nope.transpose(0, 1),
190+
).transpose(0, 1)
196191
if self.enable_opt_decoding_mha:
197192
import lightllm_ppl_mla
198193

@@ -213,19 +208,10 @@ def _ACC_method(
213208
output_parallel = o_tensor
214209
else:
215210
output_parallel = self._token_gqa_decode_attention_flashdecoding_origin(
216-
(q_nope_up_, q_rope), infer_state, layer_weight
211+
(q_nope, q_rope), infer_state, layer_weight
217212
)
218-
o_tensor = self.alloc_tensor(
219-
[output_parallel.shape[1], output_parallel.shape[0], self.qk_nope_head_dim], dtype=q_rope.dtype
220-
)
221-
o_tensor = torch.bmm( # TODO: 转换成einsum 或者 cublas
222-
output_parallel.transpose(0, 1), # (h, b*s, kv_lora)
223-
layer_weight.v_b_proj_.weight, # (h, kv_lora, vo_d)
224-
out=o_tensor,
225-
).transpose(
226-
0, 1
227-
) # (b*s, h, vo_d)
228-
return o_tensor
213+
vo = layer_weight.v_b_proj_.bmm(output_parallel.transpose(0, 1)).transpose(0, 1)
214+
return vo
229215

230216
def _context_attention_kernel(
231217
self, q, kv, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None

lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,7 @@ def _fwd_kernel_with_v(
9797
other=0,
9898
)
9999
off_k = k_loc[None, :] * stride_k_bs + cur_k_head * stride_k_h + offs_d[:, None] * stride_k_d
100-
off_k_rope = (
101-
k_loc[None, :] * stride_k_rope_bs + cur_k_head * stride_k_rope_h + offs_rope_d[:, None] * stride_k_rope_d
102-
)
100+
off_k_rope = k_loc[None, :] * stride_k_rope_bs + offs_rope_d[:, None] * stride_k_rope_d
103101
k = tl.load(K_nope + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0)
104102
k_rope = tl.load(K_rope + off_k_rope, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0)
105103

@@ -290,7 +288,7 @@ def _fwd_kernel_no_prompt_cache_with_v(
290288
+ offs_rope_d[None, :] * stride_q_rope_d
291289
)
292290
off_k = offs_n[None, :] * stride_k_bs + cur_k_head * stride_k_h + offs_d[:, None] * stride_k_d
293-
off_rope_k = offs_n[None, :] * stride_k_rope_bs + 0 * stride_k_rope_h + offs_rope_d[:, None] * stride_k_rope_d
291+
off_rope_k = offs_n[None, :] * stride_k_rope_bs + offs_rope_d[:, None] * stride_k_rope_d
294292
off_v = offs_n[:, None] * stride_vbs + cur_k_head * stride_vh + offs_d[None, :] * stride_vd
295293

296294
q = tl.load(Q_nope + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)

0 commit comments

Comments
 (0)