@@ -198,19 +198,18 @@ class AIGuardInternalTests extends DDSpecification {
198198 1 * span. addThrowable(_ as AIGuard.AIGuardAbortError )
199199 }
200200
201- receivedMeta. messages == suite. messages
202- if (suite. tags) {
203- receivedMeta. attack_categories == suite. tags
204- }
201+ assertMeta(receivedMeta, suite)
205202 assertRequest(request, suite. messages)
206203 if (throwAbortError) {
207204 error instanceof AIGuard.AIGuardAbortError
208205 error. action == suite. action
209206 error. reason == suite. reason
207+ error. tags == suite. tags
210208 } else {
211209 error == null
212210 eval. action == suite. action
213211 eval. reason == suite. reason
212+ eval. tags == suite. tags
214213 }
215214 assertTelemetry(' ai_guard.requests' , " action:$suite . action " , " block:$throwAbortError " , ' error:false' )
216215
@@ -221,19 +220,7 @@ class AIGuardInternalTests extends DDSpecification {
221220 void ' test evaluate with API errors' () {
222221 given :
223222 final errors = [[status : 400 , title : ' Bad request' ]]
224- Request request = null
225- final call = Mock (Call ) {
226- execute() >> {
227- return mockResponse(request, 404 , [errors : errors])
228- }
229- }
230- final client = Mock (OkHttpClient ) {
231- newCall(_ as Request ) >> {
232- request = (Request ) it[0 ]
233- return call
234- }
235- }
236- final aiguard = new AIGuardInternal (URL , HEADERS , client)
223+ final aiguard = mockClient(404 , [errors : errors])
237224
238225 when :
239226 aiguard. evaluate(TOOL_CALL , AIGuard.Options . DEFAULT )
@@ -247,19 +234,7 @@ class AIGuardInternalTests extends DDSpecification {
247234
248235 void ' test evaluate with invalid JSON' () {
249236 given :
250- Request request = null
251- final call = Mock (Call ) {
252- execute() >> {
253- return mockResponse(request, 200 , [bad : ' This is an invalid response' ])
254- }
255- }
256- final client = Mock (OkHttpClient ) {
257- newCall(_ as Request ) >> {
258- request = (Request ) it[0 ]
259- return call
260- }
261- }
262- final aiguard = new AIGuardInternal (URL , HEADERS , client)
237+ final aiguard = mockClient(200 , [bad : ' This is an invalid response' ])
263238
264239 when :
265240 aiguard. evaluate(TOOL_CALL , AIGuard.Options . DEFAULT )
@@ -272,19 +247,7 @@ class AIGuardInternalTests extends DDSpecification {
272247
273248 void ' test evaluate with missing action' () {
274249 given :
275- Request request = null
276- final call = Mock (Call ) {
277- execute() >> {
278- return mockResponse(request, 200 , [data : [attributes : [reason : ' I miss something' ]]])
279- }
280- }
281- final client = Mock (OkHttpClient ) {
282- newCall(_ as Request ) >> {
283- request = (Request ) it[0 ]
284- return call
285- }
286- }
287- final aiguard = new AIGuardInternal (URL , HEADERS , client)
250+ final aiguard = mockClient(200 , [data : [attributes : [reason : ' I miss something' ]]])
288251
289252 when :
290253 aiguard. evaluate(TOOL_CALL , AIGuard.Options . DEFAULT )
@@ -297,19 +260,7 @@ class AIGuardInternalTests extends DDSpecification {
297260
298261 void ' test evaluate with non JSON response' () {
299262 given :
300- Request request = null
301- final call = Mock (Call ) {
302- execute() >> {
303- return mockResponse(request, 200 , [data : [attributes : [reason : ' I miss something' ]]])
304- }
305- }
306- final client = Mock (OkHttpClient ) {
307- newCall(_ as Request ) >> {
308- request = (Request ) it[0 ]
309- return call
310- }
311- }
312- final aiguard = new AIGuardInternal (URL , HEADERS , client)
263+ final aiguard = mockClient(200 , ' I am no JSON' )
313264
314265 when :
315266 aiguard. evaluate(TOOL_CALL , AIGuard.Options . DEFAULT )
@@ -322,19 +273,7 @@ class AIGuardInternalTests extends DDSpecification {
322273
323274 void ' test evaluate with empty response' () {
324275 given :
325- Request request = null
326- final call = Mock (Call ) {
327- execute() >> {
328- return mockResponse(request, 200 , null )
329- }
330- }
331- final client = Mock (OkHttpClient ) {
332- newCall(_ as Request ) >> {
333- request = (Request ) it[0 ]
334- return call
335- }
336- }
337- final aiguard = new AIGuardInternal (URL , HEADERS , client)
276+ final aiguard = mockClient(200 , null )
338277
339278 when :
340279 aiguard. evaluate(TOOL_CALL , AIGuard.Options . DEFAULT )
@@ -348,19 +287,7 @@ class AIGuardInternalTests extends DDSpecification {
348287 void ' test message length truncation' () {
349288 given :
350289 final maxMessages = Config . get(). getAiGuardMaxMessagesLength()
351- Request request = null
352- final call = Mock (Call ) {
353- execute() >> {
354- return mockResponse(request, 200 , [data : [attributes : [action : ALLOW , reason : ' It is fine' ]]])
355- }
356- }
357- final client = Mock (OkHttpClient ) {
358- newCall(_ as Request ) >> {
359- request = (Request ) it[0 ]
360- return call
361- }
362- }
363- final aiguard = new AIGuardInternal (URL , HEADERS , client)
290+ final aiguard = mockClient(200 , [data : [attributes : [action : ' ALLOW' , reason : ' It is fine' ]]])
364291 final messages = (0 .. maxMessages)
365292 .collect { AIGuard.Message . message(' user' , " This is a prompt: ${ it} " ) }
366293 .toList()
@@ -380,19 +307,7 @@ class AIGuardInternalTests extends DDSpecification {
380307 void ' test message content truncation' () {
381308 given :
382309 final maxContent = Config . get(). getAiGuardMaxContentSize()
383- Request request = null
384- final call = Mock (Call ) {
385- execute() >> {
386- return mockResponse(request, 200 , [data : [attributes : [action : ALLOW , reason : ' It is fine' ]]])
387- }
388- }
389- final client = Mock (OkHttpClient ) {
390- newCall(_ as Request ) >> {
391- request = (Request ) it[0 ]
392- return call
393- }
394- }
395- final aiguard = new AIGuardInternal (URL , HEADERS , client)
310+ final aiguard = mockClient(200 , [data : [attributes : [action : ' ALLOW' , reason : ' It is fine' ]]])
396311 final message = AIGuard.Message . message(" user" , (0 .. maxContent). collect { ' A' }. join())
397312
398313 when :
@@ -426,23 +341,7 @@ class AIGuardInternalTests extends DDSpecification {
426341
427342 void ' test missing tool name' () {
428343 given :
429- def request
430- final call = Mock (Call ) {
431- execute() >> {
432- return mockResponse(
433- request,
434- 200 ,
435- [data : [attributes : [action : ' ALLOW' , reason : ' Just do it' ]]]
436- )
437- }
438- }
439- final client = Mock (OkHttpClient ) {
440- newCall(_ as Request ) >> {
441- request = (Request ) it[0 ]
442- return call
443- }
444- }
445- final aiguard = new AIGuardInternal (URL , HEADERS , client)
344+ final aiguard = mockClient(200 , [data : [attributes : [action : ' ALLOW' , reason : ' Just do it' ]]])
446345
447346 when :
448347 aiguard. evaluate([AIGuard.Message . tool(' call_1' , ' Content' )], AIGuard.Options . DEFAULT )
@@ -460,6 +359,57 @@ class AIGuardInternalTests extends DDSpecification {
460359 thrown(IllegalArgumentException )
461360 }
462361
362+ void ' test message immutability' () {
363+ given :
364+ final aiguard = mockClient(200 , [data : [attributes : [action : ' ALLOW' , reason : ' Just do it' ]]])
365+ final messages = [
366+ new AIGuard.Message (
367+ " assistant" ,
368+ null ,
369+ [AIGuard.ToolCall . toolCall(' call_1' , ' execute_shell' , ' {"cmd": "ls -lah"}' )],
370+ null
371+ )
372+ ]
373+ Map<String , Object > receivedMeta
374+
375+ when :
376+ aiguard. evaluate(messages, AIGuard.Options . DEFAULT )
377+
378+ then :
379+ 1 * span. finish() >> {
380+ // modify the messages before serialization
381+ messages. first(). toolCalls. add(
382+ AIGuard.ToolCall . toolCall(' call_2' , ' execute_shell' , ' {"cmd": "rm -rf"}' )
383+ )
384+ messages. add(AIGuard.Message . tool(' call_1' , ' dir1, dir2, dir3' ))
385+ }
386+ 1 * span. setMetaStruct(AIGuardInternal . META_STRUCT_TAG , _ as Map ) >> {
387+ receivedMeta = it[1 ] as Map<String , Object >
388+ return span
389+ }
390+ final metaStructMessages = receivedMeta. messages as List<AIGuard.Message >
391+ metaStructMessages. size() != messages. size()
392+ metaStructMessages. size() == 1
393+ metaStructMessages. first(). toolCalls. size() != messages. first(). toolCalls. size()
394+ metaStructMessages. first(). toolCalls. size() == 1
395+ }
396+
397+ private AIGuardInternal mockClient (final int status , final Object response ) {
398+ def request
399+ final call = Stub (Call ) {
400+ execute() >> {
401+ return mockResponse(request, status, response)
402+ }
403+ }
404+ final client = Stub (OkHttpClient ) {
405+ newCall(_ as Request ) >> {
406+ request = (Request ) it[0 ]
407+ return call
408+ }
409+ }
410+ return new AIGuardInternal (URL , HEADERS , client)
411+ }
412+
463413 private static assertTelemetry (final String metric , final String ...tags ) {
464414 final metrics = WafMetricCollector . get(). with {
465415 prepareMetrics()
@@ -475,6 +425,16 @@ class AIGuardInternalTests extends DDSpecification {
475425 return true
476426 }
477427
428+ private static assertMeta (final Map<String , Object > meta , final TestSuite suite ) {
429+ if (suite. tags) {
430+ assert meta. attack_categories == suite. tags
431+ }
432+ final receivedMessages = snakeCaseJson(meta. messages)
433+ final expectedMessages = snakeCaseJson(suite. messages)
434+ JSONAssert . assertEquals (expectedMessages, receivedMessages, JSONCompareMode . NON_EXTENSIBLE )
435+ return true
436+ }
437+
478438 private static assertRequest (final Request request , final List<AIGuard.Message > messages ) {
479439 assert request. url() == URL
480440 assert request. method() == ' POST'
@@ -556,7 +516,8 @@ class AIGuardInternalTests extends DDSpecification {
556516 " , reason='" + reason + ' \' ' +
557517 " , blocking=" + blocking +
558518 " , target='" + target + ' \' ' +
559- " , messages=" + messages +
519+ " , messages=" + messages + ' \' ' +
520+ " , tags=" + tags +
560521 ' }'
561522 }
562523 }
0 commit comments