11from collections .abc import Generator
22from typing import Any
3+ import logging
34
45from dify_plugin import Tool
56from dify_plugin .entities .tool import ToolInvokeMessage
67from openguardrails import OpenGuardrails
78
9+ logger = logging .getLogger (__name__ )
10+
811class CheckPromptTool (Tool ):
912 def _invoke (self , tool_parameters : dict [str , Any ]) -> Generator [ToolInvokeMessage ]:
1013 try :
@@ -24,39 +27,57 @@ def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessag
2427 yield self .create_text_message ("Error: API key is required." )
2528 return
2629
30+ # Get base_url from credentials
31+ base_url = self .runtime .credentials .get ("base_url" )
32+
2733 # Create OpenGuardrails client and check user input
28- client = OpenGuardrails (api_key )
29- result = client .check_prompt (prompt , user_id = user_id )
34+ try :
35+ if base_url :
36+ client = OpenGuardrails (api_key , base_url = base_url )
37+ else :
38+ client = OpenGuardrails (api_key )
39+
40+ result = client .check_prompt (prompt , user_id = user_id )
41+ except Exception as client_error :
42+ logger .error (f"OpenGuardrails client error: { str (client_error )} " , exc_info = True )
43+ raise
3044
3145 # Extract categories field: from compliance and security not equal to "no_risk" categories list first item
32- categories = []
33- if result .result .compliance .risk_level != "no_risk" and result .result .compliance .categories :
34- categories .append (result .result .compliance .categories [0 ])
35- elif result .result .security .risk_level != "no_risk" and result .result .security .categories :
36- categories .append (result .result .security .categories [0 ])
37- elif result .result .data .risk_level != "no_risk" and result .result .data .categories :
38- categories .append (result .result .data .categories [0 ])
39-
40- categories_str = ", " .join (categories )
41- if categories_str :
42- categories_str = f"{ categories_str } "
43- else :
44- categories_str = ""
45-
46- # Process suggest_answer field, if not exist, set to empty string
47- suggest_answer = ""
48- if result .suggest_answer :
49- suggest_answer = result .suggest_answer
50-
51- # Use custom variable to return result
52- yield self .create_variable_message ("id" , result .id )
53- yield self .create_variable_message ("overall_risk_level" , result .overall_risk_level )
54- yield self .create_variable_message ("suggest_action" , result .suggest_action )
55- yield self .create_variable_message ("suggest_answer" , suggest_answer )
56- yield self .create_variable_message ("categories" , categories_str )
57- yield self .create_variable_message ("score" , result .score )
46+ try :
47+ categories = []
48+ if result .result .compliance .risk_level != "no_risk" and result .result .compliance .categories :
49+ categories .append (result .result .compliance .categories [0 ])
50+ elif result .result .security .risk_level != "no_risk" and result .result .security .categories :
51+ categories .append (result .result .security .categories [0 ])
52+ elif result .result .data .risk_level != "no_risk" and result .result .data .categories :
53+ categories .append (result .result .data .categories [0 ])
54+
55+ categories_str = ", " .join (categories )
56+ if categories_str :
57+ categories_str = f"{ categories_str } "
58+ else :
59+ categories_str = ""
60+
61+ # Process suggest_answer field, if not exist, set to empty string
62+ suggest_answer = ""
63+ if result .suggest_answer :
64+ suggest_answer = result .suggest_answer
65+
66+ # Use custom variable to return result
67+ yield self .create_variable_message ("id" , result .id )
68+ yield self .create_variable_message ("overall_risk_level" , result .overall_risk_level )
69+ yield self .create_variable_message ("suggest_action" , result .suggest_action )
70+ yield self .create_variable_message ("suggest_answer" , suggest_answer )
71+ yield self .create_variable_message ("categories" , categories_str )
72+ # Ensure score is a number, not None
73+ score_value = result .score if result .score is not None else 0.0
74+ yield self .create_variable_message ("score" , score_value )
75+ except Exception as result_error :
76+ logger .error (f"Error processing result: { str (result_error )} " , exc_info = True )
77+ raise
5878
5979 except Exception as e :
6080 # 错误处理
81+ logger .error (f"CheckPromptTool error: { str (e )} " , exc_info = True )
6182 yield self .create_text_message (f"Error: { str (e )} " )
6283 yield self .create_json_message ({"error" : str (e )})
0 commit comments