Skip to content

Commit 1310c8a

Browse files
committed
refactor: image model get_num_tokens override
1 parent 8db35c4 commit 1310c8a

File tree

9 files changed

+156
-10
lines changed
  • apps/setting/models_provider/impl

9 files changed

+156
-10
lines changed

apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
# coding=utf-8
22

3-
from typing import Dict
3+
from typing import Dict, List
44

55
from langchain_community.chat_models import ChatOpenAI
6+
from langchain_core.messages import BaseMessage, get_buffer_string
67

8+
from common.config.tokenizer_manage_config import TokenizerManage
79
from setting.models_provider.base_model_provider import MaxKBBaseModel
810

911

12+
def custom_get_token_ids(text: str):
13+
tokenizer = TokenizerManage.get_tokenizer()
14+
return tokenizer.encode(text)
15+
16+
1017
class QwenVLChatModel(MaxKBBaseModel, ChatOpenAI):
1118

1219
@staticmethod
@@ -21,3 +28,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
2128
**optional_params,
2229
)
2330
return chat_tong_yi
31+
32+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
33+
try:
34+
return super().get_num_tokens_from_messages(messages)
35+
except Exception as e:
36+
tokenizer = TokenizerManage.get_tokenizer()
37+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
38+
39+
def get_num_tokens(self, text: str) -> int:
40+
try:
41+
return super().get_num_tokens(text)
42+
except Exception as e:
43+
tokenizer = TokenizerManage.get_tokenizer()
44+
return len(tokenizer.encode(text))

apps/setting/models_provider/impl/azure_model_provider/model/image.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import Dict
1+
from typing import Dict, List
22

3+
from langchain_core.messages import BaseMessage, get_buffer_string
34
from langchain_openai import AzureChatOpenAI
4-
from langchain_openai.chat_models import ChatOpenAI
55

66
from common.config.tokenizer_manage_config import TokenizerManage
77
from setting.models_provider.base_model_provider import MaxKBBaseModel
@@ -26,3 +26,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
2626
streaming=True,
2727
**optional_params,
2828
)
29+
30+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
31+
try:
32+
return super().get_num_tokens_from_messages(messages)
33+
except Exception as e:
34+
tokenizer = TokenizerManage.get_tokenizer()
35+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
36+
37+
def get_num_tokens(self, text: str) -> int:
38+
try:
39+
return super().get_num_tokens(text)
40+
except Exception as e:
41+
tokenizer = TokenizerManage.get_tokenizer()
42+
return len(tokenizer.encode(text))

apps/setting/models_provider/impl/ollama_model_provider/model/image.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Dict
1+
from typing import Dict, List
22
from urllib.parse import urlparse, ParseResult
33

4+
from langchain_core.messages import get_buffer_string, BaseMessage
45
from langchain_openai.chat_models import ChatOpenAI
56

67
from common.config.tokenizer_manage_config import TokenizerManage
@@ -36,3 +37,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3637
streaming=True,
3738
**optional_params,
3839
)
40+
41+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
42+
try:
43+
return super().get_num_tokens_from_messages(messages)
44+
except Exception as e:
45+
tokenizer = TokenizerManage.get_tokenizer()
46+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
47+
48+
def get_num_tokens(self, text: str) -> int:
49+
try:
50+
return super().get_num_tokens(text)
51+
except Exception as e:
52+
tokenizer = TokenizerManage.get_tokenizer()
53+
return len(tokenizer.encode(text))

apps/setting/models_provider/impl/openai_model_provider/model/image.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Dict
1+
from typing import Dict, List
22

3+
from langchain_core.messages import BaseMessage, get_buffer_string
34
from langchain_openai.chat_models import ChatOpenAI
45

56
from common.config.tokenizer_manage_config import TokenizerManage
@@ -24,3 +25,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
2425
streaming=True,
2526
**optional_params,
2627
)
28+
29+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
30+
try:
31+
return super().get_num_tokens_from_messages(messages)
32+
except Exception as e:
33+
tokenizer = TokenizerManage.get_tokenizer()
34+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
35+
36+
def get_num_tokens(self, text: str) -> int:
37+
try:
38+
return super().get_num_tokens(text)
39+
except Exception as e:
40+
tokenizer = TokenizerManage.get_tokenizer()
41+
return len(tokenizer.encode(text))

apps/setting/models_provider/impl/qwen_model_provider/model/image.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
# coding=utf-8
22

3-
from typing import Dict
3+
from typing import Dict, List
44

55
from langchain_community.chat_models import ChatOpenAI
6+
from langchain_core.messages import BaseMessage, get_buffer_string
67

