Skip to content

Commit 06f37ec

Browse files
Ensure messages are not modified before span serialization
1 parent 85ac4c8 commit 06f37ec

File tree

4 files changed

+96
-126
lines changed

4 files changed

+96
-126
lines changed

dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -135,18 +135,17 @@ private static List<Message> messagesForMetaStruct(List<Message> messages) {
135135
final int maxContent = config.getAiGuardMaxContentSize();
136136
boolean contentTruncated = false;
137137
for (int i = 0; i < size; i++) {
138-
Message source = messages.get(i);
139-
final String content = source.getContent();
138+
final Message source = messages.get(i);
139+
String content = source.getContent();
140140
if (content != null && content.length() > maxContent) {
141141
contentTruncated = true;
142-
source =
143-
new Message(
144-
source.getRole(),
145-
content.substring(0, maxContent),
146-
source.getToolCalls(),
147-
source.getToolCallId());
142+
content = content.substring(0, maxContent);
148143
}
149-
result.add(source);
144+
List<ToolCall> toolCalls = source.getToolCalls();
145+
if (toolCalls != null) {
146+
toolCalls = new ArrayList<>(toolCalls);
147+
}
148+
result.add(new Message(source.getRole(), content, toolCalls, source.getToolCallId()));
150149
}
151150
if (contentTruncated) {
152151
WafMetricCollector.get().aiGuardTruncated(CONTENT);
@@ -240,7 +239,7 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
240239
span.setTag(BLOCKED_TAG, true);
241240
throw new AIGuardAbortError(action, reason, tags);
242241
}
243-
return new Evaluation(action, reason);
242+
return new Evaluation(action, reason, tags);
244243
}
245244
} catch (AIGuardAbortError e) {
246245
span.addThrowable(e);

dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy

Lines changed: 72 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import datadog.trace.api.aiguard.AIGuard
1010
import datadog.trace.api.telemetry.WafMetricCollector
1111
import datadog.trace.bootstrap.instrumentation.api.AgentSpan
1212
import datadog.trace.bootstrap.instrumentation.api.AgentTracer
13+
import datadog.trace.bootstrap.instrumentation.api.AgentTracer.TracerAPI
1314
import datadog.trace.test.util.DDSpecification
1415
import okhttp3.Call
1516
import okhttp3.HttpUrl
@@ -71,6 +72,7 @@ class AIGuardInternalTests extends DDSpecification {
7172
@Shared
7273
protected static final PROMPT = TOOL_OUTPUT + [AIGuard.Message.message('assistant', '2 + 2 is 5'), AIGuard.Message.message('user', '')]
7374

75+
protected TracerAPI tracer
7476
protected AgentSpan span
7577

7678
void setup() {
@@ -81,7 +83,7 @@ class AIGuardInternalTests extends DDSpecification {
8183
final builder = Mock(AgentTracer.SpanBuilder) {
8284
start() >> span
8385
}
84-
final tracer = Stub(AgentTracer.TracerAPI) {
86+
tracer = Stub(TracerAPI) {
8587
buildSpan(_ as String, _ as String) >> builder
8688
}
8789
AgentTracer.forceRegister(tracer)
@@ -198,19 +200,18 @@ class AIGuardInternalTests extends DDSpecification {
198200
1 * span.addThrowable(_ as AIGuard.AIGuardAbortError)
199201
}
200202

201-
receivedMeta.messages == suite.messages
202-
if (suite.tags) {
203-
receivedMeta.attack_categories == suite.tags
204-
}
203+
assertMeta(receivedMeta, suite)
205204
assertRequest(request, suite.messages)
206205
if (throwAbortError) {
207206
error instanceof AIGuard.AIGuardAbortError
208207
error.action == suite.action
209208
error.reason == suite.reason
209+
error.tags == suite.tags
210210
} else {
211211
error == null
212212
eval.action == suite.action
213213
eval.reason == suite.reason
214+
eval.tags == suite.tags
214215
}
215216
assertTelemetry('ai_guard.requests', "action:$suite.action", "block:$throwAbortError", 'error:false')
216217

@@ -221,19 +222,7 @@ class AIGuardInternalTests extends DDSpecification {
221222
void 'test evaluate with API errors'() {
222223
given:
223224
final errors = [[status: 400, title: 'Bad request']]
224-
Request request = null
225-
final call = Mock(Call) {
226-
execute() >> {
227-
return mockResponse(request, 404, [errors: errors])
228-
}
229-
}
230-
final client = Mock(OkHttpClient) {
231-
newCall(_ as Request) >> {
232-
request = (Request) it[0]
233-
return call
234-
}
235-
}
236-
final aiguard = new AIGuardInternal(URL, HEADERS, client)
225+
final aiguard = mockClient(404, [errors: errors])
237226

238227
when:
239228
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
@@ -247,19 +236,7 @@ class AIGuardInternalTests extends DDSpecification {
247236

248237
void 'test evaluate with invalid JSON'() {
249238
given:
250-
Request request = null
251-
final call = Mock(Call) {
252-
execute() >> {
253-
return mockResponse(request, 200, [bad: 'This is an invalid response'])
254-
}
255-
}
256-
final client = Mock(OkHttpClient) {
257-
newCall(_ as Request) >> {
258-
request = (Request) it[0]
259-
return call
260-
}
261-
}
262-
final aiguard = new AIGuardInternal(URL, HEADERS, client)
239+
final aiguard = mockClient(200, [bad: 'This is an invalid response'])
263240

264241
when:
265242
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
@@ -272,19 +249,7 @@ class AIGuardInternalTests extends DDSpecification {
272249

273250
void 'test evaluate with missing action'() {
274251
given:
275-
Request request = null
276-
final call = Mock(Call) {
277-
execute() >> {
278-
return mockResponse(request, 200, [data: [attributes: [reason: 'I miss something']]])
279-
}
280-
}
281-
final client = Mock(OkHttpClient) {
282-
newCall(_ as Request) >> {
283-
request = (Request) it[0]
284-
return call
285-
}
286-
}
287-
final aiguard = new AIGuardInternal(URL, HEADERS, client)
252+
final aiguard = mockClient(200, [data: [attributes: [reason: 'I miss something']]])
288253

289254
when:
290255
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
@@ -297,19 +262,7 @@ class AIGuardInternalTests extends DDSpecification {
297262

298263
void 'test evaluate with non JSON response'() {
299264
given:
300-
Request request = null
301-
final call = Mock(Call) {
302-
execute() >> {
303-
return mockResponse(request, 200, [data: [attributes: [reason: 'I miss something']]])
304-
}
305-
}
306-
final client = Mock(OkHttpClient) {
307-
newCall(_ as Request) >> {
308-
request = (Request) it[0]
309-
return call
310-
}
311-
}
312-
final aiguard = new AIGuardInternal(URL, HEADERS, client)
265+
final aiguard = mockClient(200, 'I am no JSON')
313266

314267
when:
315268
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
@@ -322,19 +275,7 @@ class AIGuardInternalTests extends DDSpecification {
322275

323276
void 'test evaluate with empty response'() {
324277
given:
325-
Request request = null
326-
final call = Mock(Call) {
327-
execute() >> {
328-
return mockResponse(request, 200, null)
329-
}
330-
}
331-
final client = Mock(OkHttpClient) {
332-
newCall(_ as Request) >> {
333-
request = (Request) it[0]
334-
return call
335-
}
336-
}
337-
final aiguard = new AIGuardInternal(URL, HEADERS, client)
278+
final aiguard = mockClient(200, null)
338279

339280
when:
340281
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
@@ -348,19 +289,7 @@ class AIGuardInternalTests extends DDSpecification {
348289
void 'test message length truncation'() {
349290
given:
350291
final maxMessages = Config.get().getAiGuardMaxMessagesLength()
351-
Request request = null
352-
final call = Mock(Call) {
353-
execute() >> {
354-
return mockResponse(request, 200, [data: [attributes: [action: ALLOW, reason: 'It is fine']]])
355-
}
356-
}
357-
final client = Mock(OkHttpClient) {
358-
newCall(_ as Request) >> {
359-
request = (Request) it[0]
360-
return call
361-
}
362-
}
363-
final aiguard = new AIGuardInternal(URL, HEADERS, client)
292+
final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]])
364293
final messages = (0..maxMessages)
365294
.collect { AIGuard.Message.message('user', "This is a prompt: ${it}") }
366295
.toList()
@@ -380,19 +309,7 @@ class AIGuardInternalTests extends DDSpecification {
380309
void 'test message content truncation'() {
381310
given:
382311
final maxContent = Config.get().getAiGuardMaxContentSize()
383-
Request request = null
384-
final call = Mock(Call) {
385-
execute() >> {
386-
return mockResponse(request, 200, [data: [attributes: [action: ALLOW, reason: 'It is fine']]])
387-
}
388-
}
389-
final client = Mock(OkHttpClient) {
390-
newCall(_ as Request) >> {
391-
request = (Request) it[0]
392-
return call
393-
}
394-
}
395-
final aiguard = new AIGuardInternal(URL, HEADERS, client)
312+
final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]])
396313
final message = AIGuard.Message.message("user", (0..maxContent).collect { 'A' }.join())
397314

398315
when:
@@ -426,23 +343,7 @@ class AIGuardInternalTests extends DDSpecification {
426343

427344
void 'test missing tool name'() {
428345
given:
429-
def request
430-
final call = Mock(Call) {
431-
execute() >> {
432-
return mockResponse(
433-
request,
434-
200,
435-
[data: [attributes: [action: 'ALLOW', reason: 'Just do it']]]
436-
)
437-
}
438-
}
439-
final client = Mock(OkHttpClient) {
440-
newCall(_ as Request) >> {
441-
request = (Request) it[0]
442-
return call
443-
}
444-
}
445-
final aiguard = new AIGuardInternal(URL, HEADERS, client)
346+
final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Just do it']]])
446347

447348
when:
448349
aiguard.evaluate([AIGuard.Message.tool('call_1', 'Content')], AIGuard.Options.DEFAULT)
@@ -460,6 +361,52 @@ class AIGuardInternalTests extends DDSpecification {
460361
thrown(IllegalArgumentException)
461362
}
462363

364+
void 'test message immutability'() {
365+
given:
366+
final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Just do it']]])
367+
final message = new AIGuard.Message(
368+
"assistant",
369+
null,
370+
[AIGuard.ToolCall.toolCall('call_1', 'execute_shell', '{"cmd": "ls -lah"}')],
371+
null
372+
)
373+
Map<String, Object> receivedMeta
374+
375+
when:
376+
aiguard.evaluate([message], AIGuard.Options.DEFAULT)
377+
378+
then:
379+
1 * span.finish() >> {
380+
// modify the tool calls before flushing
381+
message.toolCalls.add(
382+
AIGuard.ToolCall.toolCall('call_2', 'execute_shell', '{"cmd": "rm -rf"}')
383+
)
384+
}
385+
1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _ as Map) >> {
386+
receivedMeta = it[1] as Map<String, Object>
387+
return span
388+
}
389+
final messages = receivedMeta.messages as List<AIGuard.Message>
390+
final assistant = messages.first()
391+
assistant.toolCalls.size() == 1
392+
}
393+
394+
private AIGuardInternal mockClient(final int status, final Object response) {
395+
def request
396+
final call = Stub(Call) {
397+
execute() >> {
398+
return mockResponse(request, status, response)
399+
}
400+
}
401+
final client = Stub(OkHttpClient) {
402+
newCall(_ as Request) >> {
403+
request = (Request) it[0]
404+
return call
405+
}
406+
}
407+
return new AIGuardInternal(URL, HEADERS, client)
408+
}
409+
463410
private static assertTelemetry(final String metric, final String...tags) {
464411
final metrics = WafMetricCollector.get().with {
465412
prepareMetrics()
@@ -475,6 +422,16 @@ class AIGuardInternalTests extends DDSpecification {
475422
return true
476423
}
477424

425+
private static assertMeta(final Map<String, Object> meta, final TestSuite suite) {
426+
if (suite.tags) {
427+
assert meta.attack_categories == suite.tags
428+
}
429+
final receivedMessages = snakeCaseJson(meta.messages)
430+
final expectedMessages = snakeCaseJson(suite.messages)
431+
JSONAssert.assertEquals(expectedMessages, receivedMessages, JSONCompareMode.NON_EXTENSIBLE)
432+
return true
433+
}
434+
478435
private static assertRequest(final Request request, final List<AIGuard.Message> messages) {
479436
assert request.url() == URL
480437
assert request.method() == 'POST'
@@ -556,7 +513,8 @@ class AIGuardInternalTests extends DDSpecification {
556513
", reason='" + reason + '\'' +
557514
", blocking=" + blocking +
558515
", target='" + target + '\'' +
559-
", messages=" + messages +
516+
", messages=" + messages + '\'' +
517+
", tags=" + tags +
560518
'}'
561519
}
562520
}

dd-trace-api/src/main/java/datadog/trace/api/aiguard/AIGuard.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,16 +143,19 @@ public static class Evaluation {
143143

144144
final Action action;
145145
final String reason;
146+
final List<String> tags;
146147

147148
/**
148149
* Creates a new evaluation result.
149150
*
150151
* @param action the recommended action for the evaluated content
151152
* @param reason human-readable explanation for the decision
153+
* @param tags list of tags associated with the evaluation (e.g. indirect-prompt-injection)
152154
*/
153-
public Evaluation(final Action action, final String reason) {
155+
public Evaluation(final Action action, final String reason, final List<String> tags) {
154156
this.action = action;
155157
this.reason = reason;
158+
this.tags = tags;
156159
}
157160

158161
/**
@@ -172,6 +175,15 @@ public Action getAction() {
172175
public String getReason() {
173176
return reason;
174177
}
178+
179+
/**
180+
* Returns the list of tags associated with the evaluation (e.g. indirect-prompt-injection)
181+
*
182+
* @return list of tags.
183+
*/
184+
public List<String> getTags() {
185+
return tags;
186+
}
175187
}
176188

177189
/**

dd-trace-api/src/main/java/datadog/trace/api/aiguard/noop/NoOpEvaluator.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package datadog.trace.api.aiguard.noop;
22

33
import static datadog.trace.api.aiguard.AIGuard.Action.ALLOW;
4+
import static java.util.Collections.emptyList;
45

56
import datadog.trace.api.aiguard.AIGuard.Evaluation;
67
import datadog.trace.api.aiguard.AIGuard.Message;
@@ -12,6 +13,6 @@ public final class NoOpEvaluator implements Evaluator {
1213

1314
@Override
1415
public Evaluation evaluate(final List<Message> messages, final Options options) {
15-
return new Evaluation(ALLOW, "AI Guard is not enabled");
16+
return new Evaluation(ALLOW, "AI Guard is not enabled", emptyList());
1617
}
1718
}

0 commit comments

Comments
 (0)