1010 context_attention_fwd ,
1111 context_attention_fwd_no_prompt_cache ,
1212)
13+ from lightllm .models .deepseek2 .triton_kernel .context_flashattention_nopad_with_v import context_attention_fwd_with_v
14+ from lightllm .models .deepseek2 .triton_kernel .sample_kv import sample_kv
1315
1416from lightllm .models .deepseek2 .triton_kernel .gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
1517from lightllm .models .llama .layer_infer .transformer_layer_infer import LlamaTransformerLayerInfer
@@ -54,6 +56,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
5456 if mscale_all_dim :
5557 mscale = get_deepseek_mscale (scaling_factor , mscale_all_dim )
5658 self .softmax_scale = self .softmax_scale * mscale * mscale
59+ self .enable_cc_method = os .getenv ("ENABLE_CC_METHOD" , "False" ).upper () in ["ON" , "TRUE" , "1" ]
5760 super ().__init__ (layer_num , tp_rank , world_size , network_config , mode )
5861 self .enable_dp = os .getenv ("ENABLE_DP" , "0" ).upper () in ["ON" , "TRUE" , "1" ]
5962 if self .enable_dp :
@@ -65,7 +68,14 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
6568 return
6669
6770 def _bind_attention (self ):
68- self ._context_attention_kernel = partial (Deepseek2TransformerLayerInfer ._context_attention_kernel_origin , self )
71+ if self .enable_cc_method :
72+ self ._context_attention_kernel = partial (
73+ Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC , self
74+ )
75+ else :
76+ self ._context_attention_kernel = partial (
77+ Deepseek2TransformerLayerInfer ._context_attention_kernel_origin , self
78+ )
6979 self ._token_attention_kernel = partial (
7080 Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashdecoding , self
7181 )
@@ -123,6 +133,65 @@ def _get_o(
123133 o_tensor = layer_weight .o_weight_ .mm (input .reshape (- 1 , self .tp_q_head_num_ * self .qk_nope_head_dim ))
124134 return o_tensor
125135
136+ def _context_attention_kernel_with_CC (
137+ self ,
138+ q : torch .Tensor ,
139+ kv ,
140+ infer_state : Deepseek2InferStateInfo ,
141+ layer_weight : Deepseek2TransformerLayerWeight ,
142+ out = None ,
143+ ) -> torch .Tensor :
144+ if infer_state .use_dynamic_prompt_cache :
145+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
146+ compressed_kv = self .alloc_tensor (
147+ [infer_state .total_token_num , 1 , layer_weight .kv_lora_rank ], dtype = kv .dtype
148+ )
149+ k_rope = self .alloc_tensor ([infer_state .total_token_num , 1 , self .qk_rope_head_dim ], dtype = kv .dtype )
150+ sample_kv (
151+ kv ,
152+ compressed_kv ,
153+ k_rope ,
154+ infer_state .b_req_idx ,
155+ infer_state .b_seq_len ,
156+ infer_state .req_manager .req_to_token_indexs ,
157+ )
158+ else :
159+ compressed_kv , k_rope = torch .split ( # (b*s, 1, kv_lora + qk_r)
160+ kv , [layer_weight .kv_lora_rank , layer_weight .qk_rope_head_dim ], dim = - 1
161+ )
162+
163+ # CC
164+ k_nope = self .alloc_tensor (
165+ [compressed_kv .shape [0 ], q .shape [1 ], self .qk_nope_head_dim ],
166+ dtype = compressed_kv .dtype ,
167+ )
168+ v = self .alloc_tensor (
169+ k_nope .shape ,
170+ dtype = compressed_kv .dtype ,
171+ )
172+ compressed_kv = compressed_kv .view (- 1 , layer_weight .kv_lora_rank )
173+ wk = layer_weight .k_b_proj_ .weight .view (- 1 , layer_weight .kv_lora_rank )
174+ wv = layer_weight .v_b_proj_ .weight .transpose (1 , 2 ).view (- 1 , layer_weight .kv_lora_rank )
175+ torch .mm (compressed_kv , wk .transpose (0 , 1 ), out = k_nope .reshape (compressed_kv .shape [0 ], - 1 ))
176+ torch .mm (compressed_kv , wv .transpose (0 , 1 ), out = v .reshape (compressed_kv .shape [0 ], - 1 ))
177+
178+ q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
179+ o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype ) if out is None else out
180+ context_attention_fwd_with_v (
181+ q_nope ,
182+ q_rope ,
183+ k_nope ,
184+ k_rope ,
185+ v ,
186+ o_tensor .view (- 1 , self .tp_q_head_num_ , q_nope .shape [- 1 ]),
187+ infer_state .b_start_loc ,
188+ infer_state .b_seq_len ,
189+ infer_state .b_ready_cache_len ,
190+ infer_state .max_len_in_batch ,
191+ self .softmax_scale ,
192+ )
193+ return o_tensor
194+
126195 def _context_attention_kernel_origin (
127196 self ,
128197 q : torch .Tensor ,
0 commit comments