Skip to content

Commit f7cd225

Browse files
committed
feat: gqa flash decode able to run
1 parent 6ede09e commit f7cd225

File tree

5 files changed

+675
-1
lines changed

5 files changed

+675
-1
lines changed

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ def _bind_attention(self):
9696
LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding, self
9797
)
9898
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
99+
elif "triton_gqa_flashdecoding_vsm" in self.mode:
100+
self._token_attention_kernel = partial(
101+
LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding_vsm, self
102+
)
103+
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
99104
else:
100105
self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_normal, self)
101106
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
@@ -587,3 +592,20 @@ def _token_decode_attention_ppl_int4kv_flashdecoding(
587592
out=out,
588593
alloc_tensor_func=self.alloc_tensor,
589594
)
595+
596+
def _token_decode_attention_gqa_flashdecoding_vsm( self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None):
597+
from lightllm.models.llama.triton_kernel.gqa_flash_decoding_vsm import gqa_token_decode_attention_flash_decoding_vsm
598+
599+
cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
600+
cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][
601+
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
602+
]
603+
q_shape = (infer_state.batch_size, self.tp_q_head_num_, self.head_dim_)
604+
return gqa_token_decode_attention_flash_decoding_vsm(
605+
q.view(q_shape),
606+
cache_k,
607+
cache_v,
608+
infer_state,
609+
out=out,
610+
alloc_tensor_func=self.alloc_tensor,
611+
)

0 commit comments

Comments
 (0)