@@ -56,6 +56,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
5656 if mscale_all_dim :
5757 mscale = get_deepseek_mscale (scaling_factor , mscale_all_dim )
5858 self .softmax_scale = self .softmax_scale * mscale * mscale
59+ self .enable_cc_method = os .getenv ("ENABLE_CC_METHOD" , "False" ).upper () in ["ON" , "TRUE" , "1" ]
5960 super ().__init__ (layer_num , tp_rank , world_size , network_config , mode )
6061 self .enable_dp = os .getenv ("ENABLE_DP" , "0" ).upper () in ["ON" , "TRUE" , "1" ]
6162 if self .enable_dp :
@@ -67,7 +68,14 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
6768 return
6869
6970 def _bind_attention (self ):
70- self ._context_attention_kernel = partial (Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC , self )
71+ if self .enable_cc_method :
72+ self ._context_attention_kernel = partial (
73+ Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC , self
74+ )
75+ else :
76+ self ._context_attention_kernel = partial (
77+ Deepseek2TransformerLayerInfer ._context_attention_kernel_origin , self
78+ )
7179 self ._token_attention_kernel = partial (
7280 Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashdecoding , self
7381 )
0 commit comments