Skip to content

Commit f511422

Browse files
committed
feat: Extract reason using regex
1 parent 584850b commit f511422

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

apps/application/flow/step_node/intent_node/impl/base_intent_node.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def execute(self, model_id, dialogue_number, history_chat_record, user_input, br
8989
try:
9090
r = chat_model.invoke(message_list)
9191
classification_result = r.content.strip()
92-
9392
# 解析分类结果获取分支信息
9493
matched_branch = self.parse_classification_result(classification_result, branch)
9594

@@ -101,7 +100,7 @@ def execute(self, model_id, dialogue_number, history_chat_record, user_input, br
101100
'history_message': history_message,
102101
'user_input': user_input,
103102
'branch_id': matched_branch['id'],
104-
'reason': json.loads(r.content).get('reason'),
103+
'reason': self.parse_result_reason(r.content),
105104
'category': matched_branch.get('content', matched_branch['id'])
106105
}, {}, _write_context=write_context)
107106

@@ -191,7 +190,7 @@ def get_branch_by_id(category_id: int):
191190

192191
try:
193192
result_json = json.loads(result)
194-
classification_id = result_json.get('classificationId', 0) # 0 兜底
193+
classification_id = result_json.get('classificationId')
195194
# 如果是 0 ,返回其他分支
196195
matched_branch = get_branch_by_id(classification_id)
197196
if matched_branch:
@@ -210,6 +209,26 @@ def get_branch_by_id(category_id: int):
210209
# 如果都解析失败,返回“other”
211210
return other_branch or (normal_intents[0] if normal_intents else {'id': 'unknown', 'content': 'unknown'})
212211

212+
def parse_result_reason(self, result: str):
213+
"""解析分类的原因"""
214+
try:
215+
result_json = json.loads(result)
216+
return result_json.get('reason', '')
217+
except Exception as e:
218+
reason_patterns = [
219+
r'"reason":\s*"([^"]*)"', # 标准格式
220+
r'"reason":\s*"([^"]*)', # 缺少结束引号
221+
r'"reason":\s*([^,}\n]*)', # 没有引号包围的内容
222+
]
223+
for pattern in reason_patterns:
224+
match = re.search(pattern, result, re.DOTALL)
225+
if match:
226+
reason = match.group(1).strip()
227+
# 清理可能的尾部字符
228+
reason = re.sub(r'["\s]*$', '', reason)
229+
return reason
230+
231+
return ''
213232

214233
def find_other_branch(self, branch: List[Dict]) -> Dict[str, Any] | None:
215234
"""查找其他分支"""

0 commit comments

Comments
 (0)