Skip to content

Commit 3af68a0

Browse files
authored
Add multimodal token usage (#1011)
1 parent 120c833 commit 3af68a0

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

lightllm/server/api_lightllm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
5252
prompt_tokens = 0
5353
prompt_token_ids = None
5454
is_first_metadata = True
55+
usage = None
5556
async for sub_req_id, request_output, metadata, finish_status in results_generator:
5657
# when set "--return_all_prompt_logprobs", the first token metadata will contains
5758
# prompt_logprobs and prompt_token_ids
@@ -65,6 +66,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
6566
del metadata["prompt_token_ids"]
6667
is_first_metadata = False
6768

69+
usage = metadata.get("usage", None)
70+
6871
count_output_tokens_dict[sub_req_id] += 1
6972
final_output_dict[sub_req_id].append(request_output)
7073
if return_details:
@@ -95,6 +98,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana
9598
ret["prompt_token_ids"] = prompt_token_ids
9699
if prompt_logprobs is not None:
97100
ret["prompt_logprobs"] = prompt_logprobs
101+
if usage is not None:
102+
ret["usage"] = usage
98103
return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8"))
99104

100105

@@ -130,6 +135,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
130135
"finished": finish_status.is_finished(),
131136
"finish_reason": finish_status.get_finish_reason(),
132137
"details": None,
138+
"usage": metadata.get("usage", None),
133139
}
134140

135141
yield ("data:" + json.dumps(ret, ensure_ascii=False) + "\n\n").encode("utf-8")

lightllm/server/httpserver/manager.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,21 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar
213213
audio_tokens += self.tokenizer.get_audio_token_length(audio)
214214
return len(prompt_ids) + image_tokens + img_count + audio_tokens + audio_count
215215

216+
def _calculate_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple[int, int]:
217+
image_tokens = 0
218+
audio_tokens = 0
219+
220+
if self.enable_multimodal and self.pd_mode.is_P_or_NORMAL():
221+
for img in multimodal_params.images:
222+
image_tokens += self.tokenizer.get_image_token_length(img)
223+
for audio in multimodal_params.audios:
224+
audio_tokens += self.tokenizer.get_audio_token_length(audio)
225+
else:
226+
image_tokens = len(multimodal_params.images)
227+
audio_tokens = len(multimodal_params.audios)
228+
229+
return image_tokens, audio_tokens
230+
216231
async def loop_for_request(self):
217232
assert self.args.node_rank > 0
218233
while True:
@@ -311,6 +326,16 @@ async def generate(
311326
req_objs.append(req_obj)
312327

313328
req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time)
329+
330+
# 计算输入 token 使用量统计
331+
text_tokens = len(prompt_ids)
332+
image_tokens, audio_tokens = self._calculate_multimodal_tokens(multimodal_params)
333+
input_usage = {
334+
"input_text_tokens": text_tokens,
335+
"input_audio_tokens": audio_tokens,
336+
"input_image_tokens": image_tokens,
337+
}
338+
314339
self.req_id_to_out_inf[group_request_id] = req_status
315340

316341
await self.transfer_to_next_module_or_node(
@@ -326,6 +351,7 @@ async def generate(
326351
request,
327352
)
328353
async for sub_req_id, request_output, metadata, finish_status in results_generator:
354+
metadata["usage"] = {**input_usage, **metadata.get("usage", {})}
329355
yield sub_req_id, request_output, metadata, finish_status
330356

331357
except Exception as e:
@@ -513,6 +539,8 @@ async def _wait_to_token_package(
513539
if self.pd_mode == NodeRole.P and is_first_token:
514540
metadata["prompt_ids"] = prompt_ids
515541

542+
metadata["usage"] = {"output_tokens": out_token_counter}
543+
516544
prompt_cache_len = metadata.pop("prompt_cache_len", 0)
517545
if is_first_token:
518546
first_token_cost_ms = (time.time() - start_time) * 1000

0 commit comments

Comments
 (0)