Skip to content

Commit e4d2c2f

Browse files
chore: migrate the AI Guard SDK to V2 of the message format (#14916)
## Description This pull request migrates the AI Guard SDK to version 2 of the REST API message format, aligning with the latest service specification. The update simplifies the SDK by dropping legacy V1 support, which had minimal adoption. ## Testing <!-- Describe your testing strategy or note what tests are included --> ## Risks ## Additional Notes [APPSEC-59206](https://datadoghq.atlassian.net/browse/APPSEC-59206) [APPSEC-59205](https://datadoghq.atlassian.net/browse/APPSEC-59205) <!-- Any other information that would be helpful for reviewers --> [APPSEC-59206]: https://datadoghq.atlassian.net/browse/APPSEC-59206?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ [APPSEC-59205]: https://datadoghq.atlassian.net/browse/APPSEC-59205?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ
1 parent 5fd1df2 commit e4d2c2f

File tree

13 files changed

+611
-616
lines changed

13 files changed

+611
-616
lines changed

ddtrace/appsec/_ai_guard/_langchain.py

Lines changed: 86 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import json
22
from typing import Any
3-
from typing import Dict
43
from typing import List
5-
from typing import Optional
64
from typing import Sequence
5+
import uuid
76

87
from ddtrace.appsec.ai_guard import AIGuardAbortError
98
from ddtrace.appsec.ai_guard import AIGuardClient
10-
from ddtrace.appsec.ai_guard import Prompt
9+
from ddtrace.appsec.ai_guard import Function
10+
from ddtrace.appsec.ai_guard import Message
11+
from ddtrace.appsec.ai_guard import Options
1112
from ddtrace.appsec.ai_guard import ToolCall
12-
from ddtrace.appsec.ai_guard._api_client import Evaluation
1313
from ddtrace.contrib.internal.trace_utils import unwrap
1414
from ddtrace.contrib.internal.trace_utils import wrap
1515
import ddtrace.internal.logger as ddlogger
@@ -61,12 +61,12 @@ def _langchain_unpatch():
6161

6262
def _langchain_agent_plan(client: AIGuardClient, func, instance, args, kwargs):
6363
action = func(*args, **kwargs)
64-
return _handle_agent_action_result(client, action, kwargs)
64+
return _handle_agent_action_result(client, action, args, kwargs)
6565

6666

6767
async def _langchain_agent_aplan(client: AIGuardClient, func, instance, args, kwargs):
6868
action = await func(*args, **kwargs)
69-
return _handle_agent_action_result(client, action, kwargs)
69+
return _handle_agent_action_result(client, action, args, kwargs)
7070

7171

7272
def _try_parse_json(value: dict, attribute: str) -> Any:
@@ -79,6 +79,15 @@ def _try_parse_json(value: dict, attribute: str) -> Any:
7979
return {attribute: json_str}
8080

8181

82+
def _try_format_json(value: Any) -> str:
83+
if not value:
84+
return ""
85+
try:
86+
return json.dumps(value)
87+
except Exception:
88+
return str(value)
89+
90+
8291
def _get_message_text(msg: Any) -> str:
8392
if isinstance(msg.content, str):
8493
return msg.content
@@ -91,71 +100,102 @@ def _get_message_text(msg: Any) -> str:
91100
return "".join(block if isinstance(block, str) else block["text"] for block in blocks)
92101

93102

94-
def _convert_messages(messages: list[Any]) -> list[Evaluation]:
103+
def _convert_messages(messages: list[Any]) -> list[Message]:
95104
from langchain_core.messages import ChatMessage
96105
from langchain_core.messages import HumanMessage
97106
from langchain_core.messages import SystemMessage
98107
from langchain_core.messages.ai import AIMessage
99108
from langchain_core.messages.function import FunctionMessage
100109
from langchain_core.messages.tool import ToolMessage
101110

102-
result: List[Evaluation] = []
103-
tool_calls: Dict[str, ToolCall] = dict()
104-
function_call: Optional[ToolCall] = None
111+
result: List[Message] = []
105112
for message in messages:
106113
try:
107114
if isinstance(message, HumanMessage):
108-
result.append(Prompt(role="user", content=_get_message_text(message)))
115+
result.append(Message(role="user", content=_get_message_text(message)))
109116
elif isinstance(message, SystemMessage):
110-
result.append(Prompt(role="system", content=_get_message_text(message)))
117+
result.append(Message(role="system", content=_get_message_text(message)))
111118
elif isinstance(message, ChatMessage):
112-
result.append(Prompt(role=message.role, content=_get_message_text(message)))
119+
result.append(Message(role=message.role, content=_get_message_text(message)))
113120
elif isinstance(message, AIMessage):
114-
for call in message.tool_calls:
115-
tool_call = ToolCall(tool_name=call["name"], tool_args=call["args"])
116-
result.append(tool_call)
117-
if call["id"]:
118-
tool_calls[call["id"]] = tool_call
121+
if len(message.tool_calls) > 0:
122+
tool_calls = [
123+
ToolCall(
124+
id=call.get("id", ""),
125+
function=Function(
126+
name=call.get("name", ""), arguments=_try_format_json(call.get("args", {}))
127+
),
128+
)
129+
for call in message.tool_calls
130+
]
131+
result.append(Message(role="assistant", tool_calls=tool_calls))
119132
if "function_call" in message.additional_kwargs:
120-
call = message.additional_kwargs["function_call"]
121-
function_call = ToolCall(tool_name=call.get("name"), tool_args=_try_parse_json(call, "arguments"))
122-
result.append(function_call)
133+
function_call = message.additional_kwargs["function_call"]
134+
tool_call = ToolCall(
135+
id="",
136+
function=Function(name=function_call.get("name"), arguments=function_call.get("arguments")),
137+
)
138+
result.append(Message(role="assistant", tool_calls=[tool_call]))
123139
if message.content:
124-
result.append(Prompt(role="assistant", content=_get_message_text(message)))
140+
result.append(Message(role="assistant", content=_get_message_text(message)))
125141
elif isinstance(message, ToolMessage):
126-
current_call = tool_calls.get(message.tool_call_id)
127-
if current_call:
128-
current_call["output"] = _get_message_text(message)
142+
result.append(
143+
Message(role="tool", tool_call_id=message.tool_call_id, content=_get_message_text(message))
144+
)
129145
elif isinstance(message, FunctionMessage):
130-
if function_call and function_call["tool_name"] == message.name:
131-
function_call["output"] = _get_message_text(message)
132-
function_call = None
146+
result.append(Message(role="tool", tool_call_id="", content=_get_message_text(message)))
133147
except Exception:
134148
logger.debug("Failed to convert message", exc_info=True)
135149

136150
return result
137151

138152

139-
def _handle_agent_action_result(client: AIGuardClient, result, kwargs):
153+
def _handle_agent_action_result(client: AIGuardClient, result, args, kwargs):
140154
try:
141155
from langchain_core.agents import AgentAction
142-
from langchain_core.agents import AgentFinish
156+
from langchain_core.agents import AgentActionMessageLog
143157
except ImportError:
144158
from langchain.agents import AgentAction
145-
from langchain.agents import AgentFinish
159+
from langchain.agents import AgentActionMessageLog
146160

147161
for action in result if isinstance(result, Sequence) else [result]:
148162
if isinstance(action, AgentAction) and action.tool:
149163
try:
150-
history = _convert_messages(kwargs["chat_history"]) if "chat_history" in kwargs else []
151-
if "input" in kwargs:
164+
chat_history = kwargs["chat_history"] if "chat_history" in kwargs else []
165+
messages = _convert_messages(chat_history)
166+
prompt = kwargs["input"] if "input" in kwargs else None
167+
if prompt:
152168
# TODO we are assuming user prompt
153-
history.append(Prompt(role="user", content=kwargs["input"]))
154-
tool_name = action.tool
155-
tool_input = action.tool_input
156-
if not client.evaluate_tool(tool_name, tool_input, history=history):
157-
blocked_message = f"Tool call '{tool_name}' was blocked due to security policies."
158-
return AgentFinish(return_values={"output": blocked_message}, log=blocked_message)
169+
messages.append(Message(role="user", content=prompt))
170+
intermediate_steps = get_argument_value(args, kwargs, 0, "intermediate_steps")
171+
if intermediate_steps:
172+
for intermediate_step, output in intermediate_steps:
173+
if isinstance(intermediate_step, AgentActionMessageLog):
174+
tool_call_id = str(uuid.uuid4())
175+
tool_call = ToolCall(
176+
id=tool_call_id,
177+
function=Function(
178+
name=intermediate_step.tool,
179+
arguments=_try_format_json(intermediate_step.tool_input),
180+
),
181+
)
182+
messages.append(Message(role="assistant", tool_calls=[tool_call]))
183+
184+
tool_output = str(output) if output else ""
185+
if tool_output:
186+
messages.append(Message(role="tool", tool_call_id=tool_call_id, content=tool_output))
187+
messages.append(
188+
Message(
189+
role="assistant",
190+
tool_calls=[
191+
ToolCall(
192+
id="",
193+
function=Function(name=action.tool, arguments=_try_format_json(action.tool_input)),
194+
)
195+
],
196+
)
197+
)
198+
client.evaluate(messages, Options(block=True))
159199
except AIGuardAbortError:
160200
raise
161201
except Exception:
@@ -173,8 +213,10 @@ def _langchain_chatmodel_generate_before(client: AIGuardClient, message_lists):
173213

174214

175215
def _langchain_llm_generate_before(client: AIGuardClient, prompts):
216+
from langchain_core.messages import HumanMessage
217+
176218
for prompt in prompts:
177-
result = _evaluate_langchain_prompt(client, prompt)
219+
result = _evaluate_langchain_messages(client, [HumanMessage(content=prompt)])
178220
if result:
179221
return result
180222
return None
@@ -187,35 +229,22 @@ def _langchain_chatmodel_stream_before(client: AIGuardClient, instance, args, kw
187229

188230

189231
def _langchain_llm_stream_before(client: AIGuardClient, instance, args, kwargs):
232+
from langchain_core.messages import HumanMessage
233+
190234
input_arg = get_argument_value(args, kwargs, 0, "input")
191235
prompt = instance._convert_input(input_arg).to_string()
192-
return _evaluate_langchain_prompt(client, prompt)
236+
return _evaluate_langchain_messages(client, [HumanMessage(content=prompt)])
193237

194238

195239
def _evaluate_langchain_messages(client: AIGuardClient, messages):
196240
from langchain_core.messages import HumanMessage
197241

198242
# only call evaluator when the last message is an actual user prompt
199243
if len(messages) > 0 and isinstance(messages[-1], HumanMessage):
200-
history = _convert_messages(messages)
201-
prompt = history.pop(-1)
202244
try:
203-
role, content = (prompt["role"], prompt["content"]) # type: ignore[typeddict-item]
204-
if not client.evaluate_prompt(role, content, history=history):
205-
return AIGuardAbortError()
245+
client.evaluate(_convert_messages(messages), Options(block=True))
206246
except AIGuardAbortError as e:
207247
return e
208248
except Exception:
209249
logger.debug("Failed to evaluate chat model prompt", exc_info=True)
210250
return None
211-
212-
213-
def _evaluate_langchain_prompt(client: AIGuardClient, prompt):
214-
try:
215-
if not client.evaluate_prompt("user", prompt):
216-
return AIGuardAbortError()
217-
except AIGuardAbortError as e:
218-
return e
219-
except Exception:
220-
logger.debug("Failed to evaluate llm prompt", exc_info=True)
221-
return None

ddtrace/appsec/ai_guard/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,22 @@
55
from ._api_client import AIGuardAbortError
66
from ._api_client import AIGuardClient
77
from ._api_client import AIGuardClientError
8-
from ._api_client import AIGuardWorkflow
9-
from ._api_client import Prompt
8+
from ._api_client import Evaluation
9+
from ._api_client import Function
10+
from ._api_client import Message
11+
from ._api_client import Options
1012
from ._api_client import ToolCall
1113
from ._api_client import new_ai_guard_client
1214

1315

1416
__all__ = [
1517
"new_ai_guard_client",
16-
"AIGuardWorkflow",
1718
"AIGuardClient",
1819
"AIGuardClientError",
1920
"AIGuardAbortError",
20-
"Prompt",
21+
"Evaluation",
22+
"Function",
23+
"Message",
24+
"Options",
2125
"ToolCall",
2226
]

0 commit comments

Comments
 (0)