Skip to content

Commit eab0284

Browse files
committed
refactor: image model get_num_tokens override
1 parent 8db35c4 commit eab0284

File tree

9 files changed

+165
-10
lines changed
  • apps/setting/models_provider/impl

9 files changed

+165
-10
lines changed

apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
# coding=utf-8
22

3-
from typing import Dict
3+
from typing import Dict, List
44

55
from langchain_community.chat_models import ChatOpenAI
6+
from langchain_core.messages import BaseMessage, get_buffer_string
67

8+
from common.config.tokenizer_manage_config import TokenizerManage
79
from setting.models_provider.base_model_provider import MaxKBBaseModel
810

911

12+
def custom_get_token_ids(text: str):
13+
tokenizer = TokenizerManage.get_tokenizer()
14+
return tokenizer.encode(text)
15+
16+
1017
class QwenVLChatModel(MaxKBBaseModel, ChatOpenAI):
1118

1219
@staticmethod
@@ -18,6 +25,21 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
1825
openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
1926
# stream_options={"include_usage": True},
2027
streaming=True,
28+
custom_get_token_ids=custom_get_token_ids,
2129
**optional_params,
2230
)
2331
return chat_tong_yi
32+
33+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
34+
try:
35+
return super().get_num_tokens_from_messages(messages)
36+
except Exception as e:
37+
tokenizer = TokenizerManage.get_tokenizer()
38+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
39+
40+
def get_num_tokens(self, text: str) -> int:
41+
try:
42+
return super().get_num_tokens(text)
43+
except Exception as e:
44+
tokenizer = TokenizerManage.get_tokenizer()
45+
return len(tokenizer.encode(text))

apps/setting/models_provider/impl/azure_model_provider/model/image.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import Dict
1+
from typing import Dict, List
22

3+
from langchain_core.messages import BaseMessage, get_buffer_string
34
from langchain_openai import AzureChatOpenAI
4-
from langchain_openai.chat_models import ChatOpenAI
55

66
from common.config.tokenizer_manage_config import TokenizerManage
77
from setting.models_provider.base_model_provider import MaxKBBaseModel
@@ -24,5 +24,20 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
2424
openai_api_version=model_credential.get('api_version'),
2525
openai_api_type="azure",
2626
streaming=True,
27+
custom_get_token_ids=custom_get_token_ids,
2728
**optional_params,
2829
)
30+
31+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
32+
try:
33+
return super().get_num_tokens_from_messages(messages)
34+
except Exception as e:
35+
tokenizer = TokenizerManage.get_tokenizer()
36+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
37+
38+
def get_num_tokens(self, text: str) -> int:
39+
try:
40+
return super().get_num_tokens(text)
41+
except Exception as e:
42+
tokenizer = TokenizerManage.get_tokenizer()
43+
return len(tokenizer.encode(text))

apps/setting/models_provider/impl/ollama_model_provider/model/image.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Dict
1+
from typing import Dict, List
22
from urllib.parse import urlparse, ParseResult
33

4+
from langchain_core.messages import get_buffer_string, BaseMessage
45
from langchain_openai.chat_models import ChatOpenAI
56

67
from common.config.tokenizer_manage_config import TokenizerManage
@@ -34,5 +35,20 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
3435
openai_api_key=model_credential.get('api_key'),
3536
# stream_options={"include_usage": True},
3637
streaming=True,
38+
custom_get_token_ids=custom_get_token_ids,
3739
**optional_params,
3840
)
41+
42+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
43+
try:
44+
return super().get_num_tokens_from_messages(messages)
45+
except Exception as e:
46+
tokenizer = TokenizerManage.get_tokenizer()
47+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
48+
49+
def get_num_tokens(self, text: str) -> int:
50+
try:
51+
return super().get_num_tokens(text)
52+
except Exception as e:
53+
tokenizer = TokenizerManage.get_tokenizer()
54+
return len(tokenizer.encode(text))

apps/setting/models_provider/impl/openai_model_provider/model/image.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Dict
1+
from typing import Dict, List
22

3+
from langchain_core.messages import BaseMessage, get_buffer_string
34
from langchain_openai.chat_models import ChatOpenAI
45

56
from common.config.tokenizer_manage_config import TokenizerManage
@@ -22,5 +23,20 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
2223
openai_api_key=model_credential.get('api_key'),
2324
# stream_options={"include_usage": True},
2425
streaming=True,
26+
custom_get_token_ids=custom_get_token_ids,
2527
**optional_params,
2628
)
29+
30+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
31+
try:
32+
return super().get_num_tokens_from_messages(messages)
33+
except Exception as e:
34+
tokenizer = TokenizerManage.get_tokenizer()
35+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
36+
37+
def get_num_tokens(self, text: str) -> int:
38+
try:
39+
return super().get_num_tokens(text)
40+
except Exception as e:
41+
tokenizer = TokenizerManage.get_tokenizer()
42+
return len(tokenizer.encode(text))

apps/setting/models_provider/impl/qwen_model_provider/model/image.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
# coding=utf-8
22

3-
from typing import Dict
3+
from typing import Dict, List
44

55
from langchain_community.chat_models import ChatOpenAI
6+
from langchain_core.messages import BaseMessage, get_buffer_string
67

