Skip to content

Commit 8a30bd2

Browse files
committed
fix: VLLM supplier recalculates token function
1 parent 2d4deda commit 8a30bd2

File tree

2 files changed

+32
-2
lines changed
  • apps/setting/models_provider/impl/vllm_model_provider/model

2 files changed

+32
-2
lines changed

apps/setting/models_provider/impl/vllm_model_provider/model/image.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
from typing import Dict
1+
from typing import Dict, List
22

3+
from langchain_core.messages import get_buffer_string, BaseMessage
4+
5+
from common.config.tokenizer_manage_config import TokenizerManage
36
from setting.models_provider.base_model_provider import MaxKBBaseModel
47
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
58

@@ -21,3 +24,15 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
2124

2225
def is_cache_model(self):
2326
return False
27+
28+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
29+
if self.usage_metadata is None or self.usage_metadata == {}:
30+
tokenizer = TokenizerManage.get_tokenizer()
31+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
32+
return self.usage_metadata.get('input_tokens', 0)
33+
34+
def get_num_tokens(self, text: str) -> int:
35+
if self.usage_metadata is None or self.usage_metadata == {}:
36+
tokenizer = TokenizerManage.get_tokenizer()
37+
return len(tokenizer.encode(text))
38+
return self.get_last_generation_info().get('output_tokens', 0)

apps/setting/models_provider/impl/vllm_model_provider/model/llm.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
# coding=utf-8
22

3-
from typing import Dict
3+
from typing import Dict, List
44
from urllib.parse import urlparse, ParseResult
55

6+
from langchain_core.messages import BaseMessage, get_buffer_string
7+
8+
from common.config.tokenizer_manage_config import TokenizerManage
69
from setting.models_provider.base_model_provider import MaxKBBaseModel
710
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
811

@@ -33,3 +36,15 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3336
stream_usage=True,
3437
)
3538
return vllm_chat_open_ai
39+
40+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
41+
if self.usage_metadata is None or self.usage_metadata == {}:
42+
tokenizer = TokenizerManage.get_tokenizer()
43+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
44+
return self.usage_metadata.get('input_tokens', 0)
45+
46+
def get_num_tokens(self, text: str) -> int:
47+
if self.usage_metadata is None or self.usage_metadata == {}:
48+
tokenizer = TokenizerManage.get_tokenizer()
49+
return len(tokenizer.encode(text))
50+
return self.get_last_generation_info().get('output_tokens', 0)

0 commit comments

Comments
 (0)