Skip to content

Commit e4ec0c0

Browse files
committed
remove unused att op
1 parent 8d334ea commit e4ec0c0

File tree

4 files changed

+3
-414
lines changed

4 files changed

+3
-414
lines changed

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
context_attention_fwd_no_prompt_cache,
1111
)
1212

13-
from lightllm.models.deepseek2.triton_kernel.flash_decoding import token_decode_attention_flash_decoding
1413
from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
1514
from lightllm.models.deepseek2.layer_infer.fused_moe import fused_experts, grouped_topk
1615
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
@@ -60,14 +59,9 @@ def __init__(
6059

6160
def _bind_attention(self):
6261
self._context_attention_kernel = partial(Deepseek2TransformerLayerInfer._context_attention_kernel, self)
63-
if "triton_flashdecoding" in self.mode:
64-
self._token_attention_kernel = partial(
65-
Deepseek2TransformerLayerInfer._token_decode_attention_flashdecoding, self
66-
)
67-
else:
68-
self._token_attention_kernel = partial(
69-
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self
70-
)
62+
self._token_attention_kernel = partial(
63+
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self
64+
)
7165
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
7266
if self.is_moe:
7367
self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn, self)
@@ -171,24 +165,6 @@ def _context_attention_kernel(
171165
q_rope = None
172166
return o_tensor
173167

174-
def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None):
175-
q_nope, q_rope = q
176-
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank]
177-
kv_rope = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank :]
178-
return token_decode_attention_flash_decoding(
179-
q_nope,
180-
q_rope,
181-
kv,
182-
kv_rope,
183-
infer_state,
184-
self.tp_q_head_num_,
185-
self.kv_lora_rank,
186-
self.qk_rope_head_dim,
187-
self.qk_nope_head_dim,
188-
self.softmax_scale,
189-
alloc_tensor_func=self.alloc_tensor,
190-
)
191-
192168
def _token_gqa_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None):
193169
q_nope, q_rope = q
194170
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank]

lightllm/models/deepseek2/triton_kernel/flash_decoding.py

Lines changed: 0 additions & 121 deletions
This file was deleted.

lightllm/models/deepseek2/triton_kernel/flash_decoding_stage1.py

Lines changed: 0 additions & 184 deletions
This file was deleted.

0 commit comments

Comments
 (0)