11import json
22from typing import Any
3- from typing import Dict
43from typing import List
5- from typing import Optional
64from typing import Sequence
5+ import uuid
76
87from ddtrace .appsec .ai_guard import AIGuardAbortError
98from ddtrace .appsec .ai_guard import AIGuardClient
10- from ddtrace .appsec .ai_guard import Prompt
9+ from ddtrace .appsec .ai_guard import Function
10+ from ddtrace .appsec .ai_guard import Message
11+ from ddtrace .appsec .ai_guard import Options
1112from ddtrace .appsec .ai_guard import ToolCall
12- from ddtrace .appsec .ai_guard ._api_client import Evaluation
1313from ddtrace .contrib .internal .trace_utils import unwrap
1414from ddtrace .contrib .internal .trace_utils import wrap
1515import ddtrace .internal .logger as ddlogger
@@ -61,12 +61,12 @@ def _langchain_unpatch():
6161
6262def _langchain_agent_plan (client : AIGuardClient , func , instance , args , kwargs ):
6363 action = func (* args , ** kwargs )
64- return _handle_agent_action_result (client , action , kwargs )
64+ return _handle_agent_action_result (client , action , args , kwargs )
6565
6666
6767async def _langchain_agent_aplan (client : AIGuardClient , func , instance , args , kwargs ):
6868 action = await func (* args , ** kwargs )
69- return _handle_agent_action_result (client , action , kwargs )
69+ return _handle_agent_action_result (client , action , args , kwargs )
7070
7171
7272def _try_parse_json (value : dict , attribute : str ) -> Any :
@@ -79,6 +79,15 @@ def _try_parse_json(value: dict, attribute: str) -> Any:
7979 return {attribute : json_str }
8080
8181
82+ def _try_format_json (value : Any ) -> str :
83+ if not value :
84+ return ""
85+ try :
86+ return json .dumps (value )
87+ except Exception :
88+ return str (value )
89+
90+
8291def _get_message_text (msg : Any ) -> str :
8392 if isinstance (msg .content , str ):
8493 return msg .content
@@ -91,71 +100,102 @@ def _get_message_text(msg: Any) -> str:
91100 return "" .join (block if isinstance (block , str ) else block ["text" ] for block in blocks )
92101
93102
94- def _convert_messages (messages : list [Any ]) -> list [Evaluation ]:
103+ def _convert_messages (messages : list [Any ]) -> list [Message ]:
95104 from langchain_core .messages import ChatMessage
96105 from langchain_core .messages import HumanMessage
97106 from langchain_core .messages import SystemMessage
98107 from langchain_core .messages .ai import AIMessage
99108 from langchain_core .messages .function import FunctionMessage
100109 from langchain_core .messages .tool import ToolMessage
101110
102- result : List [Evaluation ] = []
103- tool_calls : Dict [str , ToolCall ] = dict ()
104- function_call : Optional [ToolCall ] = None
111+ result : List [Message ] = []
105112 for message in messages :
106113 try :
107114 if isinstance (message , HumanMessage ):
108- result .append (Prompt (role = "user" , content = _get_message_text (message )))
115+ result .append (Message (role = "user" , content = _get_message_text (message )))
109116 elif isinstance (message , SystemMessage ):
110- result .append (Prompt (role = "system" , content = _get_message_text (message )))
117+ result .append (Message (role = "system" , content = _get_message_text (message )))
111118 elif isinstance (message , ChatMessage ):
112- result .append (Prompt (role = message .role , content = _get_message_text (message )))
119+ result .append (Message (role = message .role , content = _get_message_text (message )))
113120 elif isinstance (message , AIMessage ):
114- for call in message .tool_calls :
115- tool_call = ToolCall (tool_name = call ["name" ], tool_args = call ["args" ])
116- result .append (tool_call )
117- if call ["id" ]:
118- tool_calls [call ["id" ]] = tool_call
121+ if len (message .tool_calls ) > 0 :
122+ tool_calls = [
123+ ToolCall (
124+ id = call .get ("id" , "" ),
125+ function = Function (
126+ name = call .get ("name" , "" ), arguments = _try_format_json (call .get ("args" , {}))
127+ ),
128+ )
129+ for call in message .tool_calls
130+ ]
131+ result .append (Message (role = "assistant" , tool_calls = tool_calls ))
119132 if "function_call" in message .additional_kwargs :
120- call = message .additional_kwargs ["function_call" ]
121- function_call = ToolCall (tool_name = call .get ("name" ), tool_args = _try_parse_json (call , "arguments" ))
122- result .append (function_call )
133+ function_call = message .additional_kwargs ["function_call" ]
134+ tool_call = ToolCall (
135+ id = "" ,
136+ function = Function (name = function_call .get ("name" ), arguments = function_call .get ("arguments" )),
137+ )
138+ result .append (Message (role = "assistant" , tool_calls = [tool_call ]))
123139 if message .content :
124- result .append (Prompt (role = "assistant" , content = _get_message_text (message )))
140+ result .append (Message (role = "assistant" , content = _get_message_text (message )))
125141 elif isinstance (message , ToolMessage ):
126- current_call = tool_calls . get ( message . tool_call_id )
127- if current_call :
128- current_call [ "output" ] = _get_message_text ( message )
142+ result . append (
143+ Message ( role = "tool" , tool_call_id = message . tool_call_id , content = _get_message_text ( message ))
144+ )
129145 elif isinstance (message , FunctionMessage ):
130- if function_call and function_call ["tool_name" ] == message .name :
131- function_call ["output" ] = _get_message_text (message )
132- function_call = None
146+ result .append (Message (role = "tool" , tool_call_id = "" , content = _get_message_text (message )))
133147 except Exception :
134148 logger .debug ("Failed to convert message" , exc_info = True )
135149
136150 return result
137151
138152
139- def _handle_agent_action_result (client : AIGuardClient , result , kwargs ):
153+ def _handle_agent_action_result (client : AIGuardClient , result , args , kwargs ):
140154 try :
141155 from langchain_core .agents import AgentAction
142- from langchain_core .agents import AgentFinish
156+ from langchain_core .agents import AgentActionMessageLog
143157 except ImportError :
144158 from langchain .agents import AgentAction
145- from langchain .agents import AgentFinish
159+ from langchain .agents import AgentActionMessageLog
146160
147161 for action in result if isinstance (result , Sequence ) else [result ]:
148162 if isinstance (action , AgentAction ) and action .tool :
149163 try :
150- history = _convert_messages (kwargs ["chat_history" ]) if "chat_history" in kwargs else []
151- if "input" in kwargs :
164+ chat_history = kwargs ["chat_history" ] if "chat_history" in kwargs else []
165+ messages = _convert_messages (chat_history )
166+ prompt = kwargs ["input" ] if "input" in kwargs else None
167+ if prompt :
152168 # TODO we are assuming user prompt
153- history .append (Prompt (role = "user" , content = kwargs ["input" ]))
154- tool_name = action .tool
155- tool_input = action .tool_input
156- if not client .evaluate_tool (tool_name , tool_input , history = history ):
157- blocked_message = f"Tool call '{ tool_name } ' was blocked due to security policies."
158- return AgentFinish (return_values = {"output" : blocked_message }, log = blocked_message )
169+ messages .append (Message (role = "user" , content = prompt ))
170+ intermediate_steps = get_argument_value (args , kwargs , 0 , "intermediate_steps" )
171+ if intermediate_steps :
172+ for intermediate_step , output in intermediate_steps :
173+ if isinstance (intermediate_step , AgentActionMessageLog ):
174+ tool_call_id = str (uuid .uuid4 ())
175+ tool_call = ToolCall (
176+ id = tool_call_id ,
177+ function = Function (
178+ name = intermediate_step .tool ,
179+ arguments = _try_format_json (intermediate_step .tool_input ),
180+ ),
181+ )
182+ messages .append (Message (role = "assistant" , tool_calls = [tool_call ]))
183+
184+ tool_output = str (output ) if output else ""
185+ if tool_output :
186+ messages .append (Message (role = "tool" , tool_call_id = tool_call_id , content = tool_output ))
187+ messages .append (
188+ Message (
189+ role = "assistant" ,
190+ tool_calls = [
191+ ToolCall (
192+ id = "" ,
193+ function = Function (name = action .tool , arguments = _try_format_json (action .tool_input )),
194+ )
195+ ],
196+ )
197+ )
198+ client .evaluate (messages , Options (block = True ))
159199 except AIGuardAbortError :
160200 raise
161201 except Exception :
@@ -173,8 +213,10 @@ def _langchain_chatmodel_generate_before(client: AIGuardClient, message_lists):
173213
174214
175215def _langchain_llm_generate_before (client : AIGuardClient , prompts ):
216+ from langchain_core .messages import HumanMessage
217+
176218 for prompt in prompts :
177- result = _evaluate_langchain_prompt (client , prompt )
219+ result = _evaluate_langchain_messages (client , [ HumanMessage ( content = prompt )] )
178220 if result :
179221 return result
180222 return None
@@ -187,35 +229,22 @@ def _langchain_chatmodel_stream_before(client: AIGuardClient, instance, args, kw
187229
188230
189231def _langchain_llm_stream_before (client : AIGuardClient , instance , args , kwargs ):
232+ from langchain_core .messages import HumanMessage
233+
190234 input_arg = get_argument_value (args , kwargs , 0 , "input" )
191235 prompt = instance ._convert_input (input_arg ).to_string ()
192- return _evaluate_langchain_prompt (client , prompt )
236+ return _evaluate_langchain_messages (client , [ HumanMessage ( content = prompt )] )
193237
194238
195239def _evaluate_langchain_messages (client : AIGuardClient , messages ):
196240 from langchain_core .messages import HumanMessage
197241
198242 # only call evaluator when the last message is an actual user prompt
199243 if len (messages ) > 0 and isinstance (messages [- 1 ], HumanMessage ):
200- history = _convert_messages (messages )
201- prompt = history .pop (- 1 )
202244 try :
203- role , content = (prompt ["role" ], prompt ["content" ]) # type: ignore[typeddict-item]
204- if not client .evaluate_prompt (role , content , history = history ):
205- return AIGuardAbortError ()
245+ client .evaluate (_convert_messages (messages ), Options (block = True ))
206246 except AIGuardAbortError as e :
207247 return e
208248 except Exception :
209249 logger .debug ("Failed to evaluate chat model prompt" , exc_info = True )
210250 return None
211-
212-
213- def _evaluate_langchain_prompt (client : AIGuardClient , prompt ):
214- try :
215- if not client .evaluate_prompt ("user" , prompt ):
216- return AIGuardAbortError ()
217- except AIGuardAbortError as e :
218- return e
219- except Exception :
220- logger .debug ("Failed to evaluate llm prompt" , exc_info = True )
221- return None
0 commit comments