@@ -30,6 +30,14 @@ def __init__(self, model):
3030 self .head_dim = model .config ["hidden_size" ] // model .config ["num_attention_heads" ]
3131 self .workspace_buffer = torch .empty (256 * 1024 * 1024 , dtype = torch .int8 ).to (get_current_device_id ())
3232 self .max_seq_length = model .max_seq_length
33+ self .kv_indices_buffer = [
34+ torch .empty (model .graph_max_batch_size * self .max_seq_length , dtype = torch .int32 ).to (
35+ get_current_device_id ()
36+ ),
37+ torch .empty (model .graph_max_batch_size * self .max_seq_length , dtype = torch .int32 ).to (
38+ get_current_device_id ()
39+ ),
40+ ]
3341 self .q_data_type = model .data_type
3442 self .kv_data_type = model .data_type
3543
@@ -51,8 +59,6 @@ def __init__(self, kvargs):
5159 self .enable_flashinfer = (
5260 get_env_start_args ().enable_flashinfer_prefill or get_env_start_args ().enable_flashinfer_decode
5361 )
54- if self .enable_flashinfer :
55- self .infer_state_class = LlamaFlashInferStateInfo
5662 super ().__init__ (kvargs )
5763 return
5864
@@ -61,8 +67,6 @@ def _init_config(self):
6167 # rename key
6268 # repair_config()
6369 self ._reset_num_key_value_heads ()
64- if self .enable_flashinfer :
65- self .flashinfer_extra_state = LlamaFlashInferStateExtraInfo (self )
6670 return
6771
6872 def _reset_num_key_value_heads (self ):
@@ -90,6 +94,9 @@ def _init_mem_manager(self):
9094 def _init_inferstate_cls (self ):
9195 if get_env_start_args ().enable_fa3 :
9296 self .infer_state_class = FlashAttentionStateInfo
97+ elif self .enable_flashinfer :
98+ self .infer_state_class = LlamaFlashInferStateInfo
99+ self .flashinfer_extra_state = LlamaFlashInferStateExtraInfo (self )
93100
94101 def _init_custom (self ):
95102 """
0 commit comments