Skip to content

Commit ee9e963

Browse files
committed
feat: Intent classify
1 parent bd668e7 commit ee9e963

File tree

18 files changed

+861
-1
lines changed

18 files changed

+861
-1
lines changed

apps/application/flow/step_node/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@
2424
from .tool_lib_node import *
2525
from .tool_node import *
2626
from .variable_assign_node import BaseVariableAssignNode
27+
from .intent_node import *
2728

2829
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchKnowledgeNode, BaseQuestionNode,
2930
BaseConditionNode, BaseReplyNode,
3031
BaseToolNodeNode, BaseToolLibNodeNode, BaseRerankerNode, BaseApplicationNode,
3132
BaseDocumentExtractNode,
3233
BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode,
33-
BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode]
34+
BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode,BaseIntentNode]
3435

3536

3637
def get_node(node_type):
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# coding=utf-8
2+
3+
4+
5+
6+
from .impl import *
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# coding=utf-8
2+
3+
from typing import Type
4+
5+
from django.utils.translation import gettext_lazy as _
6+
from rest_framework import serializers
7+
8+
from application.flow.i_step_node import INode, NodeResult
9+
10+
11+
class IntentBranchSerializer(serializers.Serializer):
12+
13+
id = serializers.CharField(required=True, label=_("Branch id"))
14+
content = serializers.CharField(required=True, label=_("content"))
15+
isOther = serializers.BooleanField(required=True, label=_("Branch Type"))
16+
17+
18+
class IntentNodeSerializer(serializers.Serializer):
19+
model_id = serializers.CharField(required=True, label=_("Model id"))
20+
content_list = serializers.ListField(required=True, label=_("Text content"))
21+
dialogue_number = serializers.IntegerField(required=True, label=
22+
_("Number of multi-round conversations"))
23+
model_params_setting = serializers.DictField(required=False,
24+
label=_("Model parameter settings"))
25+
branch = IntentBranchSerializer(many=True)
26+
27+
class IIntentNode(INode):
28+
type = 'intent-node'
29+
def save_context(self, details, workflow_manage):
30+
pass
31+
32+
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
33+
return IntentNodeSerializer
34+
35+
def _run(self):
36+
question = self.workflow_manage.get_reference_field(
37+
self.node_params_serializer.data.get('content_list')[0],
38+
self.node_params_serializer.data.get('content_list')[1:],
39+
)
40+
41+
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data, user_input=str(question))
42+
43+
44+
def execute(self, model_id, dialogue_number, history_chat_record, user_input, branch,
45+
model_params_setting=None, **kwargs) -> NodeResult:
46+
pass
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
2+
3+
from .base_intent_node import BaseIntentNode
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
# coding=utf-8
2+
import json
3+
import re
4+
import time
5+
from typing import List, Dict, Any
6+
from functools import reduce
7+
8+
from django.db.models import QuerySet
9+
from langchain.schema import HumanMessage, SystemMessage
10+
11+
from application.flow.i_step_node import INode, NodeResult
12+
from application.flow.step_node.intent_node.i_intent_node import IIntentNode
13+
from models_provider.models import Model
14+
from models_provider.tools import get_model_instance_by_model_workspace_id, get_model_credential
15+
from .prompt_template import PROMPT_TEMPLATE
16+
17+
def get_default_model_params_setting(model_id):
18+
19+
model = QuerySet(Model).filter(id=model_id).first()
20+
credential = get_model_credential(model.provider, model.model_type, model.model_name)
21+
model_params_setting = credential.get_model_params_setting_form(
22+
model.model_name).get_default_form_data()
23+
return model_params_setting
24+
25+
26+
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
27+
28+
chat_model = node_variable.get('chat_model')
29+
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
30+
answer_tokens = chat_model.get_num_tokens(answer)
31+
32+
node.context['message_tokens'] = message_tokens
33+
node.context['answer_tokens'] = answer_tokens
34+
node.context['answer'] = answer
35+
node.context['history_message'] = node_variable['history_message']
36+
node.context['user_input'] = node_variable['user_input']
37+
node.context['branch_id'] = node_variable.get('branch_id')
38+
node.context['reason'] = node_variable.get('reason')
39+
node.context['category'] = node_variable.get('category')
40+
node.context['run_time'] = time.time() - node.context['start_time']
41+
42+
43+
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
44+
45+
response = node_variable.get('result')
46+
answer = response.content
47+
_write_context(node_variable, workflow_variable, node, workflow, answer)
48+
49+
50+
class BaseIntentNode(IIntentNode):
51+
52+
53+
def save_context(self, details, workflow_manage):
54+
55+
self.context['branch_id'] = details.get('branch_id')
56+
self.context['category'] = details.get('category')
57+
58+
59+
def execute(self, model_id, dialogue_number, history_chat_record, user_input, branch,
60+
model_params_setting=None, **kwargs) -> NodeResult:
61+
62+
# 设置默认模型参数
63+
if model_params_setting is None:
64+
model_params_setting = get_default_model_params_setting(model_id)
65+
66+
# 获取模型实例
67+
workspace_id = self.workflow_manage.get_body().get('workspace_id')
68+
chat_model = get_model_instance_by_model_workspace_id(
69+
model_id, workspace_id, **model_params_setting
70+
)
71+
72+
# 获取历史对话
73+
history_message = self.get_history_message(history_chat_record, dialogue_number)
74+
self.context['history_message'] = history_message
75+
76+
# 保存问题到上下文
77+
self.context['user_input'] = user_input
78+
79+
# 构建分类提示词
80+
prompt = self.build_classification_prompt(user_input, branch)
81+
82+
83+
# 生成消息列表
84+
system = self.build_system_prompt()
85+
message_list = self.generate_message_list(system, prompt, history_message)
86+
self.context['message_list'] = message_list
87+
88+
# 调用模型进行分类
89+
try:
90+
r = chat_model.invoke(message_list)
91+
classification_result = r.content.strip()
92+
93+
# 解析分类结果获取分支信息
94+
matched_branch = self.parse_classification_result(classification_result, branch)
95+
96+
# 返回结果
97+
return NodeResult({
98+
'result': r,
99+
'chat_model': chat_model,
100+
'message_list': message_list,
101+
'history_message': history_message,
102+
'user_input': user_input,
103+
'branch_id': matched_branch['id'],
104+
'reason': json.loads(r.content).get('reason'),
105+
'category': matched_branch.get('content', matched_branch['id'])
106+
}, {}, _write_context=write_context)
107+
108+
except Exception as e:
109+
# 错误处理:返回"其他"分支
110+
other_branch = self.find_other_branch(branch)
111+
if other_branch:
112+
return NodeResult({
113+
'branch_id': other_branch['id'],
114+
'category': other_branch.get('content', other_branch['id']),
115+
'error': str(e)
116+
}, {})
117+
else:
118+
raise Exception(f"error: {str(e)}")
119+
120+
@staticmethod
121+
def get_history_message(history_chat_record, dialogue_number):
122+
"""获取历史消息"""
123+
start_index = len(history_chat_record) - dialogue_number
124+
history_message = reduce(lambda x, y: [*x, *y], [
125+
[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
126+
for index in
127+
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
128+
129+
for message in history_message:
130+
if isinstance(message.content, str):
131+
message.content = re.sub('<form_rander>[\d\D]*?<\/form_rander>', '', message.content)
132+
return history_message
133+
134+
135+
def build_system_prompt(self) -> str:
136+
"""构建系统提示词"""
137+
return "你是一个专业的意图识别助手,请根据用户输入和意图选项,准确识别用户的真实意图。"
138+
139+
def build_classification_prompt(self, user_input: str, branch: List[Dict]) -> str:
140+
"""构建分类提示词"""
141+
142+
classification_list = []
143+
144+
other_branch = self.find_other_branch(branch)
145+
# 添加其他分支
146+
if other_branch:
147+
classification_list.append({
148+
"classificationId": 0,
149+
"content": other_branch.get('content')
150+
})
151+
# 添加正常分支
152+
classification_id = 1
153+
for b in branch:
154+
if not b.get('isOther'):
155+
classification_list.append({
156+
"classificationId": classification_id,
157+
"content": b['content']
158+
})
159+
classification_id += 1
160+
161+
return PROMPT_TEMPLATE.format(
162+
classification_list=classification_list,
163+
user_input=user_input
164+
)
165+
166+
167+
def generate_message_list(self, system: str, prompt: str, history_message):
168+
"""生成消息列表"""
169+
if system is None or len(system) == 0:
170+
return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))]
171+
else:
172+
return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
173+
HumanMessage(self.workflow_manage.generate_prompt(prompt))]
174+
175+
def parse_classification_result(self, result: str, branch: List[Dict]) -> Dict[str, Any]:
176+
"""解析分类结果"""
177+
178+
other_branch = self.find_other_branch(branch)
179+
normal_intents = [
180+
b
181+
for b in branch
182+
if not b.get('isOther')
183+
]
184+
185+
def get_branch_by_id(category_id: int):
186+
if category_id == 0:
187+
return other_branch
188+
elif 1 <= category_id <= len(normal_intents):
189+
return normal_intents[category_id - 1]
190+
return None
191+
192+
try:
193+
result_json = json.loads(result)
194+
classification_id = result_json.get('classificationId', 0) # 0 兜底
195+
# 如果是 0 ,返回其他分支
196+
matched_branch = get_branch_by_id(classification_id)
197+
if matched_branch:
198+
return matched_branch
199+
200+
except Exception as e:
201+
# json 解析失败,re 提取
202+
numbers = re.findall(r'"classificationId":\s*(\d+)', result)
203+
if numbers:
204+
classification_id = int(numbers[0])
205+
206+
matched_branch = get_branch_by_id(classification_id)
207+
if matched_branch:
208+
return matched_branch
209+
210+
# 如果都解析失败,返回“other”
211+
return other_branch or (normal_intents[0] if normal_intents else {'id': 'unknown', 'content': 'unknown'})
212+
213+
214+
def find_other_branch(self, branch: List[Dict]) -> Dict[str, Any] | None:
215+
"""查找其他分支"""
216+
for b in branch:
217+
if b.get('isOther'):
218+
return b
219+
return None
220+
221+
222+
def get_details(self, index: int, **kwargs):
223+
"""获取节点执行详情"""
224+
return {
225+
'name': self.node.properties.get('stepName'),
226+
'index': index,
227+
'run_time': self.context.get('run_time'),
228+
'system': self.context.get('system'),
229+
'history_message': [
230+
{'content': message.content, 'role': message.type}
231+
for message in (self.context.get('history_message') or [])
232+
],
233+
'user_input': self.context.get('user_input'),
234+
'answer': self.context.get('answer'),
235+
'branch_id': self.context.get('branch_id'),
236+
'category': self.context.get('category'),
237+
'type': self.node.type,
238+
'message_tokens': self.context.get('message_tokens'),
239+
'answer_tokens': self.context.get('answer_tokens'),
240+
'status': self.status,
241+
'err_message': self.err_message
242+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
2+
3+
4+
5+
PROMPT_TEMPLATE = """# Role
6+
You are an intention classification expert, good at being able to judge which classification the user's input belongs to.
7+
8+
## Skills
9+
Skill 1: Clearly determine which of the following intention classifications the user's input belongs to.
10+
Intention classification list:
11+
{classification_list}
12+
13+
Note:
14+
- Please determine the match only between the user's input content and the Intention classification list content, without judging or categorizing the match with the classification ID.
15+
16+
## User Input
17+
{user_input}
18+
19+
## Reply requirements
20+
- The answer must be returned in JSON format.
21+
- Strictly ensure that the output is in a valid JSON format.
22+
- Do not add prefix ```json or suffix ```
23+
- The answer needs to include the following fields such as:
24+
{{
25+
"classificationId": 0,
26+
"reason": ""
27+
}}
28+
29+
## Limit
30+
- Please do not reply in text."""
Lines changed: 18 additions & 0 deletions
Loading

ui/src/enums/application.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ export enum WorkflowType {
2424
SpeechToTextNode = 'speech-to-text-node',
2525
ImageGenerateNode = 'image-generate-node',
2626
McpNode = 'mcp-node',
27+
IntentNode = 'intent-node',
2728
}

ui/src/locales/lang/en-US/common.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ export default {
4747
noData: 'No data',
4848
result: 'Result',
4949
remove: 'Remove',
50+
classify: 'Classify',
51+
reason: 'Reason',
5052
removeSuccess: 'Successful',
5153
searchBar: {
5254
placeholder: 'Search by name',

0 commit comments

Comments
 (0)