@@ -96,6 +96,11 @@ def _bind_attention(self):
9696 LlamaTransformerLayerInfer ._token_decode_attention_gqa_flashdecoding , self
9797 )
9898 self ._copy_kv_to_mem_cache = partial (LlamaTransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
99+ elif "triton_gqa_flashdecoding_vsm" in self .mode :
100+ self ._token_attention_kernel = partial (
101+ LlamaTransformerLayerInfer ._token_decode_attention_gqa_flashdecoding_vsm , self
102+ )
103+ self ._copy_kv_to_mem_cache = partial (LlamaTransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
99104 else :
100105 self ._token_attention_kernel = partial (LlamaTransformerLayerInfer ._token_decode_attention_normal , self )
101106 self ._copy_kv_to_mem_cache = partial (LlamaTransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
@@ -587,3 +592,20 @@ def _token_decode_attention_ppl_int4kv_flashdecoding(
587592 out = out ,
588593 alloc_tensor_func = self .alloc_tensor ,
589594 )
595+
596+ def _token_decode_attention_gqa_flashdecoding_vsm ( self , q , infer_state : LlamaInferStateInfo , layer_weight , out = None ):
597+ from lightllm .models .llama .triton_kernel .gqa_flash_decoding_vsm import gqa_token_decode_attention_flash_decoding_vsm
598+
599+ cache_k = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, 0 : self .tp_k_head_num_ , :]
600+ cache_v = infer_state .mem_manager .kv_buffer [self .layer_num_ ][
601+ :, self .tp_k_head_num_ : self .tp_k_head_num_ + self .tp_v_head_num_ , :
602+ ]
603+ q_shape = (infer_state .batch_size , self .tp_q_head_num_ , self .head_dim_ )
604+ return gqa_token_decode_attention_flash_decoding_vsm (
605+ q .view (q_shape ),
606+ cache_k ,
607+ cache_v ,
608+ infer_state ,
609+ out = out ,
610+ alloc_tensor_func = self .alloc_tensor ,
611+ )
0 commit comments