8+
from common.config.tokenizer_manage_config import TokenizerManage
79
from setting.models_provider.base_model_provider import MaxKBBaseModel
810

911

12+
def custom_get_token_ids(text: str):
13+
tokenizer = TokenizerManage.get_tokenizer()
14+
return tokenizer.encode(text)
15+
16+
1017
class QwenVLChatModel(MaxKBBaseModel, ChatOpenAI):
1118

1219
@staticmethod
@@ -21,3 +28,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
2128
**optional_params,
2229
)
2330
return chat_tong_yi
31+
32+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
33+
try:
34+
return super().get_num_tokens_from_messages(messages)
35+
except Exception as e:
36+
tokenizer = TokenizerManage.get_tokenizer()
37+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
38+
39+
def get_num_tokens(self, text: str) -> int:
40+
try:
41+
return super().get_num_tokens(text)
42+
except Exception as e:
43+
tokenizer = TokenizerManage.get_tokenizer()
44+
return len(tokenizer.encode(text))

apps/setting/models_provider/impl/tencent_model_provider/model/image.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Dict
1+
from typing import Dict, List
22

3+
from langchain_core.messages import BaseMessage, get_buffer_string
34
from langchain_openai.chat_models import ChatOpenAI
45

56
from common.config.tokenizer_manage_config import TokenizerManage
@@ -24,3 +25,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
2425
streaming=True,
2526
**optional_params,
2627
)
28+
29+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
30+
try:
31+
return super().get_num_tokens_from_messages(messages)
32+
except Exception as e:
33+
tokenizer = TokenizerManage.get_tokenizer()
34+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
35+
36+
def get_num_tokens(self, text: str) -> int:
37+
try:
38+
return super().get_num_tokens(text)
39+
except Exception as e:
40+
tokenizer = TokenizerManage.get_tokenizer()
41+
return len(tokenizer.encode(text))

apps/setting/models_provider/impl/volcanic_engine_model_provider/model/image.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Dict
1+
from typing import Dict, List
22

3+
from langchain_core.messages import BaseMessage, get_buffer_string
34
from langchain_openai.chat_models import ChatOpenAI
45

56
from common.config.tokenizer_manage_config import TokenizerManage
@@ -24,3 +25,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
2425
streaming=True,
2526
**optional_params,
2627
)
28+
29+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
30+
try:
31+
return super().get_num_tokens_from_messages(messages)
32+
except Exception as e:
33+
tokenizer = TokenizerManage.get_tokenizer()
34+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
35+
36+
def get_num_tokens(self, text: str) -> int:
37+
try:
38+
return super().get_num_tokens(text)
39+
except Exception as e:
40+
tokenizer = TokenizerManage.get_tokenizer()
41+
return len(tokenizer.encode(text))

apps/setting/models_provider/impl/xinference_model_provider/model/image.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Dict
1+
from typing import Dict, List
22

3+
from langchain_core.messages import BaseMessage, get_buffer_string
34
from langchain_openai.chat_models import ChatOpenAI
45

56
from common.config.tokenizer_manage_config import TokenizerManage
@@ -24,3 +25,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
2425
streaming=True,
2526
**optional_params,
2627
)
28+
29+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
30+
try:
31+
return super().get_num_tokens_from_messages(messages)
32+
except Exception as e:
33+
tokenizer = TokenizerManage.get_tokenizer()
34+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
35+
36+
def get_num_tokens(self, text: str) -> int:
37+
try:
38+
return super().get_num_tokens(text)
39+
except Exception as e:
40+
tokenizer = TokenizerManage.get_tokenizer()
41+
return len(tokenizer.encode(text))

apps/setting/models_provider/impl/zhipu_model_provider/model/image.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Dict
1+
from typing import Dict, List
22

3+
from langchain_core.messages import get_buffer_string, BaseMessage
34
from langchain_openai.chat_models import ChatOpenAI
45

56
from common.config.tokenizer_manage_config import TokenizerManage
@@ -24,3 +25,17 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
2425
streaming=True,
2526
**optional_params,
2627
)
28+
29+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
30+
try:
31+
return super().get_num_tokens_from_messages(messages)
32+
except Exception as e:
33+
tokenizer = TokenizerManage.get_tokenizer()
34+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
35+
36+
def get_num_tokens(self, text: str) -> int:
37+
try:
38+
return super().get_num_tokens(text)
39+
except Exception as e:
40+
tokenizer = TokenizerManage.get_tokenizer()
41+
return len(tokenizer.encode(text))

0 commit comments

Comments
 (0)