Skip to content

Commit 6f73a16

Browse files
author
sangchengmeng
committed
[add]add tokens_num api
1 parent 8eabdfd commit 6f73a16

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

lightllm/server/api_http.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,57 @@ async def tokens(request: Request):
385385
return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}")
386386

387387

388+
# for special cases
389+
@app.get("/tokens_num")
390+
@app.post("/tokens_num")
391+
async def tokens_num(request: Request):
392+
try:
393+
request_dict = await request.json()
394+
prompt = request_dict.pop("text")
395+
sample_params_dict = request_dict.pop("parameters", {})
396+
397+
sampling_params = SamplingParams()
398+
sampling_params.init(tokenizer=g_objs.httpserver_manager.tokenizer, **sample_params_dict)
399+
sampling_params.verify()
400+
401+
multimodal_params_dict = request_dict.get("multimodal_params", {})
402+
images_size = multimodal_params_dict.get("images", [])
403+
404+
prompt_ids = g_objs.httpserver_manager.tokenizer.encode(prompt, None, add_special_tokens=False)
405+
image_tokens = 0
406+
img_count = 0
407+
max_num = 0
408+
if sampling_params.image_max_patch_num >= 0:
409+
max_num = sampling_params.image_max_patch_num
410+
else:
411+
num_images = len(images_size)
412+
if num_images == 1:
413+
max_num = 12
414+
elif num_images > 1 and num_images <= 6:
415+
max_num = 6
416+
elif num_images > 6:
417+
max_num = 0
418+
image_token_length = int(os.environ.get("INTERNVL_IMAGE_LENGTH", 256))
419+
420+
for img_size in images_size:
421+
img_count += 1
422+
image_tokens += (
423+
g_objs.httpserver_manager.tokenizer.get_image_patch_func(
424+
img_size[0], img_size[1], max_num=max_num, use_thumbnail=True
425+
)
426+
* image_token_length
427+
)
428+
429+
num_tokens = len(prompt_ids) + image_tokens + img_count
430+
431+
return JSONResponse(
432+
{"ntokens": num_tokens},
433+
status_code=200,
434+
)
435+
except Exception as e:
436+
return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}")
437+
438+
388439
@app.get("/metrics")
389440
async def metrics() -> Response:
390441
data = await g_objs.metric_client.generate_latest()

0 commit comments

Comments
 (0)