@@ -24,13 +24,13 @@ def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int):
2424 ]
2525 return cls ._shared_page_table_buffer
2626
27- def init_some_extra_state (self , model , model_input : ModelInput ):
28- super ().init_some_extra_state (model , model_input )
27+ def init_some_extra_state (self , model , input_ids : torch . Tensor ):
28+ super ().init_some_extra_state (model , input_ids )
2929 if self .is_prefill :
3030 self .cu_seqlens_q = self .b1_cu_q_seq_len .int ()
3131 self .cu_seqlens_k = self .b1_cu_kv_seq_len .int ()
3232 self .page_table = torch .empty (
33- (self .batch_size , self .max_seq_len ), dtype = torch .int32 , device = model_input . input_ids .device
33+ (self .batch_size , self .max_seq_len ), dtype = torch .int32 , device = input_ids .device
3434 )
3535 self .page_table .copy_ (model .req_manager .req_to_token_indexs [self .b_req_idx , : self .max_seq_len ])
3636 else :
@@ -47,7 +47,7 @@ def init_some_extra_state(self, model, model_input: ModelInput):
4747 ].reshape (self .batch_size , model .graph_max_len_in_batch )
4848 else :
4949 self .page_table = torch .empty (
50- (self .batch_size , self .max_len_in_batch ), dtype = torch .int32 , device = model_input . input_ids .device
50+ (self .batch_size , self .max_len_in_batch ), dtype = torch .int32 , device = input_ids .device
5151 )
5252
5353 self .page_table [:, :max_seq_len_k ].copy_ (
@@ -58,7 +58,7 @@ def init_some_extra_state(self, model, model_input: ModelInput):
5858
5959 if "offline_calibration_fp8kv" in model .mode :
6060 if self .is_prefill :
61- device = model_input . input_ids .device
61+ device = input_ids .device
6262 # q_scale和token_batch_ids在对q做per head量化使用,为了节省资源在推理外部初始化
6363 self .q_scale = torch .empty (
6464 (self .batch_size , self .mem_manager .head_num ), dtype = torch .float32 , device = device
@@ -78,7 +78,7 @@ def init_some_extra_state(self, model, model_input: ModelInput):
7878 else torch .ones (
7979 (self .mem_manager .layer_num , self .batch_size , head_num ),
8080 dtype = torch .float32 ,
81- device = model_input . input_ids .device ,
81+ device = input_ids .device ,
8282 )
8383 )
8484 self .v_descale = (
@@ -89,7 +89,7 @@ def init_some_extra_state(self, model, model_input: ModelInput):
8989 else torch .ones (
9090 (self .mem_manager .layer_num , self .batch_size , head_num ),
9191 dtype = torch .float32 ,
92- device = model_input . input_ids .device ,
92+ device = input_ids .device ,
9393 )
9494 )
9595 return
0 commit comments