@@ -13,20 +13,14 @@ def __init__(self):
1313 super ().__init__ ()
1414 self .prefill_wrapper = None
1515 self .decode_wrapper = None
16+ self .flashinfer_extra_state = None
1617
1718 def init_some_extra_state (self , model , input_ids : torch .Tensor ):
1819 super ().init_some_extra_state (model , input_ids )
20+ self .flashinfer_extra_state = model .flashinfer_extra_state
1921
2022 if not self .is_prefill :
2123 if enable_env_vars ("ENABLE_FLASHINFER_DECODE_MLA" ):
22- self .tp_q_head_num = model .flashinfer_state .tp_q_head_num
23- self .kv_lora_rank = model .flashinfer_state .kv_lora_rank
24- self .qk_rope_head_dim = model .flashinfer_state .qk_rope_head_dim
25- self .qk_nope_head_dim = model .flashinfer_state .qk_nope_head_dim
26- self .softmax_scale = model .flashinfer_state .softmax_scale
27- self .q_data_type = model .flashinfer_state .data_type
28- self .kv_data_type = model .flashinfer_state .data_type
29-
3024 self .q_indptr = torch .arange (self .batch_size + 1 , dtype = torch .int32 ).to (input_ids .device )
3125 self .kv_indices = torch .empty (
3226 self .batch_size * model .flashinfer_state .max_seq_length , dtype = torch .int32
@@ -41,7 +35,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
4135 )
4236 if self .decode_wrapper is None :
4337 self .decode_wrapper = flashinfer .mla .BatchMLAPagedAttentionWrapper (
44- model . flashinfer_state .workspace_buffer ,
38+ self . flashinfer_extra_state .workspace_buffer ,
4539 use_cuda_graph = True ,
4640 qo_indptr = self .q_indptr ,
4741 kv_indices = self .kv_indices ,
@@ -53,23 +47,17 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
5347 self .kv_starts ,
5448 self .kv_indices ,
5549 self .b_seq_len ,
56- self .tp_q_head_num ,
57- self .kv_lora_rank ,
58- self .qk_rope_head_dim ,
50+ self .flashinfer_extra_state . tp_q_head_num ,
51+ self .flashinfer_extra_state . kv_lora_rank ,
52+ self .flashinfer_extra_state . qk_rope_head_dim ,
5953 1 ,
6054 False , # causal
61- self .softmax_scale ,
62- self .q_data_type ,
63- self .kv_data_type ,
55+ self .flashinfer_extra_state . softmax_scale ,
56+ self .flashinfer_extra_state . q_data_type ,
57+ self .flashinfer_extra_state . kv_data_type ,
6458 )
6559 else :
6660 if enable_env_vars ("ENABLE_FLASHINFER_PREFILLED" ):
67- self .tp_q_head_num = model .flashinfer_state .tp_q_head_num
68- self .qk_rope_head_dim = model .flashinfer_state .qk_rope_head_dim
69- self .qk_nope_head_dim = model .flashinfer_state .qk_nope_head_dim
70- self .softmax_scale = model .flashinfer_state .softmax_scale
71- self .q_data_type = model .flashinfer_state .data_type
72-
7361 q_starts = torch .cat (
7462 [self .b_start_loc , self .b_start_loc [- 1 :] + (self .b_seq_len - self .b_ready_cache_len )[- 1 :]], dim = 0
7563 ).int ()
@@ -78,18 +66,19 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
7866 ).int ()
7967 if self .prefill_wrapper is None :
8068 self .prefill_wrapper = flashinfer .prefill .BatchPrefillWithRaggedKVCacheWrapper (
81- model . flashinfer_state .workspace_buffer , "NHD"
69+ self . flashinfer_extra_state .workspace_buffer , "NHD"
8270 )
8371 self .prefill_wrapper .plan (
8472 qo_indptr = q_starts ,
8573 kv_indptr = kv_starts ,
86- num_qo_heads = self .tp_q_head_num ,
87- num_kv_heads = self .tp_q_head_num ,
88- head_dim_qk = self .qk_nope_head_dim + self .qk_rope_head_dim ,
89- head_dim_vo = self .qk_nope_head_dim ,
90- q_data_type = self .q_data_type ,
74+ num_qo_heads = self .flashinfer_extra_state .tp_q_head_num ,
75+ num_kv_heads = self .flashinfer_extra_state .tp_q_head_num ,
76+ head_dim_qk = self .flashinfer_extra_state .qk_nope_head_dim
77+ + self .flashinfer_extra_state .qk_rope_head_dim ,
78+ head_dim_vo = self .flashinfer_extra_state .qk_nope_head_dim ,
79+ q_data_type = self .flashinfer_extra_state .q_data_type ,
9180 causal = True ,
92- sm_scale = self .softmax_scale ,
81+ sm_scale = self .flashinfer_extra_state . softmax_scale ,
9382 )
9483 return
9584
@@ -101,13 +90,13 @@ def copy_for_cuda_graph(self, new_infer_state):
10190 new_infer_state .kv_starts ,
10291 new_infer_state .kv_indices ,
10392 new_infer_state .b_seq_len ,
104- new_infer_state .tp_q_head_num ,
105- new_infer_state .kv_lora_rank ,
106- new_infer_state .qk_rope_head_dim ,
93+ new_infer_state .flashinfer_extra_state . tp_q_head_num ,
94+ new_infer_state .flashinfer_extra_state . kv_lora_rank ,
95+ new_infer_state .flashinfer_extra_state . qk_rope_head_dim ,
10796 1 ,
10897 False , # causal
109- new_infer_state .softmax_scale ,
110- new_infer_state .q_data_type ,
111- new_infer_state .kv_data_type ,
98+ new_infer_state .flashinfer_extra_state . softmax_scale ,
99+ new_infer_state .flashinfer_extra_state . q_data_type ,
100+ new_infer_state .flashinfer_extra_state . kv_data_type ,
112101 )
113102 return
0 commit comments