@@ -213,21 +213,6 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar
213213 audio_tokens += self .tokenizer .get_audio_token_length (audio )
214214 return len (prompt_ids ) + image_tokens + img_count + audio_tokens + audio_count
215215
216- def _calculate_multimodal_tokens (self , multimodal_params : MultimodalParams ) -> Tuple [int , int ]:
217- image_tokens = 0
218- audio_tokens = 0
219-
220- if self .enable_multimodal and self .pd_mode .is_P_or_NORMAL ():
221- for img in multimodal_params .images :
222- image_tokens += self .tokenizer .get_image_token_length (img )
223- for audio in multimodal_params .audios :
224- audio_tokens += self .tokenizer .get_audio_token_length (audio )
225- else :
226- image_tokens = len (multimodal_params .images )
227- audio_tokens = len (multimodal_params .audios )
228-
229- return image_tokens , audio_tokens
230-
231216 async def loop_for_request (self ):
232217 assert self .args .node_rank > 0
233218 while True :
@@ -327,15 +312,6 @@ async def generate(
327312
328313 req_status = ReqStatus (group_request_id , multimodal_params , req_objs , start_time )
329314
330- # 计算输入 token 使用量统计
331- text_tokens = len (prompt_ids )
332- image_tokens , audio_tokens = self ._calculate_multimodal_tokens (multimodal_params )
333- input_usage = {
334- "input_text_tokens" : text_tokens ,
335- "input_audio_tokens" : audio_tokens ,
336- "input_image_tokens" : image_tokens ,
337- }
338-
339315 self .req_id_to_out_inf [group_request_id ] = req_status
340316
341317 await self .transfer_to_next_module_or_node (
@@ -350,8 +326,18 @@ async def generate(
350326 req_status ,
351327 request ,
352328 )
329+
330+ # 计算输入 token 使用量统计
331+ image_tokens , audio_tokens = self ._count_multimodal_tokens (multimodal_params )
332+ text_tokens = len (prompt_ids ) - (image_tokens + audio_tokens )
333+ input_usage = {
334+ "input_text_tokens" : text_tokens ,
335+ "input_audio_tokens" : audio_tokens ,
336+ "input_image_tokens" : image_tokens ,
337+ }
338+
353339 async for sub_req_id , request_output , metadata , finish_status in results_generator :
354- metadata ["usage " ] = { ** input_usage , ** metadata . get ( "usage" , {})}
340+ metadata ["input_usage " ] = input_usage
355341 yield sub_req_id , request_output , metadata , finish_status
356342
357343 except Exception as e :
@@ -366,6 +352,20 @@ async def generate(
366352 raise e
367353 return
368354
355+ def _count_multimodal_tokens (self , multimodal_params : MultimodalParams ) -> Tuple [int , int ]:
356+ image_tokens = 0
357+ audio_tokens = 0
358+
359+ if self .enable_multimodal and self .pd_mode .is_P_or_NORMAL ():
360+ for img in multimodal_params .images :
361+ if img .token_num is not None :
362+ image_tokens += img .token_num
363+ for audio in multimodal_params .audios :
364+ if audio .token_num is not None :
365+ audio_tokens += audio .token_num
366+
367+ return image_tokens , audio_tokens
368+
369369 async def _log_req_header (self , request_headers , group_request_id : int ):
370370
371371 x_request_id = request_headers .get ("X-Request-Id" , "" )
@@ -539,8 +539,6 @@ async def _wait_to_token_package(
539539 if self .pd_mode == NodeRole .P and is_first_token :
540540 metadata ["prompt_ids" ] = prompt_ids
541541
542- metadata ["usage" ] = {"output_tokens" : out_token_counter }
543-
544542 prompt_cache_len = metadata .pop ("prompt_cache_len" , 0 )
545543 if is_first_token :
546544 first_token_cost_ms = (time .time () - start_time ) * 1000
0 commit comments