1+ # coding=utf-8
2+ import json
3+ import re
4+ import time
5+ from typing import List , Dict , Any
6+ from functools import reduce
7+
8+ from django .db .models import QuerySet
9+ from langchain .schema import HumanMessage , SystemMessage
10+
11+ from application .flow .i_step_node import INode , NodeResult
12+ from application .flow .step_node .intent_node .i_intent_node import IIntentNode
13+ from models_provider .models import Model
14+ from models_provider .tools import get_model_instance_by_model_workspace_id , get_model_credential
15+ from .prompt_template import PROMPT_TEMPLATE
16+
17+ def get_default_model_params_setting (model_id ):
18+
19+ model = QuerySet (Model ).filter (id = model_id ).first ()
20+ credential = get_model_credential (model .provider , model .model_type , model .model_name )
21+ model_params_setting = credential .get_model_params_setting_form (
22+ model .model_name ).get_default_form_data ()
23+ return model_params_setting
24+
25+
26+ def _write_context (node_variable : Dict , workflow_variable : Dict , node : INode , workflow , answer : str ):
27+
28+ chat_model = node_variable .get ('chat_model' )
29+ message_tokens = chat_model .get_num_tokens_from_messages (node_variable .get ('message_list' ))
30+ answer_tokens = chat_model .get_num_tokens (answer )
31+
32+ node .context ['message_tokens' ] = message_tokens
33+ node .context ['answer_tokens' ] = answer_tokens
34+ node .context ['answer' ] = answer
35+ node .context ['history_message' ] = node_variable ['history_message' ]
36+ node .context ['user_input' ] = node_variable ['user_input' ]
37+ node .context ['branch_id' ] = node_variable .get ('branch_id' )
38+ node .context ['reason' ] = node_variable .get ('reason' )
39+ node .context ['category' ] = node_variable .get ('category' )
40+ node .context ['run_time' ] = time .time () - node .context ['start_time' ]
41+
42+
43+ def write_context (node_variable : Dict , workflow_variable : Dict , node : INode , workflow ):
44+
45+ response = node_variable .get ('result' )
46+ answer = response .content
47+ _write_context (node_variable , workflow_variable , node , workflow , answer )
48+
49+
50+ class BaseIntentNode (IIntentNode ):
51+
52+
53+ def save_context (self , details , workflow_manage ):
54+
55+ self .context ['branch_id' ] = details .get ('branch_id' )
56+ self .context ['category' ] = details .get ('category' )
57+
58+
59+ def execute (self , model_id , dialogue_number , history_chat_record , user_input , branch ,
60+ model_params_setting = None , ** kwargs ) -> NodeResult :
61+
62+ # 设置默认模型参数
63+ if model_params_setting is None :
64+ model_params_setting = get_default_model_params_setting (model_id )
65+
66+ # 获取模型实例
67+ workspace_id = self .workflow_manage .get_body ().get ('workspace_id' )
68+ chat_model = get_model_instance_by_model_workspace_id (
69+ model_id , workspace_id , ** model_params_setting
70+ )
71+
72+ # 获取历史对话
73+ history_message = self .get_history_message (history_chat_record , dialogue_number )
74+ self .context ['history_message' ] = history_message
75+
76+ # 保存问题到上下文
77+ self .context ['user_input' ] = user_input
78+
79+ # 构建分类提示词
80+ prompt = self .build_classification_prompt (user_input , branch )
81+
82+
83+ # 生成消息列表
84+ system = self .build_system_prompt ()
85+ message_list = self .generate_message_list (system , prompt , history_message )
86+ self .context ['message_list' ] = message_list
87+
88+ # 调用模型进行分类
89+ try :
90+ r = chat_model .invoke (message_list )
91+ classification_result = r .content .strip ()
92+
93+ # 解析分类结果获取分支信息
94+ matched_branch = self .parse_classification_result (classification_result , branch )
95+
96+ # 返回结果
97+ return NodeResult ({
98+ 'result' : r ,
99+ 'chat_model' : chat_model ,
100+ 'message_list' : message_list ,
101+ 'history_message' : history_message ,
102+ 'user_input' : user_input ,
103+ 'branch_id' : matched_branch ['id' ],
104+ 'reason' : json .loads (r .content ).get ('reason' ),
105+ 'category' : matched_branch .get ('content' , matched_branch ['id' ])
106+ }, {}, _write_context = write_context )
107+
108+ except Exception as e :
109+ # 错误处理:返回"其他"分支
110+ other_branch = self .find_other_branch (branch )
111+ if other_branch :
112+ return NodeResult ({
113+ 'branch_id' : other_branch ['id' ],
114+ 'category' : other_branch .get ('content' , other_branch ['id' ]),
115+ 'error' : str (e )
116+ }, {})
117+ else :
118+ raise Exception (f"error: { str (e )} " )
119+
120+ @staticmethod
121+ def get_history_message (history_chat_record , dialogue_number ):
122+ """获取历史消息"""
123+ start_index = len (history_chat_record ) - dialogue_number
124+ history_message = reduce (lambda x , y : [* x , * y ], [
125+ [history_chat_record [index ].get_human_message (), history_chat_record [index ].get_ai_message ()]
126+ for index in
127+ range (start_index if start_index > 0 else 0 , len (history_chat_record ))], [])
128+
129+ for message in history_message :
130+ if isinstance (message .content , str ):
131+ message .content = re .sub ('<form_rander>[\d\D]*?<\/form_rander>' , '' , message .content )
132+ return history_message
133+
134+
135+ def build_system_prompt (self ) -> str :
136+ """构建系统提示词"""
137+ return "你是一个专业的意图识别助手,请根据用户输入和意图选项,准确识别用户的真实意图。"
138+
139+ def build_classification_prompt (self , user_input : str , branch : List [Dict ]) -> str :
140+ """构建分类提示词"""
141+
142+ classification_list = []
143+
144+ other_branch = self .find_other_branch (branch )
145+ # 添加其他分支
146+ if other_branch :
147+ classification_list .append ({
148+ "classificationId" : 0 ,
149+ "content" : other_branch .get ('content' )
150+ })
151+ # 添加正常分支
152+ classification_id = 1
153+ for b in branch :
154+ if not b .get ('isOther' ):
155+ classification_list .append ({
156+ "classificationId" : classification_id ,
157+ "content" : b ['content' ]
158+ })
159+ classification_id += 1
160+
161+ return PROMPT_TEMPLATE .format (
162+ classification_list = classification_list ,
163+ user_input = user_input
164+ )
165+
166+
167+ def generate_message_list (self , system : str , prompt : str , history_message ):
168+ """生成消息列表"""
169+ if system is None or len (system ) == 0 :
170+ return [* history_message , HumanMessage (self .workflow_manage .generate_prompt (prompt ))]
171+ else :
172+ return [SystemMessage (self .workflow_manage .generate_prompt (system )), * history_message ,
173+ HumanMessage (self .workflow_manage .generate_prompt (prompt ))]
174+
175+ def parse_classification_result (self , result : str , branch : List [Dict ]) -> Dict [str , Any ]:
176+ """解析分类结果"""
177+
178+ other_branch = self .find_other_branch (branch )
179+ normal_intents = [
180+ b
181+ for b in branch
182+ if not b .get ('isOther' )
183+ ]
184+
185+ def get_branch_by_id (category_id : int ):
186+ if category_id == 0 :
187+ return other_branch
188+ elif 1 <= category_id <= len (normal_intents ):
189+ return normal_intents [category_id - 1 ]
190+ return None
191+
192+ try :
193+ result_json = json .loads (result )
194+ classification_id = result_json .get ('classificationId' , 0 ) # 0 兜底
195+ # 如果是 0 ,返回其他分支
196+ matched_branch = get_branch_by_id (classification_id )
197+ if matched_branch :
198+ return matched_branch
199+
200+ except Exception as e :
201+ # json 解析失败,re 提取
202+ numbers = re .findall (r'"classificationId":\s*(\d+)' , result )
203+ if numbers :
204+ classification_id = int (numbers [0 ])
205+
206+ matched_branch = get_branch_by_id (classification_id )
207+ if matched_branch :
208+ return matched_branch
209+
210+ # 如果都解析失败,返回“other”
211+ return other_branch or (normal_intents [0 ] if normal_intents else {'id' : 'unknown' , 'content' : 'unknown' })
212+
213+
214+ def find_other_branch (self , branch : List [Dict ]) -> Dict [str , Any ] | None :
215+ """查找其他分支"""
216+ for b in branch :
217+ if b .get ('isOther' ):
218+ return b
219+ return None
220+
221+
222+ def get_details (self , index : int , ** kwargs ):
223+ """获取节点执行详情"""
224+ return {
225+ 'name' : self .node .properties .get ('stepName' ),
226+ 'index' : index ,
227+ 'run_time' : self .context .get ('run_time' ),
228+ 'system' : self .context .get ('system' ),
229+ 'history_message' : [
230+ {'content' : message .content , 'role' : message .type }
231+ for message in (self .context .get ('history_message' ) or [])
232+ ],
233+ 'user_input' : self .context .get ('user_input' ),
234+ 'answer' : self .context .get ('answer' ),
235+ 'branch_id' : self .context .get ('branch_id' ),
236+ 'category' : self .context .get ('category' ),
237+ 'type' : self .node .type ,
238+ 'message_tokens' : self .context .get ('message_tokens' ),
239+ 'answer_tokens' : self .context .get ('answer_tokens' ),
240+ 'status' : self .status ,
241+ 'err_message' : self .err_message
242+ }
0 commit comments