@@ -10,6 +10,7 @@ import datadog.trace.api.aiguard.AIGuard
1010import datadog.trace.api.telemetry.WafMetricCollector
1111import datadog.trace.bootstrap.instrumentation.api.AgentSpan
1212import datadog.trace.bootstrap.instrumentation.api.AgentTracer
13+ import datadog.trace.bootstrap.instrumentation.api.AgentTracer.TracerAPI
1314import datadog.trace.test.util.DDSpecification
1415import okhttp3.Call
1516import okhttp3.HttpUrl
@@ -71,6 +72,7 @@ class AIGuardInternalTests extends DDSpecification {
7172 @Shared
7273 protected static final PROMPT = TOOL_OUTPUT + [AIGuard.Message . message(' assistant' , ' 2 + 2 is 5' ), AIGuard.Message . message(' user' , ' ' )]
7374
75+ protected TracerAPI tracer
7476 protected AgentSpan span
7577
7678 void setup () {
@@ -81,7 +83,7 @@ class AIGuardInternalTests extends DDSpecification {
8183 final builder = Mock (AgentTracer.SpanBuilder ) {
8284 start() >> span
8385 }
84- final tracer = Stub (AgentTracer. TracerAPI ) {
86+ tracer = Stub (TracerAPI ) {
8587 buildSpan(_ as String , _ as String ) >> builder
8688 }
8789 AgentTracer . forceRegister(tracer)
@@ -198,19 +200,18 @@ class AIGuardInternalTests extends DDSpecification {
198200 1 * span. addThrowable(_ as AIGuard.AIGuardAbortError )
199201 }
200202
201- receivedMeta. messages == suite. messages
202- if (suite. tags) {
203- receivedMeta. attack_categories == suite. tags
204- }
203+ assertMeta(receivedMeta, suite)
205204 assertRequest(request, suite. messages)
206205 if (throwAbortError) {
207206 error instanceof AIGuard.AIGuardAbortError
208207 error. action == suite. action
209208 error. reason == suite. reason
209+ error. tags == suite. tags
210210 } else {
211211 error == null
212212 eval. action == suite. action
213213 eval. reason == suite. reason
214+ eval. tags == suite. tags
214215 }
215216 assertTelemetry(' ai_guard.requests' , " action:$suite . action " , " block:$throwAbortError " , ' error:false' )
216217
@@ -221,19 +222,7 @@ class AIGuardInternalTests extends DDSpecification {
221222 void ' test evaluate with API errors' () {
222223 given :
223224 final errors = [[status : 400 , title : ' Bad request' ]]
224- Request request = null
225- final call = Mock (Call ) {
226- execute() >> {
227- return mockResponse(request, 404 , [errors : errors])
228- }
229- }
230- final client = Mock (OkHttpClient ) {
231- newCall(_ as Request ) >> {
232- request = (Request ) it[0 ]
233- return call
234- }
235- }
236- final aiguard = new AIGuardInternal (URL , HEADERS , client)
225+ final aiguard = mockClient(404 , [errors : errors])
237226
238227 when :
239228 aiguard. evaluate(TOOL_CALL , AIGuard.Options . DEFAULT )
@@ -247,19 +236,7 @@ class AIGuardInternalTests extends DDSpecification {
247236
248237 void ' test evaluate with invalid JSON' () {
249238 given :
250- Request request = null
251- final call = Mock (Call ) {
252- execute() >> {
253- return mockResponse(request, 200 , [bad : ' This is an invalid response' ])
254- }
255- }
256- final client = Mock (OkHttpClient ) {
257- newCall(_ as Request ) >> {
258- request = (Request ) it[0 ]
259- return call
260- }
261- }
262- final aiguard = new AIGuardInternal (URL , HEADERS , client)
239+ final aiguard = mockClient(200 , [bad : ' This is an invalid response' ])
263240
264241 when :
265242 aiguard. evaluate(TOOL_CALL , AIGuard.Options . DEFAULT )
@@ -272,19 +249,7 @@ class AIGuardInternalTests extends DDSpecification {
272249
273250 void ' test evaluate with missing action' () {
274251 given :
275- Request request = null
276- final call = Mock (Call ) {
277- execute() >> {
278- return mockResponse(request, 200 , [data : [attributes : [reason : ' I miss something' ]]])
279- }
280- }
281- final client = Mock (OkHttpClient ) {
282- newCall(_ as Request ) >> {
283- request = (Request ) it[0 ]
284- return call
285- }
286- }
287- final aiguard = new AIGuardInternal (URL , HEADERS , client)
252+ final aiguard = mockClient(200 , [data : [attributes : [reason : ' I miss something' ]]])
288253
289254 when :
290255 aiguard. evaluate(TOOL_CALL , AIGuard.Options . DEFAULT )
@@ -297,19 +262,7 @@ class AIGuardInternalTests extends DDSpecification {
297262
298263 void ' test evaluate with non JSON response' () {
299264 given :
300- Request request = null
301- final call = Mock (Call ) {
302- execute() >> {
303- return mockResponse(request, 200 , [data : [attributes : [reason : ' I miss something' ]]])
304- }
305- }
306- final client = Mock (OkHttpClient ) {
307- newCall(_ as Request ) >> {
308- request = (Request ) it[0 ]
309- return call
310- }
311- }
312- final aiguard = new AIGuardInternal (URL , HEADERS , client)
265+ final aiguard = mockClient(200 , ' I am no JSON' )
313266
314267 when :
315268 aiguard. evaluate(TOOL_CALL , AIGuard.Options . DEFAULT )
@@ -322,19 +275,7 @@ class AIGuardInternalTests extends DDSpecification {
322275
323276 void ' test evaluate with empty response' () {
324277 given :
325- Request request = null
326- final call = Mock (Call ) {
327- execute() >> {
328- return mockResponse(request, 200 , null )
329- }
330- }
331- final client = Mock (OkHttpClient ) {
332- newCall(_ as Request ) >> {
333- request = (Request ) it[0 ]
334- return call
335- }
336- }
337- final aiguard = new AIGuardInternal (URL , HEADERS , client)
278+ final aiguard = mockClient(200 , null )
338279
339280 when :
340281 aiguard. evaluate(TOOL_CALL , AIGuard.Options . DEFAULT )
@@ -348,19 +289,7 @@ class AIGuardInternalTests extends DDSpecification {
348289 void ' test message length truncation' () {
349290 given :
350291 final maxMessages = Config . get(). getAiGuardMaxMessagesLength()
351- Request request = null
352- final call = Mock (Call ) {
353- execute() >> {
354- return mockResponse(request, 200 , [data : [attributes : [action : ALLOW , reason : ' It is fine' ]]])
355- }
356- }
357- final client = Mock (OkHttpClient ) {
358- newCall(_ as Request ) >> {
359- request = (Request ) it[0 ]
360- return call
361- }
362- }
363- final aiguard = new AIGuardInternal (URL , HEADERS , client)
292+ final aiguard = mockClient(200 , [data : [attributes : [action : ' ALLOW' , reason : ' It is fine' ]]])
364293 final messages = (0 .. maxMessages)
365294 .collect { AIGuard.Message . message(' user' , " This is a prompt: ${ it} " ) }
366295 .toList()
@@ -380,19 +309,7 @@ class AIGuardInternalTests extends DDSpecification {
380309 void ' test message content truncation' () {
381310 given :
382311 final maxContent = Config . get(). getAiGuardMaxContentSize()
383- Request request = null
384- final call = Mock (Call ) {
385- execute() >> {
386- return mockResponse(request, 200 , [data : [attributes : [action : ALLOW , reason : ' It is fine' ]]])
387- }
388- }
389- final client = Mock (OkHttpClient ) {
390- newCall(_ as Request ) >> {
391- request = (Request ) it[0 ]
392- return call
393- }
394- }
395- final aiguard = new AIGuardInternal (URL , HEADERS , client)
312+ final aiguard = mockClient(200 , [data : [attributes : [action : ' ALLOW' , reason : ' It is fine' ]]])
396313 final message = AIGuard.Message . message(" user" , (0 .. maxContent). collect { ' A' }. join())
397314
398315 when :
@@ -426,23 +343,7 @@ class AIGuardInternalTests extends DDSpecification {
426343
427344 void ' test missing tool name' () {
428345 given :
429- def request
430- final call = Mock (Call ) {
431- execute() >> {
432- return mockResponse(
433- request,
434- 200 ,
435- [data : [attributes : [action : ' ALLOW' , reason : ' Just do it' ]]]
436- )
437- }
438- }
439- final client = Mock (OkHttpClient ) {
440- newCall(_ as Request ) >> {
441- request = (Request ) it[0 ]
442- return call
443- }
444- }
445- final aiguard = new AIGuardInternal (URL , HEADERS , client)
346+ final aiguard = mockClient(200 , [data : [attributes : [action : ' ALLOW' , reason : ' Just do it' ]]])
446347
447348 when :
448349 aiguard. evaluate([AIGuard.Message . tool(' call_1' , ' Content' )], AIGuard.Options . DEFAULT )
@@ -460,6 +361,52 @@ class AIGuardInternalTests extends DDSpecification {
460361 thrown(IllegalArgumentException )
461362 }
462363
364+ void ' test message immutability' () {
365+ given :
366+ final aiguard = mockClient(200 , [data : [attributes : [action : ' ALLOW' , reason : ' Just do it' ]]])
367+ final message = new AIGuard.Message (
368+ " assistant" ,
369+ null ,
370+ [AIGuard.ToolCall . toolCall(' call_1' , ' execute_shell' , ' {"cmd": "ls -lah"}' )],
371+ null
372+ )
373+ Map<String , Object > receivedMeta
374+
375+ when :
376+ aiguard. evaluate([message], AIGuard.Options . DEFAULT )
377+
378+ then :
379+ 1 * span. finish() >> {
380+ // modify the tool calls before flushing
381+ message. toolCalls. add(
382+ AIGuard.ToolCall . toolCall(' call_2' , ' execute_shell' , ' {"cmd": "rm -rf"}' )
383+ )
384+ }
385+ 1 * span. setMetaStruct(AIGuardInternal . META_STRUCT_TAG , _ as Map ) >> {
386+ receivedMeta = it[1 ] as Map<String , Object >
387+ return span
388+ }
389+ final messages = receivedMeta. messages as List<AIGuard.Message >
390+ final assistant = messages. first()
391+ assistant. toolCalls. size() == 1
392+ }
393+
394+ private AIGuardInternal mockClient (final int status , final Object response ) {
395+ def request
396+ final call = Stub (Call ) {
397+ execute() >> {
398+ return mockResponse(request, status, response)
399+ }
400+ }
401+ final client = Stub (OkHttpClient ) {
402+ newCall(_ as Request ) >> {
403+ request = (Request ) it[0 ]
404+ return call
405+ }
406+ }
407+ return new AIGuardInternal (URL , HEADERS , client)
408+ }
409+
463410 private static assertTelemetry (final String metric , final String ...tags ) {
464411 final metrics = WafMetricCollector . get(). with {
465412 prepareMetrics()
@@ -475,6 +422,16 @@ class AIGuardInternalTests extends DDSpecification {
475422 return true
476423 }
477424
425+ private static assertMeta (final Map<String , Object > meta , final TestSuite suite ) {
426+ if (suite. tags) {
427+ assert meta. attack_categories == suite. tags
428+ }
429+ final receivedMessages = snakeCaseJson(meta. messages)
430+ final expectedMessages = snakeCaseJson(suite. messages)
431+ JSONAssert . assertEquals (expectedMessages, receivedMessages, JSONCompareMode . NON_EXTENSIBLE )
432+ return true
433+ }
434+
478435 private static assertRequest (final Request request , final List<AIGuard.Message > messages ) {
479436 assert request. url() == URL
480437 assert request. method() == ' POST'
@@ -556,7 +513,8 @@ class AIGuardInternalTests extends DDSpecification {
556513 " , reason='" + reason + ' \' ' +
557514 " , blocking=" + blocking +
558515 " , target='" + target + ' \' ' +
559- " , messages=" + messages +
516+ " , messages=" + messages + ' \' ' +
517+ " , tags=" + tags +
560518 ' }'
561519 }
562520 }
0 commit comments