Skip to content

Commit 24e5295

Browse files
Ensure messages are not modified before span serialization (#10116)
Ensure messages are not modified before span serialization
1 parent 87d3759 commit 24e5295

File tree

4 files changed

+99
-126
lines changed

4 files changed

+99
-126
lines changed

dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -134,19 +134,18 @@ private static List<Message> messagesForMetaStruct(List<Message> messages) {
134134
final List<Message> result = new ArrayList<>(size);
135135
final int maxContent = config.getAiGuardMaxContentSize();
136136
boolean contentTruncated = false;
137-
for (int i = 0; i < size; i++) {
138-
Message source = messages.get(i);
139-
final String content = source.getContent();
137+
for (int i = messages.size() - size; i < messages.size(); i++) {
138+
final Message source = messages.get(i);
139+
String content = source.getContent();
140140
if (content != null && content.length() > maxContent) {
141141
contentTruncated = true;
142-
source =
143-
new Message(
144-
source.getRole(),
145-
content.substring(0, maxContent),
146-
source.getToolCalls(),
147-
source.getToolCallId());
142+
content = content.substring(0, maxContent);
148143
}
149-
result.add(source);
144+
List<ToolCall> toolCalls = source.getToolCalls();
145+
if (toolCalls != null) {
146+
toolCalls = new ArrayList<>(toolCalls);
147+
}
148+
result.add(new Message(source.getRole(), content, toolCalls, source.getToolCallId()));
150149
}
151150
if (contentTruncated) {
152151
WafMetricCollector.get().aiGuardTruncated(CONTENT);
@@ -240,7 +239,7 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
240239
span.setTag(BLOCKED_TAG, true);
241240
throw new AIGuardAbortError(action, reason, tags);
242241
}
243-
return new Evaluation(action, reason);
242+
return new Evaluation(action, reason, tags);
244243
}
245244
} catch (AIGuardAbortError e) {
246245
span.addThrowable(e);

dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy

Lines changed: 74 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -198,19 +198,18 @@ class AIGuardInternalTests extends DDSpecification {
198198
1 * span.addThrowable(_ as AIGuard.AIGuardAbortError)
199199
}
200200

201-
receivedMeta.messages == suite.messages
202-
if (suite.tags) {
203-
receivedMeta.attack_categories == suite.tags
204-
}
201+
assertMeta(receivedMeta, suite)
205202
assertRequest(request, suite.messages)
206203
if (throwAbortError) {
207204
error instanceof AIGuard.AIGuardAbortError
208205
error.action == suite.action
209206
error.reason == suite.reason
207+
error.tags == suite.tags
210208
} else {
211209
error == null
212210
eval.action == suite.action
213211
eval.reason == suite.reason
212+
eval.tags == suite.tags
214213
}
215214
assertTelemetry('ai_guard.requests', "action:$suite.action", "block:$throwAbortError", 'error:false')
216215

@@ -221,19 +220,7 @@ class AIGuardInternalTests extends DDSpecification {
221220
void 'test evaluate with API errors'() {
222221
given:
223222
final errors = [[status: 400, title: 'Bad request']]
224-
Request request = null
225-
final call = Mock(Call) {
226-
execute() >> {
227-
return mockResponse(request, 404, [errors: errors])
228-
}
229-
}
230-
final client = Mock(OkHttpClient) {
231-
newCall(_ as Request) >> {
232-
request = (Request) it[0]
233-
return call
234-
}
235-
}
236-
final aiguard = new AIGuardInternal(URL, HEADERS, client)
223+
final aiguard = mockClient(404, [errors: errors])
237224

238225
when:
239226
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
@@ -247,19 +234,7 @@ class AIGuardInternalTests extends DDSpecification {
247234

248235
void 'test evaluate with invalid JSON'() {
249236
given:
250-
Request request = null
251-
final call = Mock(Call) {
252-
execute() >> {
253-
return mockResponse(request, 200, [bad: 'This is an invalid response'])
254-
}
255-
}
256-
final client = Mock(OkHttpClient) {
257-
newCall(_ as Request) >> {
258-
request = (Request) it[0]
259-
return call
260-
}
261-
}
262-
final aiguard = new AIGuardInternal(URL, HEADERS, client)
237+
final aiguard = mockClient(200, [bad: 'This is an invalid response'])
263238

264239
when:
265240
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
@@ -272,19 +247,7 @@ class AIGuardInternalTests extends DDSpecification {
272247

273248
void 'test evaluate with missing action'() {
274249
given:
275-
Request request = null
276-
final call = Mock(Call) {
277-
execute() >> {
278-
return mockResponse(request, 200, [data: [attributes: [reason: 'I miss something']]])
279-
}
280-
}
281-
final client = Mock(OkHttpClient) {
282-
newCall(_ as Request) >> {
283-
request = (Request) it[0]
284-
return call
285-
}
286-
}
287-
final aiguard = new AIGuardInternal(URL, HEADERS, client)
250+
final aiguard = mockClient(200, [data: [attributes: [reason: 'I miss something']]])
288251

289252
when:
290253
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
@@ -297,19 +260,7 @@ class AIGuardInternalTests extends DDSpecification {
297260

298261
void 'test evaluate with non JSON response'() {
299262
given:
300-
Request request = null
301-
final call = Mock(Call) {
302-
execute() >> {
303-
return mockResponse(request, 200, [data: [attributes: [reason: 'I miss something']]])
304-
}
305-
}
306-
final client = Mock(OkHttpClient) {
307-
newCall(_ as Request) >> {
308-
request = (Request) it[0]
309-
return call
310-
}
311-
}
312-
final aiguard = new AIGuardInternal(URL, HEADERS, client)
263+
final aiguard = mockClient(200, 'I am no JSON')
313264

314265
when:
315266
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
@@ -322,19 +273,7 @@ class AIGuardInternalTests extends DDSpecification {
322273

323274
void 'test evaluate with empty response'() {
324275
given:
325-
Request request = null
326-
final call = Mock(Call) {
327-
execute() >> {
328-
return mockResponse(request, 200, null)
329-
}
330-
}
331-
final client = Mock(OkHttpClient) {
332-
newCall(_ as Request) >> {
333-
request = (Request) it[0]
334-
return call
335-
}
336-
}
337-
final aiguard = new AIGuardInternal(URL, HEADERS, client)
276+
final aiguard = mockClient(200, null)
338277

339278
when:
340279
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
@@ -348,19 +287,7 @@ class AIGuardInternalTests extends DDSpecification {
348287
void 'test message length truncation'() {
349288
given:
350289
final maxMessages = Config.get().getAiGuardMaxMessagesLength()
351-
Request request = null
352-
final call = Mock(Call) {
353-
execute() >> {
354-
return mockResponse(request, 200, [data: [attributes: [action: ALLOW, reason: 'It is fine']]])
355-
}
356-
}
357-
final client = Mock(OkHttpClient) {
358-
newCall(_ as Request) >> {
359-
request = (Request) it[0]
360-
return call
361-
}
362-
}
363-
final aiguard = new AIGuardInternal(URL, HEADERS, client)
290+
final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]])
364291
final messages = (0..maxMessages)
365292
.collect { AIGuard.Message.message('user', "This is a prompt: ${it}") }
366293
.toList()
@@ -380,19 +307,7 @@ class AIGuardInternalTests extends DDSpecification {
380307
void 'test message content truncation'() {
381308
given:
382309
final maxContent = Config.get().getAiGuardMaxContentSize()
383-
Request request = null
384-
final call = Mock(Call) {
385-
execute() >> {
386-
return mockResponse(request, 200, [data: [attributes: [action: ALLOW, reason: 'It is fine']]])
387-
}
388-
}
389-
final client = Mock(OkHttpClient) {
390-
newCall(_ as Request) >> {
391-
request = (Request) it[0]
392-
return call
393-
}
394-
}
395-
final aiguard = new AIGuardInternal(URL, HEADERS, client)
310+
final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]])
396311
final message = AIGuard.Message.message("user", (0..maxContent).collect { 'A' }.join())
397312

398313
when:
@@ -426,23 +341,7 @@ class AIGuardInternalTests extends DDSpecification {
426341

427342
void 'test missing tool name'() {
428343
given:
429-
def request
430-
final call = Mock(Call) {
431-
execute() >> {
432-
return mockResponse(
433-
request,
434-
200,
435-
[data: [attributes: [action: 'ALLOW', reason: 'Just do it']]]
436-
)
437-
}
438-
}
439-
final client = Mock(OkHttpClient) {
440-
newCall(_ as Request) >> {
441-
request = (Request) it[0]
442-
return call
443-
}
444-
}
445-
final aiguard = new AIGuardInternal(URL, HEADERS, client)
344+
final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Just do it']]])
446345

447346
when:
448347
aiguard.evaluate([AIGuard.Message.tool('call_1', 'Content')], AIGuard.Options.DEFAULT)
@@ -460,6 +359,57 @@ class AIGuardInternalTests extends DDSpecification {
460359
thrown(IllegalArgumentException)
461360
}
462361

362+
void 'test message immutability'() {
363+
given:
364+
final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Just do it']]])
365+
final messages = [
366+
new AIGuard.Message(
367+
"assistant",
368+
null,
369+
[AIGuard.ToolCall.toolCall('call_1', 'execute_shell', '{"cmd": "ls -lah"}')],
370+
null
371+
)
372+
]
373+
Map<String, Object> receivedMeta
374+
375+
when:
376+
aiguard.evaluate(messages, AIGuard.Options.DEFAULT)
377+
378+
then:
379+
1 * span.finish() >> {
380+
// modify the messages before serialization
381+
messages.first().toolCalls.add(
382+
AIGuard.ToolCall.toolCall('call_2', 'execute_shell', '{"cmd": "rm -rf"}')
383+
)
384+
messages.add(AIGuard.Message.tool('call_1', 'dir1, dir2, dir3'))
385+
}
386+
1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _ as Map) >> {
387+
receivedMeta = it[1] as Map<String, Object>
388+
return span
389+
}
390+
final metaStructMessages = receivedMeta.messages as List<AIGuard.Message>
391+
metaStructMessages.size() != messages.size()
392+
metaStructMessages.size() == 1
393+
metaStructMessages.first().toolCalls.size() != messages.first().toolCalls.size()
394+
metaStructMessages.first().toolCalls.size() == 1
395+
}
396+
397+
private AIGuardInternal mockClient(final int status, final Object response) {
398+
def request
399+
final call = Stub(Call) {
400+
execute() >> {
401+
return mockResponse(request, status, response)
402+
}
403+
}
404+
final client = Stub(OkHttpClient) {
405+
newCall(_ as Request) >> {
406+
request = (Request) it[0]
407+
return call
408+
}
409+
}
410+
return new AIGuardInternal(URL, HEADERS, client)
411+
}
412+
463413
private static assertTelemetry(final String metric, final String...tags) {
464414
final metrics = WafMetricCollector.get().with {
465415
prepareMetrics()
@@ -475,6 +425,16 @@ class AIGuardInternalTests extends DDSpecification {
475425
return true
476426
}
477427

428+
private static assertMeta(final Map<String, Object> meta, final TestSuite suite) {
429+
if (suite.tags) {
430+
assert meta.attack_categories == suite.tags
431+
}
432+
final receivedMessages = snakeCaseJson(meta.messages)
433+
final expectedMessages = snakeCaseJson(suite.messages)
434+
JSONAssert.assertEquals(expectedMessages, receivedMessages, JSONCompareMode.NON_EXTENSIBLE)
435+
return true
436+
}
437+
478438
private static assertRequest(final Request request, final List<AIGuard.Message> messages) {
479439
assert request.url() == URL
480440
assert request.method() == 'POST'
@@ -556,7 +516,8 @@ class AIGuardInternalTests extends DDSpecification {
556516
", reason='" + reason + '\'' +
557517
", blocking=" + blocking +
558518
", target='" + target + '\'' +
559-
", messages=" + messages +
519+
", messages=" + messages + '\'' +
520+
", tags=" + tags +
560521
'}'
561522
}
562523
}

dd-trace-api/src/main/java/datadog/trace/api/aiguard/AIGuard.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,16 +143,19 @@ public static class Evaluation {
143143

144144
final Action action;
145145
final String reason;
146+
final List<String> tags;
146147

147148
/**
148149
* Creates a new evaluation result.
149150
*
150151
* @param action the recommended action for the evaluated content
151152
* @param reason human-readable explanation for the decision
153+
* @param tags list of tags associated with the evaluation (e.g. indirect-prompt-injection)
152154
*/
153-
public Evaluation(final Action action, final String reason) {
155+
public Evaluation(final Action action, final String reason, final List<String> tags) {
154156
this.action = action;
155157
this.reason = reason;
158+
this.tags = tags;
156159
}
157160

158161
/**
@@ -172,6 +175,15 @@ public Action getAction() {
172175
public String getReason() {
173176
return reason;
174177
}
178+
179+
/**
180+
* Returns the list of tags associated with the evaluation (e.g. indirect-prompt-injection)
181+
*
182+
* @return list of tags.
183+
*/
184+
public List<String> getTags() {
185+
return tags;
186+
}
175187
}
176188

177189
/**

dd-trace-api/src/main/java/datadog/trace/api/aiguard/noop/NoOpEvaluator.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package datadog.trace.api.aiguard.noop;
22

33
import static datadog.trace.api.aiguard.AIGuard.Action.ALLOW;
4+
import static java.util.Collections.emptyList;
45

56
import datadog.trace.api.aiguard.AIGuard.Evaluation;
67
import datadog.trace.api.aiguard.AIGuard.Message;
@@ -12,6 +13,6 @@ public final class NoOpEvaluator implements Evaluator {
1213

1314
@Override
1415
public Evaluation evaluate(final List<Message> messages, final Options options) {
15-
return new Evaluation(ALLOW, "AI Guard is not enabled");
16+
return new Evaluation(ALLOW, "AI Guard is not enabled", emptyList());
1617
}
1718
}

0 commit comments

Comments
 (0)