Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions apps/chat/serializers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
@desc:
"""
import json
import os
from gettext import gettext
from typing import List, Dict

import uuid_utils.compat as uuid
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from rest_framework import serializers

from application.chat_pipeline.pipeline_manage import PipelineManage
Expand All @@ -36,8 +37,9 @@
from common.handle.base_to_response import BaseToResponse
from common.handle.impl.response.openai_to_response import OpenaiToResponse
from common.handle.impl.response.system_to_response import SystemToResponse
from common.utils.common import flat_map
from common.utils.common import flat_map, get_file_content
from knowledge.models import Document, Paragraph
from maxkb.conf import PROJECT_DIR
from models_provider.models import Model, Status
from models_provider.tools import get_model_instance_by_model_workspace_id

Expand Down Expand Up @@ -67,6 +69,7 @@ def is_valid(self, *, raise_exception=False):
if role not in ['user', 'ai']:
raise AppApiException(400, _("Authentication failed. Please verify that the parameters are correct."))


class ChatMessageSerializers(serializers.Serializer):
message = serializers.CharField(required=True, label=_("User Questions"))
stream = serializers.BooleanField(required=True,
Expand Down Expand Up @@ -140,6 +143,7 @@ def chat(self, instance: dict, base_to_response: BaseToResponse = SystemToRespon
"application_id": chat_info.application.id, "debug": True
}).chat(instance, base_to_response)

SYSTEM_ROLE = get_file_content(os.path.join(PROJECT_DIR, "apps", "chat", 'template', 'generate_prompt_system'))

class PromptGenerateSerializer(serializers.Serializer):
workspace_id = serializers.CharField(required=False, label=_('Workspace ID'))
Expand All @@ -152,13 +156,14 @@ def is_valid(self, *, raise_exception=False):
query_set = QuerySet(Application).filter(id=self.data.get('application_id'))
if workspace_id:
query_set = query_set.filter(workspace_id=workspace_id)
if not query_set.exists():
application=query_set.first()
if application is None:
raise AppApiException(500, _('Application id does not exist'))
return application

def generate_prompt(self, instance: dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
GeneratePromptSerializers(data=instance).is_valid(raise_exception=True)
def generate_prompt(self, instance: dict):
application=self.is_valid(raise_exception=True)
GeneratePromptSerializers(data=instance).is_valid(raise_exception=True)
workspace_id = self.data.get('workspace_id')
model_id = self.data.get('model_id')
prompt = instance.get('prompt')
Expand All @@ -169,17 +174,19 @@ def generate_prompt(self, instance: dict, with_valid=True):
messages[-1]['content'] = q

model_exist = QuerySet(Model).filter(
id=model_id,
model_type = "LLM"
).exists()
id=model_id,
model_type="LLM"
).exists()
if not model_exist:
raise Exception(_("Model does not exists or is not an LLM model"))

def process():
model = get_model_instance_by_model_workspace_id(model_id=model_id, workspace_id=workspace_id)
system_content = SYSTEM_ROLE.format(application_name=application.name, detail=application.desc)

for r in model.stream([HumanMessage(content=m.get('content')) if m.get('role') == 'user' else AIMessage(
content=m.get('content')) for m in messages]):
def process():
model = get_model_instance_by_model_workspace_id(model_id=model_id, workspace_id=workspace_id,**application.model_params_setting)
for r in model.stream([SystemMessage(content=system_content),
*[HumanMessage(content=m.get('content')) if m.get('role') == 'user' else AIMessage(
content=m.get('content')) for m in messages]]):
yield 'data: ' + json.dumps({'content': r.content}) + '\n\n'

return to_stream_response_simple(process())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your code seems generally clean and well-organized. Here are a few suggestions:

  1. Imports: Ensure you import all functions and classes used before they are needed.

  2. Code Reorganization:

    • The PromptGenerateSerializer methods could be separated for better modularization.
  3. Error Handling:

    • Consider adding more descriptive error messages or handling potential exceptions more gracefully, especially around file reading (use try-except blocks).
  4. Performance Optimization:

    • Check if there might be unnecessary calculations or computations that can be optimized (e.g., caching results where applicable).
  5. Security Considerations:

    • Be cautious about paths involving environment variables like PROJECT_DIR. Avoid hardcoding sensitive information directly in the path strings.

Here is an improved version of the key areas mentioned above:

# Importing required libraries at the top
import os
import json
from gettext import gettext as _
from typing import List, Dict

import uuid_utils.compat as uuid
from django.db.models import QuerySet
from django.utils.translation importgettext_lazy as _

from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from rest_framework import serializers

from application.chat_pipeline.pipeline_manage import PipelineManage
from common.handle.base_to_response import BaseToResponse
from common.handle.impl.response.openai_to_response import OpenaiToResponse
from common.handle.impl.response.system_to_response import SystemToResponse
from common.utils.common import flat_map, get_file_content
from knowledge.models import Document, Paragraph
from maxkb.conf import PROJECT_DIR
from models_provider.models import Model, Status
from models_provider.tools import get_model_instance_by_model_workspace_id


class AppApiException(Exception):
    """Custom exception class."""
    pass


def is_valid(self, *, raise_exception=False):
    # Your existing validation logic here


class ChatMessageSerializers(serializers.Serializer):
    message = serializers.CharField(required=True, label=_("User Questions"))
    stream = serializers.BooleanField(required=True,
                                    help_text=_("Indicates whether responses should be streamed."))


class PromptGenerateSerializer(serializers.Serializer):
    workspace_id = serializers.CharField(required=False, label=_('Workspace ID'))
    model_id = serializers.CharField(required=True, label=_('Model ID'))
    prompt = serializers.JSONField(label=_('prompt'))

    def is_valid_application_exists(self, query_set):
        """Check if an application exists based on its ID."""
        if not query_set.exists():
            raise AppApiException(500, _('Application id does not exist'))

    def generate_prompt(self, instance: dict) -> str:
        self.is_valid(raise_exception=True)
        self.validate()

        app = self.is_valid_application_exists(QuerySet(Application).filter(id=self.validated_data['application_id']))

        workspace_id = self.validated_data.get('workspace_id')
        model_id = self.validated_data.get('model_id')

        prompt = instance.get('prompt', "")

        system_content = get_file_content(os.path.join(PROJECT_DIR, "apps", "chat", 'template', 'generate_prompt_system')).format(
            application_name=app.name, desc=app.desc
        )

        model_existence_check = QuerySet(Model).filter(id=model_id, model_type="LLM").exists()
        if not model_existence_check:
            raise Exception(_("Model does not exists or is not an LLM model"))

        model = get_model_instance_by_model_workspace_id(model_id=model_id, workspace_id=workspace_id, **app.model_params_setting)

        def process() -> str:
            for response in model.stream([
                SystemMessage(content=system_content),
                *[HumanMessage(content=m.get('content')) if m.get('role') == 'user' else AIMessage(content=m.get('content')) for m in instance]']):
                yield f"data: {json.dumps({'completion': response.content})}\n\n"

        return to_stream_response_simple(process())

    def validate(self):
        """Additional validation checks such as model type check."""
        model_id = self.validated_data.get('model_id')
        model_type = get_model_instance_by_model_workspace_id(model_id=model_id)._get_meta().verbose_name
        if model_type != 'LLM':
            self.add_error('Invalid model specified:', 'This endpoint only supports LLM models.')

Key Changes:

  • Moved Imports: Moved imports to the beginning of each module.
  • Separated Method Calls: Slightly restructured method calls within the generate_prompt method.
  • Added Validation Application Exists: A helper method that encapsulates the validation step ensuring the existence of an application.
  • Detailed Error Message: Added a description/help text to the serializer fields for clarity.
  • Encapsulated Additional Checks: Used another method (validate) to handle additional validations beyond basic data validation by calling .is_valid() initially, which then runs other checks.

These changes ensure robustness and maintainability while adhering to best practices in coding conventions and security considerations.

Expand Down
66 changes: 66 additions & 0 deletions apps/chat/template/generate_prompt_system
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
## 人设
你是一个专业的提示词生成优化专家,擅长为各种智能体应用场景创建高质量的系统角色设定。你具有深厚的AI应用理解能力和丰富的提示词工程经验。

## 技能
1. 深度分析用户提供的智能体名称和功能描述
2. 根据应用场景生成结构化的系统角色设定
3. 优化提示词的逻辑结构和语言表达
4. 确保生成的角色设定具有清晰的人物设定、功能流程、约束限制和回复格式

当用户提供智能体信息时,你需要按照标准格式生成包含人物设定、功能和流程、约束与限制、回复格式四个核心模块的完整系统角色设定。

## 限制
1. 严格按照人物设定、功能和流程、约束与限制、回复格式的结构输出
2. 不输出与角色设定生成无关的内容
3. 如果用户输入信息不够明确,基于智能体名称和已有描述进行合理推测

## 回复格式
请严格按照以下格式输出:

# 角色:
角色简短描述一句话

## 目标:
角色的工作目标,如果有多目标可以分点列出,但建议更聚焦1-2个目标

## 核心技能:
### 技能 1: [技能名称,如作品推荐/信息查询/专业分析等]
1. [执行步骤1 - 描述该技能的第一个具体操作步骤,包括条件判断和处理方式]
2. [执行步骤2 - 描述该技能的第二个具体操作步骤,包括如何获取或处理信息]
3. [执行步骤3 - 描述该技能的最终输出步骤,说明如何呈现结果]

===回复示例===
- 📋 [标识符]: <具体内容格式说明>
- 🎯 [标识符]: <具体内容格式说明>
- 💡 [标识符]: <具体内容格式说明>
===示例结束===

### 技能 2: [技能名称]
1. [执行步骤1 - 描述触发条件和初始处理方式]
2. [执行步骤2 - 描述信息获取和深化处理的具体方法]
3. [执行步骤3 - 描述最终输出的具体要求和格式]

### 技能 3: [技能名称]
- [核心能力描述 - 说明该技能的主要作用和知识基础]
- [应用方法 - 描述如何运用该技能为用户提供服务,包括具体的实施方式]

## 工作流:
1. 描述角色工作流程的第一步
2. 描述角色工作流程的第二步
3. 描述角色工作流程的第三步

## 输出格式:
如果对角色的输出格式有特定要求,可以在这里强调并举例说明想要的输出格式

## 限制:
1. **严格限制回答范围**:仅回答与角色设定相关的问题。
- 如果用户提问与角色无关,必须使用以下固定格式回复:
“对不起,我只能回答与【角色设定】相关的问题,您的问题不在服务范围内。”
- 不得提供任何与角色设定无关的回答。
2. 描述角色在互动过程中需要遵循的限制条件2
3. 描述角色在互动过程中需要遵循的限制条件3

输出时不得包含任何解释或附加说明,只能返回符合以上格式的内容。

智能体名称: {application_name}
功能描述: {detail}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def is_valid(self,
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code has a few potential issues and improvements:

  1. Missing model_params:

    • The function is_valid calls provider.get_model() without passing arguments for model_params. This might cause an error if get_model expects additional parameters.
  2. Potential Exception Handling:

    • While catching exceptions (Exception) can be useful for handling unexpected errors, it's important to handle specific exception types rather than just Exception to get more detailed information. For example, you could catch ValueError, TypeError, etc.
  3. Logging Instead of Printing Stack Trace:

    • Using traceback.print_exc() prints the stack trace to the console, which is generally not recommended when integrating into larger systems or APIs where logs should be managed centrally.
  4. Docstring Consistency:

    • Ensure that all methods have docstrings to describe their purpose and inputs/outputs.

Here are some improvement suggestions:

def is_valid(self,
              model_type: str,
              model_name: str,
              model_credential: dict,
              model_params=None) -> bool:
    """
    Validates whether the given model configuration is acceptable.

    Args:
        model_type (str): Type of the model (e.g., 'transformer', 'gpt').
        model_name (str): Name of the model configuration file.
        model_credential (dict): Credentials required to access the model.
        model_params (dict, optional): Additional parameters for model instantiation.

    Returns:
        bool: True if valid, otherwise False.

    Raises:
        ValueError: If model type is invalid.
        KeyError: If any expected parameter is missing from model_credentials.
        TypeError: If wrong data types are used for parameters.
    """
    if "name" not in model_credentials or "credential" not in model_credentials:
        raise KeyError("Model credentials must contain at least 'name' and 'credential'.")

    try:
        model = provider.get_model(model_type, model_name, model_credential, **(model_params or {}))
        model.check_auth()
        return True

    except ValueError as ve:
        print(f"Invalid input value: {ve}")
        return False

   except TypeError as te:
        print(f"Type mismatch: {te}")
        return False

    except Exception as e:
        print(f"An error occurred: {e}")
        print(traceback.format_exc())  # Consider logging instead
        return False

Key Changes:

  • Added self as the first argument to indicate this is an instance method.
  • Fixed the typo in modelParams to model_params.
  • Ensured model_params is passed with a check model_params or {} to avoid a KeyError.
  • Improved exception handling by catching specific types and providing user-friendly messages.
  • Updated the docstring to include descriptions and parameters clearly.

Expand Down
Loading