Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 23 additions & 75 deletions apps/models_provider/impl/zhipu_model_provider/model/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,21 @@
@desc:
"""

import json
from collections.abc import Iterator
from typing import Any, Dict, List, Optional
from typing import Dict, List

from langchain_community.chat_models import ChatZhipuAI
from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \
_convert_delta_to_message_chunk
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.messages import (
AIMessageChunk,
BaseMessage
)
from langchain_core.outputs import ChatGenerationChunk
from langchain_core.messages import BaseMessage, get_buffer_string

from common.config.tokenizer_manage_config import TokenizerManage
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI


class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
optional_params: dict
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)


class ZhipuChatModel(MaxKBBaseModel, BaseChatOpenAI):

@staticmethod
def is_cache_model():
Expand All @@ -39,69 +33,23 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
zhipuai_chat = ZhipuChatModel(
api_key=model_credential.get('api_key'),
model=model_name,
base_url='https://open.bigmodel.cn/api/paas/v4',
extra_body=optional_params,
streaming=model_kwargs.get('streaming', False),
optional_params=optional_params,
**optional_params,
custom_get_token_ids=custom_get_token_ids
)
return zhipuai_chat

usage_metadata: dict = {}

def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.usage_metadata

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return self.usage_metadata.get('prompt_tokens', 0)
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])

def get_num_tokens(self, text: str) -> int:
return self.usage_metadata.get('completion_tokens', 0)

def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the chat response in chunks."""
if self.zhipuai_api_key is None:
raise ValueError("Did not find zhipuai_api_key.")
if self.zhipuai_api_base is None:
raise ValueError("Did not find zhipu_api_base.")
message_dicts, params = self._create_message_dicts(messages, stop)
payload = {**params, **kwargs, **self.optional_params, "messages": message_dicts, "stream": True}
_truncate_params(payload)
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),
"Accept": "application/json",
}

default_chunk_class = AIMessageChunk
import httpx

with httpx.Client(headers=headers, timeout=60) as client:
with connect_sse(
client, "POST", self.zhipuai_api_base, json=payload
) as event_source:
for sse in event_source.iter_sse():
chunk = json.loads(sse.data)
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
generation_info = {}
if "usage" in chunk:
generation_info = chunk["usage"]
self.usage_metadata = generation_info
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason", None)

chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
if finish_reason is not None:
break
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No significant changes have been identified in the provided code snippet, but here are some minor improvements or clarifications:

  • Remove self.optional_params assignment as it's not used anywhere in the class.

  • Adjust the usage of token counting methods to use consistent approach between classes.

Here's a revised version of the code with these corrections applied:

# Import necessary libraries and classes
import json

from typing import Dict, List

from langchain_core.messages import BaseMessage
from common.config.tokenizer_manage_config import TokenizerManage


class ZhipuChatModel(MaxKBBaseModel, BaseChatOpenAI):

    @staticmethod
    def is_cache_model():
        pass
    
    def __init__(self, *args, max_length=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.max_length = max_length

    def _token_count(self, text: str) -> int:
            # Use the default method provided by BaseChatOpenAI if available
            try:
                return super().get_num_tokens(text)
            except Exception as e:
                tokenizer = TokenizerManage.get_tokenizer()
                tokens = tokenizer.encode(text)
                return min(len(tokens), self.max_length) if max_length else len(tokens)

    def generate_responses(self, messages: List[BaseMessage]) -> List[str]:
        total_tokens = 0
        responses = []
        for message in messages:
            token_count = self._token_count(message.content)
            total_tokens += token_count
            
            # Perform your logic here based on whether you want to split or concatenate input data
        
        # Return all generated responses or handle further processing

Key Changes:

  1. Initialization: Added max_length parameter during initialization and set it to an instance variable.

  2. Token Count Calculation: Modified _token_count method to first call the superclass’s method (super().get_num_tokens(text)), which should be more reliable than manually encoding texts via TokenizerManage (since it might have additional optimizations).

  3. Replaced Custom Method: Removed the custom custom_get_token_ids method since we can now directly get a list of IDs using the same technique.

This change will ensure that the token count reflects the actual input length rather than potentially over-counting due to the use of multiple encoders. This may make future modifications easier without changing the core logic too much.

Loading