8+
from common.config.tokenizer_manage_config import TokenizerManage
79
from setting.models_provider.base_model_provider import MaxKBBaseModel
810

911

12+
def custom_get_token_ids(text: str):
13+
tokenizer = TokenizerManage.get_tokenizer()
14+
return tokenizer.encode(text)
15+
16+
1017
class QwenVLChatModel(MaxKBBaseModel, ChatOpenAI):
1118

1219
@staticmethod
@@ -18,6 +25,21 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
1825
openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
1926
# stream_options={"include_usage": True},
2027
streaming=True,
28+
custom_get_token_ids=custom_get_token_ids,
2129
**optional_params,
2230
)
2331
return chat_tong_yi
32+
33+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
34+
try:
35+
return super().get_num_tokens_from_messages(messages)
36+
except Exception as e:
37+
tokenizer = TokenizerManage.get_tokenizer()
38+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
39+
40+
def get_num_tokens(self, text: str) -> int:
41+
try:
42+
return super().get_num_tokens(text)
43+
except Exception as e:
44+
tokenizer = TokenizerManage.get_tokenizer()
45+
return len(tokenizer.encode(text))

apps/setting/models_provider/impl/tencent_model_provider/model/image.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Dict
1+
from typing import Dict, List
22

3+
from langchain_core.messages import BaseMessage, get_buffer_string
34
from langchain_openai.chat_models import ChatOpenAI
45

56
from common.config.tokenizer_manage_config import TokenizerManage
@@ -22,5 +23,20 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
2223
openai_api_key=model_credential.get('api_key'),
2324
# stream_options={"include_usage": True},
2425
streaming=True,
26+
custom_get_token_ids=custom_get_token_ids,
2527
**optional_params,
2628
)
29+
30+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
31+
try:
32+
return super().get_num_tokens_from_messages(messages)
33+
except Exception as e:
34+
tokenizer = TokenizerManage.get_tokenizer()
35+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
36+
37+
def get_num_tokens(self, text: str) -> int:
38+
try:
39+
return super().get_num_tokens(text)
40+
except Exception as e:
41+
tokenizer = TokenizerManage.get_tokenizer()
42+
return len(tokenizer.encode(text))

apps/setting/models_provider/impl/volcanic_engine_model_provider/model/image.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Dict
1+
from typing import Dict, List
22

3+
from langchain_core.messages import BaseMessage, get_buffer_string
34
from langchain_openai.chat_models import ChatOpenAI
45

56
from common.config.tokenizer_manage_config import TokenizerManage
@@ -22,5 +23,20 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
2223
openai_api_base=model_credential.get('api_base'),
2324
# stream_options={"include_usage": True},
2425
streaming=True,
26+
custom_get_token_ids=custom_get_token_ids,
2527
**optional_params,
2628
)
29+
30+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
31+
try:
32+
return super().get_num_tokens_from_messages(messages)
33+
except Exception as e:
34+
tokenizer = TokenizerManage.get_tokenizer()
35+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
36+
37+
def get_num_tokens(self, text: str) -> int:
38+
try:
39+
return super().get_num_tokens(text)
40+
except Exception as e:
41+
tokenizer = TokenizerManage.get_tokenizer()
42+
return len(tokenizer.encode(text))

apps/setting/models_provider/impl/xinference_model_provider/model/image.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Dict
1+
from typing import Dict, List
22

3+
from langchain_core.messages import BaseMessage, get_buffer_string
34
from langchain_openai.chat_models import ChatOpenAI
45

56
from common.config.tokenizer_manage_config import TokenizerManage
@@ -22,5 +23,20 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
2223
openai_api_key=model_credential.get('api_key'),
2324
# stream_options={"include_usage": True},
2425
streaming=True,
26+
custom_get_token_ids=custom_get_token_ids,
2527
**optional_params,
2628
)
29+
30+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
31+
try:
32+
return super().get_num_tokens_from_messages(messages)
33+
except Exception as e:
34+
tokenizer = TokenizerManage.get_tokenizer()
35+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
36+
37+
def get_num_tokens(self, text: str) -> int:
38+
try:
39+
return super().get_num_tokens(text)
40+
except Exception as e:
41+
tokenizer = TokenizerManage.get_tokenizer()
42+
return len(tokenizer.encode(text))

apps/setting/models_provider/impl/zhipu_model_provider/model/image.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Dict
1+
from typing import Dict, List
22

3+
from langchain_core.messages import get_buffer_string, BaseMessage
34
from langchain_openai.chat_models import ChatOpenAI
45

56
from common.config.tokenizer_manage_config import TokenizerManage
@@ -22,5 +23,20 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
2223
openai_api_base='https://open.bigmodel.cn/api/paas/v4',
2324
# stream_options={"include_usage": True},
2425
streaming=True,
26+
custom_get_token_ids=custom_get_token_ids,
2527
**optional_params,
2628
)
29+
30+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
31+
try:
32+
return super().get_num_tokens_from_messages(messages)
33+
except Exception as e:
34+
tokenizer = TokenizerManage.get_tokenizer()
35+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
36+
37+
def get_num_tokens(self, text: str) -> int:
38+
try:
39+
return super().get_num_tokens(text)
40+
except Exception as e:
41+
tokenizer = TokenizerManage.get_tokenizer()
42+
return len(tokenizer.encode(text))

0 commit comments

Comments
 (0)