Skip to content

Commit 41b9193

Browse files
Add multimodal token usage (#1016)
Co-authored-by: pigKiller <zhhang.bian@gmail.com>
1 parent 120c833 commit 41b9193

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

lightllm/server/api_lightllm.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,22 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
5252
prompt_tokens = 0
5353
prompt_token_ids = None
5454
is_first_metadata = True
55+
input_usage = None
5556
async for sub_req_id, request_output, metadata, finish_status in results_generator:
5657
# when set "--return_all_prompt_logprobs", the first token metadata will contains
5758
# prompt_logprobs and prompt_token_ids
5859
if is_first_metadata:
5960
prompt_logprobs = metadata.get("prompt_logprobs", None)
6061
prompt_token_ids = metadata.get("prompt_token_ids", None)
6162
prompt_tokens = metadata.get("prompt_tokens", 0)
63+
input_usage = metadata.get("input_usage", None)
6264
if prompt_logprobs is not None:
6365
del metadata["prompt_logprobs"]
6466
if prompt_token_ids is not None:
6567
del metadata["prompt_token_ids"]
68+
if input_usage is not None:
69+
del metadata["input_usage"]
70+
6671
is_first_metadata = False
6772

6873
count_output_tokens_dict[sub_req_id] += 1
@@ -95,6 +100,9 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
95100
ret["prompt_token_ids"] = prompt_token_ids
96101
if prompt_logprobs is not None:
97102
ret["prompt_logprobs"] = prompt_logprobs
103+
if input_usage is not None:
104+
ret["input_usage"] = input_usage
105+
98106
return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8"))
99107

100108

@@ -116,7 +124,12 @@ async def lightllm_generate_stream(request: Request, httpserver_manager: HttpSer
116124

117125
# Streaming case
118126
async def stream_results() -> AsyncGenerator[bytes, None]:
127+
# input_usage 只会在第一个metadata中出现,所以需要保存下来
128+
input_usage = None
119129
async for _, request_output, metadata, finish_status in results_generator:
130+
if input_usage is None:
131+
input_usage = metadata.get("input_usage", None)
132+
120133
ret = {
121134
"token": {
122135
"id": metadata.get("id", None),
@@ -130,6 +143,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
130143
"finished": finish_status.is_finished(),
131144
"finish_reason": finish_status.get_finish_reason(),
132145
"details": None,
146+
"input_usage": input_usage,
133147
}
134148

135149
yield ("data:" + json.dumps(ret, ensure_ascii=False) + "\n\n").encode("utf-8")

lightllm/server/httpserver/manager.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,23 @@ async def generate(
325325
req_status,
326326
request,
327327
)
328+
329+
# 计算输入 token 使用量统计
330+
image_tokens, audio_tokens = self._count_multimodal_tokens(multimodal_params)
331+
text_tokens = len(prompt_ids) - (image_tokens + audio_tokens)
332+
input_usage = {
333+
"input_text_tokens": text_tokens,
334+
"input_audio_tokens": audio_tokens,
335+
"input_image_tokens": image_tokens,
336+
}
337+
338+
is_first_gen_token = True
328339
async for sub_req_id, request_output, metadata, finish_status in results_generator:
340+
# 只有第一个生成的 token 的 metadata 中包含 input_usage
341+
if is_first_gen_token:
342+
metadata["input_usage"] = input_usage
343+
is_first_gen_token = False
344+
329345
yield sub_req_id, request_output, metadata, finish_status
330346

331347
except Exception as e:
@@ -340,6 +356,20 @@ async def generate(
340356
raise e
341357
return
342358

359+
def _count_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple[int, int]:
360+
image_tokens = 0
361+
audio_tokens = 0
362+
363+
if self.enable_multimodal and self.pd_mode.is_P_or_NORMAL() and multimodal_params is not None:
364+
for img in multimodal_params.images:
365+
if img.token_num is not None:
366+
image_tokens += img.token_num
367+
for audio in multimodal_params.audios:
368+
if audio.token_num is not None:
369+
audio_tokens += audio.token_num
370+
371+
return image_tokens, audio_tokens
372+
343373
async def _log_req_header(self, request_headers, group_request_id: int):
344374

345375
x_request_id = request_headers.get("X-Request-Id", "")

0 commit comments

Comments
 (0)