@@ -213,6 +213,21 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar
213213 audio_tokens += self .tokenizer .get_audio_token_length (audio )
214214 return len (prompt_ids ) + image_tokens + img_count + audio_tokens + audio_count
215215
216+ def _calculate_multimodal_tokens (self , multimodal_params : MultimodalParams ) -> Tuple [int , int ]:
217+ image_tokens = 0
218+ audio_tokens = 0
219+
220+ if self .enable_multimodal and self .pd_mode .is_P_or_NORMAL ():
221+ for img in multimodal_params .images :
222+ image_tokens += self .tokenizer .get_image_token_length (img )
223+ for audio in multimodal_params .audios :
224+ audio_tokens += self .tokenizer .get_audio_token_length (audio )
225+ else :
226+ image_tokens = len (multimodal_params .images )
227+ audio_tokens = len (multimodal_params .audios )
228+
229+ return image_tokens , audio_tokens
230+
216231 async def loop_for_request (self ):
217232 assert self .args .node_rank > 0
218233 while True :
@@ -311,6 +326,16 @@ async def generate(
311326 req_objs .append (req_obj )
312327
313328 req_status = ReqStatus (group_request_id , multimodal_params , req_objs , start_time )
329+
330+ # 计算输入 token 使用量统计
331+ text_tokens = len (prompt_ids )
332+ image_tokens , audio_tokens = self ._calculate_multimodal_tokens (multimodal_params )
333+ input_usage = {
334+ "input_text_tokens" : text_tokens ,
335+ "input_audio_tokens" : audio_tokens ,
336+ "input_image_tokens" : image_tokens ,
337+ }
338+
314339 self .req_id_to_out_inf [group_request_id ] = req_status
315340
316341 await self .transfer_to_next_module_or_node (
@@ -326,6 +351,7 @@ async def generate(
326351 request ,
327352 )
328353 async for sub_req_id , request_output , metadata , finish_status in results_generator :
354+ metadata ["usage" ] = {** input_usage , ** metadata .get ("usage" , {})}
329355 yield sub_req_id , request_output , metadata , finish_status
330356
331357 except Exception as e :
@@ -513,6 +539,8 @@ async def _wait_to_token_package(
513539 if self .pd_mode == NodeRole .P and is_first_token :
514540 metadata ["prompt_ids" ] = prompt_ids
515541
542+ metadata ["usage" ] = {"output_tokens" : out_token_counter }
543+
516544 prompt_cache_len = metadata .pop ("prompt_cache_len" , 0 )
517545 if is_first_token :
518546 first_token_cost_ms = (time .time () - start_time ) * 1000
0 commit comments