-
Notifications
You must be signed in to change notification settings - Fork 2.6k
feat: Intent classify #4026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Intent classify #4026
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| # coding=utf-8 | ||
|
|
||
|
|
||
|
|
||
|
|
||
| from .impl import * |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| # coding=utf-8 | ||
|
|
||
| from typing import Type | ||
|
|
||
| from django.utils.translation import gettext_lazy as _ | ||
| from rest_framework import serializers | ||
|
|
||
| from application.flow.i_step_node import INode, NodeResult | ||
|
|
||
|
|
||
| class IntentBranchSerializer(serializers.Serializer): | ||
|
|
||
| id = serializers.CharField(required=True, label=_("Branch id")) | ||
| content = serializers.CharField(required=True, label=_("content")) | ||
| isOther = serializers.BooleanField(required=True, label=_("Branch Type")) | ||
|
|
||
|
|
||
| class IntentNodeSerializer(serializers.Serializer): | ||
| model_id = serializers.CharField(required=True, label=_("Model id")) | ||
| content_list = serializers.ListField(required=True, label=_("Text content")) | ||
| dialogue_number = serializers.IntegerField(required=True, label= | ||
| _("Number of multi-round conversations")) | ||
| model_params_setting = serializers.DictField(required=False, | ||
| label=_("Model parameter settings")) | ||
| branch = IntentBranchSerializer(many=True) | ||
|
|
||
| class IIntentNode(INode): | ||
| type = 'intent-node' | ||
| def save_context(self, details, workflow_manage): | ||
| pass | ||
|
|
||
| def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: | ||
| return IntentNodeSerializer | ||
|
|
||
| def _run(self): | ||
| question = self.workflow_manage.get_reference_field( | ||
| self.node_params_serializer.data.get('content_list')[0], | ||
| self.node_params_serializer.data.get('content_list')[1:], | ||
| ) | ||
|
|
||
| return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data, user_input=str(question)) | ||
|
|
||
|
|
||
| def execute(self, model_id, dialogue_number, history_chat_record, user_input, branch, | ||
| model_params_setting=None, **kwargs) -> NodeResult: | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
|
|
||
|
|
||
| from .base_intent_node import BaseIntentNode |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,242 @@ | ||
| # coding=utf-8 | ||
| import json | ||
| import re | ||
| import time | ||
| from typing import List, Dict, Any | ||
| from functools import reduce | ||
|
|
||
| from django.db.models import QuerySet | ||
| from langchain.schema import HumanMessage, SystemMessage | ||
|
|
||
| from application.flow.i_step_node import INode, NodeResult | ||
| from application.flow.step_node.intent_node.i_intent_node import IIntentNode | ||
| from models_provider.models import Model | ||
| from models_provider.tools import get_model_instance_by_model_workspace_id, get_model_credential | ||
| from .prompt_template import PROMPT_TEMPLATE | ||
|
|
||
| def get_default_model_params_setting(model_id): | ||
|
|
||
| model = QuerySet(Model).filter(id=model_id).first() | ||
| credential = get_model_credential(model.provider, model.model_type, model.model_name) | ||
| model_params_setting = credential.get_model_params_setting_form( | ||
| model.model_name).get_default_form_data() | ||
| return model_params_setting | ||
|
|
||
|
|
||
| def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): | ||
|
|
||
| chat_model = node_variable.get('chat_model') | ||
| message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) | ||
| answer_tokens = chat_model.get_num_tokens(answer) | ||
|
|
||
| node.context['message_tokens'] = message_tokens | ||
| node.context['answer_tokens'] = answer_tokens | ||
| node.context['answer'] = answer | ||
| node.context['history_message'] = node_variable['history_message'] | ||
| node.context['user_input'] = node_variable['user_input'] | ||
| node.context['branch_id'] = node_variable.get('branch_id') | ||
| node.context['reason'] = node_variable.get('reason') | ||
| node.context['category'] = node_variable.get('category') | ||
| node.context['run_time'] = time.time() - node.context['start_time'] | ||
|
|
||
|
|
||
| def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): | ||
|
|
||
| response = node_variable.get('result') | ||
| answer = response.content | ||
| _write_context(node_variable, workflow_variable, node, workflow, answer) | ||
|
|
||
|
|
||
| class BaseIntentNode(IIntentNode): | ||
|
|
||
|
|
||
| def save_context(self, details, workflow_manage): | ||
|
|
||
| self.context['branch_id'] = details.get('branch_id') | ||
| self.context['category'] = details.get('category') | ||
|
|
||
|
|
||
| def execute(self, model_id, dialogue_number, history_chat_record, user_input, branch, | ||
| model_params_setting=None, **kwargs) -> NodeResult: | ||
|
|
||
| # 设置默认模型参数 | ||
| if model_params_setting is None: | ||
| model_params_setting = get_default_model_params_setting(model_id) | ||
|
|
||
| # 获取模型实例 | ||
| workspace_id = self.workflow_manage.get_body().get('workspace_id') | ||
| chat_model = get_model_instance_by_model_workspace_id( | ||
| model_id, workspace_id, **model_params_setting | ||
| ) | ||
|
|
||
| # 获取历史对话 | ||
| history_message = self.get_history_message(history_chat_record, dialogue_number) | ||
| self.context['history_message'] = history_message | ||
|
|
||
| # 保存问题到上下文 | ||
| self.context['user_input'] = user_input | ||
|
|
||
| # 构建分类提示词 | ||
| prompt = self.build_classification_prompt(user_input, branch) | ||
|
|
||
|
|
||
| # 生成消息列表 | ||
| system = self.build_system_prompt() | ||
| message_list = self.generate_message_list(system, prompt, history_message) | ||
| self.context['message_list'] = message_list | ||
|
|
||
| # 调用模型进行分类 | ||
| try: | ||
| r = chat_model.invoke(message_list) | ||
| classification_result = r.content.strip() | ||
|
|
||
| # 解析分类结果获取分支信息 | ||
| matched_branch = self.parse_classification_result(classification_result, branch) | ||
|
|
||
| # 返回结果 | ||
| return NodeResult({ | ||
| 'result': r, | ||
| 'chat_model': chat_model, | ||
| 'message_list': message_list, | ||
| 'history_message': history_message, | ||
| 'user_input': user_input, | ||
| 'branch_id': matched_branch['id'], | ||
| 'reason': json.loads(r.content).get('reason'), | ||
| 'category': matched_branch.get('content', matched_branch['id']) | ||
| }, {}, _write_context=write_context) | ||
|
|
||
| except Exception as e: | ||
| # 错误处理:返回"其他"分支 | ||
| other_branch = self.find_other_branch(branch) | ||
| if other_branch: | ||
| return NodeResult({ | ||
| 'branch_id': other_branch['id'], | ||
| 'category': other_branch.get('content', other_branch['id']), | ||
| 'error': str(e) | ||
| }, {}) | ||
| else: | ||
| raise Exception(f"error: {str(e)}") | ||
|
|
||
| @staticmethod | ||
| def get_history_message(history_chat_record, dialogue_number): | ||
| """获取历史消息""" | ||
| start_index = len(history_chat_record) - dialogue_number | ||
| history_message = reduce(lambda x, y: [*x, *y], [ | ||
| [history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] | ||
| for index in | ||
| range(start_index if start_index > 0 else 0, len(history_chat_record))], []) | ||
|
|
||
| for message in history_message: | ||
| if isinstance(message.content, str): | ||
| message.content = re.sub('<form_rander>[\d\D]*?<\/form_rander>', '', message.content) | ||
| return history_message | ||
|
|
||
|
|
||
| def build_system_prompt(self) -> str: | ||
| """构建系统提示词""" | ||
| return "你是一个专业的意图识别助手,请根据用户输入和意图选项,准确识别用户的真实意图。" | ||
|
|
||
| def build_classification_prompt(self, user_input: str, branch: List[Dict]) -> str: | ||
| """构建分类提示词""" | ||
|
|
||
| classification_list = [] | ||
|
|
||
| other_branch = self.find_other_branch(branch) | ||
| # 添加其他分支 | ||
| if other_branch: | ||
| classification_list.append({ | ||
| "classificationId": 0, | ||
| "content": other_branch.get('content') | ||
| }) | ||
| # 添加正常分支 | ||
| classification_id = 1 | ||
| for b in branch: | ||
| if not b.get('isOther'): | ||
| classification_list.append({ | ||
| "classificationId": classification_id, | ||
| "content": b['content'] | ||
| }) | ||
| classification_id += 1 | ||
|
|
||
| return PROMPT_TEMPLATE.format( | ||
| classification_list=classification_list, | ||
| user_input=user_input | ||
| ) | ||
|
|
||
|
|
||
| def generate_message_list(self, system: str, prompt: str, history_message): | ||
| """生成消息列表""" | ||
| if system is None or len(system) == 0: | ||
| return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))] | ||
| else: | ||
| return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message, | ||
| HumanMessage(self.workflow_manage.generate_prompt(prompt))] | ||
|
|
||
| def parse_classification_result(self, result: str, branch: List[Dict]) -> Dict[str, Any]: | ||
| """解析分类结果""" | ||
|
|
||
| other_branch = self.find_other_branch(branch) | ||
| normal_intents = [ | ||
| b | ||
| for b in branch | ||
| if not b.get('isOther') | ||
| ] | ||
|
|
||
| def get_branch_by_id(category_id: int): | ||
| if category_id == 0: | ||
| return other_branch | ||
| elif 1 <= category_id <= len(normal_intents): | ||
| return normal_intents[category_id - 1] | ||
| return None | ||
|
|
||
| try: | ||
| result_json = json.loads(result) | ||
| classification_id = result_json.get('classificationId', 0) # 0 兜底 | ||
| # 如果是 0 ,返回其他分支 | ||
| matched_branch = get_branch_by_id(classification_id) | ||
| if matched_branch: | ||
| return matched_branch | ||
|
|
||
| except Exception as e: | ||
| # json 解析失败,re 提取 | ||
| numbers = re.findall(r'"classificationId":\s*(\d+)', result) | ||
| if numbers: | ||
| classification_id = int(numbers[0]) | ||
|
|
||
| matched_branch = get_branch_by_id(classification_id) | ||
| if matched_branch: | ||
| return matched_branch | ||
|
|
||
| # 如果都解析失败,返回“other” | ||
| return other_branch or (normal_intents[0] if normal_intents else {'id': 'unknown', 'content': 'unknown'}) | ||
|
|
||
|
|
||
| def find_other_branch(self, branch: List[Dict]) -> Dict[str, Any] | None: | ||
| """查找其他分支""" | ||
| for b in branch: | ||
| if b.get('isOther'): | ||
| return b | ||
| return None | ||
|
|
||
|
|
||
| def get_details(self, index: int, **kwargs): | ||
| """获取节点执行详情""" | ||
| return { | ||
| 'name': self.node.properties.get('stepName'), | ||
| 'index': index, | ||
| 'run_time': self.context.get('run_time'), | ||
| 'system': self.context.get('system'), | ||
| 'history_message': [ | ||
| {'content': message.content, 'role': message.type} | ||
| for message in (self.context.get('history_message') or []) | ||
| ], | ||
| 'user_input': self.context.get('user_input'), | ||
| 'answer': self.context.get('answer'), | ||
| 'branch_id': self.context.get('branch_id'), | ||
| 'category': self.context.get('category'), | ||
| 'type': self.node.type, | ||
| 'message_tokens': self.context.get('message_tokens'), | ||
| 'answer_tokens': self.context.get('answer_tokens'), | ||
| 'status': self.status, | ||
| 'err_message': self.err_message | ||
| } | ||
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
|
|
||
|
|
||
|
|
||
|
|
||
| PROMPT_TEMPLATE = """# Role | ||
| You are an intention classification expert, good at being able to judge which classification the user's input belongs to. | ||
|
|
||
| ## Skills | ||
| Skill 1: Clearly determine which of the following intention classifications the user's input belongs to. | ||
| Intention classification list: | ||
| {classification_list} | ||
|
|
||
| Note: | ||
| - Please determine the match only between the user's input content and the Intention classification list content, without judging or categorizing the match with the classification ID. | ||
|
|
||
| ## User Input | ||
| {user_input} | ||
|
|
||
| ## Reply requirements | ||
| - The answer must be returned in JSON format. | ||
| - Strictly ensure that the output is in a valid JSON format. | ||
| - Do not add prefix ```json or suffix ``` | ||
| - The answer needs to include the following fields such as: | ||
| {{ | ||
| "classificationId": 0, | ||
| "reason": "" | ||
| }} | ||
|
|
||
| ## Limit | ||
| - Please do not reply in text.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code has several improvements that can be made to enhance its structure, readability, and maintainability:
Imports and Comments: The code includes a
# noqacomment at the end of the file to suppress flake8 errors related to missing newline characters. This is not necessary in modern Python syntax.Comments: Some comments are redundant or misspelled. They should be revised for clarity and accuracy. For example,
_model_params_setting_could be renamed tomodel_params.Function Documentation: While some functions have docstrings, they are minimal. Adding more detailed descriptions could help others understand their purpose better.
Class Implementations: The class methods need to include actual implementations for logic such as
save_context,get_node_params_serializer_class,_run, andexecute. These will depend on how you plan to handle these functionalities within your application.Error Handling: If the implementation requires handling exceptions or specific conditions, appropriate error messages or exception handling should be added.
Here's an updated version of the code with improved comments, function documentation, and potential implementation suggestions:
Key Changes:
async) where applicable.