Skip to content

Commit e12b1fe

Browse files
committed
refactor: gemini
1 parent 7ce66a7 commit e12b1fe

File tree

1 file changed

+13
-3
lines changed
  • apps/setting/models_provider/impl/gemini_model_provider/model

1 file changed

+13
-3
lines changed

apps/setting/models_provider/impl/gemini_model_provider/model/llm.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
Tool as GoogleTool,
1414
)
1515
from langchain_core.callbacks import CallbackManagerForLLMRun
16-
from langchain_core.messages import BaseMessage
16+
from langchain_core.messages import BaseMessage, get_buffer_string
1717
from langchain_core.outputs import ChatGenerationChunk
1818
from langchain_google_genai import ChatGoogleGenerativeAI
1919
from langchain_google_genai._function_utils import _ToolConfigDict, _ToolDict
@@ -22,6 +22,8 @@
2222
from langchain_google_genai._common import (
2323
SafetySettingDict,
2424
)
25+
26+
from common.config.tokenizer_manage_config import TokenizerManage
2527
from setting.models_provider.base_model_provider import MaxKBBaseModel
2628

2729

@@ -46,10 +48,18 @@ def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
4648
return self.__dict__.get('_last_generation_info')
4749

4850
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
49-
return self.get_last_generation_info().get('input_tokens', 0)
51+
try:
52+
return self.get_last_generation_info().get('input_tokens', 0)
53+
except Exception as e:
54+
tokenizer = TokenizerManage.get_tokenizer()
55+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
5056

5157
def get_num_tokens(self, text: str) -> int:
52-
return self.get_last_generation_info().get('output_tokens', 0)
58+
try:
59+
return self.get_last_generation_info().get('output_tokens', 0)
60+
except Exception as e:
61+
tokenizer = TokenizerManage.get_tokenizer()
62+
return len(tokenizer.encode(text))
5363

5464
def _stream(
5565
self,

0 commit comments

Comments
 (0)