@@ -69,15 +69,15 @@ def _bind_norm(self):
6969
7070 def _bind_attention (self ):
7171 if get_env_start_args ().enable_fa3 :
72- if "calibration_fp8kv " in self .mode :
72+ if "offline_calibration_fp8kv " in self .mode :
7373 self ._context_attention_kernel = partial (
7474 LlamaTransformerLayerInfer ._context_attention_flashattention_fp8 , self
7575 )
7676 self ._token_attention_kernel = partial (
7777 LlamaTransformerLayerInfer ._token_decode_attention_flashattention_fp8 , self
7878 )
7979 self ._copy_kv_to_mem_cache = partial (LlamaTransformerLayerInfer ._copy_kv_to_mem_cache_fp8kv , self )
80- else :
80+ elif not self . mode :
8181 self ._context_attention_kernel = partial (
8282 LlamaTransformerLayerInfer ._context_attention_flashattention , self
8383 )
@@ -90,6 +90,8 @@ def _bind_attention(self):
9090 )
9191 else :
9292 self ._copy_kv_to_mem_cache = partial (LlamaTransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
93+ else :
94+ raise Exception (f"Unsupported mode for fa3 backend: { self .mode } " )
9395 return
9496 elif get_env_start_args ().enable_flashinfer_prefill :
9597 self ._context_attention_kernel = partial (
@@ -127,7 +129,7 @@ def _bind_attention(self):
127129 elif "triton_int8kv" in self .mode :
128130 self ._token_attention_kernel = partial (LlamaTransformerLayerInfer ._token_decode_attention_int8kv , self )
129131 self ._copy_kv_to_mem_cache = partial (LlamaTransformerLayerInfer ._copy_kv_to_mem_cache_int8kv , self )
130- elif "calibration_fp8kv " in self .mode :
132+ elif "offline_calibration_fp8kv " in self .mode :
131133 raise Exception ("calibration fp8 kvcache only support fa3 backend" )
132134 elif "triton_flashdecoding" in self .mode :
133135 self ._token_attention_kernel = partial (
@@ -147,14 +149,16 @@ def _bind_attention(self):
147149 LlamaTransformerLayerInfer ._token_decode_attention_gqa_flashdecoding_vsm , self
148150 )
149151 self ._copy_kv_to_mem_cache = partial (LlamaTransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
150- else :
152+ elif not self . mode :
151153 if get_env_start_args ().enable_flashinfer_decode :
152154 self ._token_attention_kernel = partial (
153155 LlamaTransformerLayerInfer ._token_decode_attention_flashinfer , self
154156 )
155157 else :
156158 self ._token_attention_kernel = partial (LlamaTransformerLayerInfer ._token_decode_attention_normal , self )
157159 self ._copy_kv_to_mem_cache = partial (LlamaTransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
160+ else :
161+ raise Exception (f"Unsupported mode: { self .mode } " )
158162
159163 return
160164
0 commit comments