21
21
CustomGuardrail ,
22
22
log_guardrail_information ,
23
23
)
24
- from litellm .llms .vertex_ai .vertex_llm_base import VertexBase
25
24
from litellm .llms .custom_httpx .http_handler import (
26
25
get_async_httpx_client ,
27
26
httpxSpecialProvider ,
28
27
)
28
+ from litellm .llms .vertex_ai .vertex_llm_base import VertexBase
29
29
from litellm .proxy ._types import UserAPIKeyAuth
30
30
from litellm .types .guardrails import GuardrailEventHooks
31
31
from litellm .types .utils import (
@@ -225,15 +225,15 @@ def sanitize_file_prompt(
225
225
}
226
226
}
227
227
228
- def _should_block_content (self , armor_response : dict ) -> bool :
228
+ def _should_block_content (self , armor_response : dict , allow_sanitization : bool = False ) -> bool :
229
229
"""Check if Model Armor response indicates content should be blocked, including both inspectResult and deidentifyResult."""
230
230
sanitization_result = armor_response .get ("sanitizationResult" , {})
231
231
filter_results = sanitization_result .get ("filterResults" , {})
232
232
233
233
# filterResults can be a dict (named keys) or a list (array of filter result dicts)
234
234
filter_result_items = []
235
235
if isinstance (filter_results , dict ):
236
- filter_result_items = [ filter_results ]
236
+ filter_result_items = list ( filter_results . values ())
237
237
elif isinstance (filter_results , list ):
238
238
filter_result_items = filter_results
239
239
@@ -263,8 +263,10 @@ def _should_block_content(self, armor_response: dict) -> bool:
263
263
if sdp :
264
264
if sdp .get ("inspectResult" , {}).get ("matchState" ) == "MATCH_FOUND" :
265
265
return True
266
+ # Only block on deidentifyResult if sanitization is not allowed
266
267
if sdp .get ("deidentifyResult" , {}).get ("matchState" ) == "MATCH_FOUND" :
267
- return True
268
+ if not allow_sanitization :
269
+ return True
268
270
# Fallback dict code removed; all cases handled above
269
271
return False
270
272
@@ -278,7 +280,7 @@ def _get_sanitized_content(self, armor_response: dict) -> Optional[str]:
278
280
279
281
# filterResults can be a dict (single filter) or a list (multiple filters)
280
282
filters = (
281
- [ filter_results ]
283
+ list ( filter_results . values ())
282
284
if isinstance (filter_results , dict )
283
285
else filter_results
284
286
if isinstance (filter_results , list )
@@ -409,11 +411,11 @@ async def async_pre_call_hook(
409
411
# fail_on_error=False) we still want the correct status reflected.
410
412
metadata ["_model_armor_status" ] = (
411
413
"blocked"
412
- if self ._should_block_content (armor_response )
414
+ if self ._should_block_content (armor_response , allow_sanitization = self . mask_request_content )
413
415
else "success"
414
416
)
415
417
# Check if content should be blocked
416
- if self ._should_block_content (armor_response ):
418
+ if self ._should_block_content (armor_response , allow_sanitization = self . mask_request_content ):
417
419
raise HTTPException (
418
420
status_code = 400 ,
419
421
detail = {
@@ -494,12 +496,12 @@ async def async_post_call_success_hook(
494
496
metadata ["_model_armor_response" ] = armor_response
495
497
metadata ["_model_armor_status" ] = (
496
498
"blocked"
497
- if self ._should_block_content (armor_response )
499
+ if self ._should_block_content (armor_response , allow_sanitization = self . mask_response_content )
498
500
else "success"
499
501
)
500
502
501
503
# Check if content should be blocked
502
- if self ._should_block_content (armor_response ):
504
+ if self ._should_block_content (armor_response , allow_sanitization = self . mask_response_content ):
503
505
raise HTTPException (
504
506
status_code = 400 ,
505
507
detail = {
0 commit comments