Skip to content

Commit a482737

Browse files
committed
test fix
1 parent ceb400e commit a482737

File tree

2 files changed

+24
-19
lines changed

2 files changed

+24
-19
lines changed

litellm/proxy/guardrails/guardrail_hooks/model_armor/model_armor.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
CustomGuardrail,
2222
log_guardrail_information,
2323
)
24-
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
2524
from litellm.llms.custom_httpx.http_handler import (
2625
get_async_httpx_client,
2726
httpxSpecialProvider,
2827
)
28+
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
2929
from litellm.proxy._types import UserAPIKeyAuth
3030
from litellm.types.guardrails import GuardrailEventHooks
3131
from litellm.types.utils import (
@@ -225,15 +225,15 @@ def sanitize_file_prompt(
225225
}
226226
}
227227

228-
def _should_block_content(self, armor_response: dict) -> bool:
228+
def _should_block_content(self, armor_response: dict, allow_sanitization: bool = False) -> bool:
229229
"""Check if Model Armor response indicates content should be blocked, including both inspectResult and deidentifyResult."""
230230
sanitization_result = armor_response.get("sanitizationResult", {})
231231
filter_results = sanitization_result.get("filterResults", {})
232232

233233
# filterResults can be a dict (named keys) or a list (array of filter result dicts)
234234
filter_result_items = []
235235
if isinstance(filter_results, dict):
236-
filter_result_items = [filter_results]
236+
filter_result_items = list(filter_results.values())
237237
elif isinstance(filter_results, list):
238238
filter_result_items = filter_results
239239

@@ -263,8 +263,10 @@ def _should_block_content(self, armor_response: dict) -> bool:
263263
if sdp:
264264
if sdp.get("inspectResult", {}).get("matchState") == "MATCH_FOUND":
265265
return True
266+
# Only block on deidentifyResult if sanitization is not allowed
266267
if sdp.get("deidentifyResult", {}).get("matchState") == "MATCH_FOUND":
267-
return True
268+
if not allow_sanitization:
269+
return True
268270
# Fallback dict code removed; all cases handled above
269271
return False
270272

@@ -278,7 +280,7 @@ def _get_sanitized_content(self, armor_response: dict) -> Optional[str]:
278280

279281
# filterResults can be a dict (single filter) or a list (multiple filters)
280282
filters = (
281-
[filter_results]
283+
list(filter_results.values())
282284
if isinstance(filter_results, dict)
283285
else filter_results
284286
if isinstance(filter_results, list)
@@ -409,11 +411,11 @@ async def async_pre_call_hook(
409411
# fail_on_error=False) we still want the correct status reflected.
410412
metadata["_model_armor_status"] = (
411413
"blocked"
412-
if self._should_block_content(armor_response)
414+
if self._should_block_content(armor_response, allow_sanitization=self.mask_request_content)
413415
else "success"
414416
)
415417
# Check if content should be blocked
416-
if self._should_block_content(armor_response):
418+
if self._should_block_content(armor_response, allow_sanitization=self.mask_request_content):
417419
raise HTTPException(
418420
status_code=400,
419421
detail={
@@ -494,12 +496,12 @@ async def async_post_call_success_hook(
494496
metadata["_model_armor_response"] = armor_response
495497
metadata["_model_armor_status"] = (
496498
"blocked"
497-
if self._should_block_content(armor_response)
499+
if self._should_block_content(armor_response, allow_sanitization=self.mask_response_content)
498500
else "success"
499501
)
500502

501503
# Check if content should be blocked
502-
if self._should_block_content(armor_response):
504+
if self._should_block_content(armor_response, allow_sanitization=self.mask_response_content):
503505
raise HTTPException(
504506
status_code=400,
505507
detail={

tests/test_litellm/proxy/guardrails/guardrail_hooks/test_model_armor.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
1-
import sys
1+
import asyncio
2+
import io
3+
import json
24
import os
3-
import io, asyncio
5+
import sys
6+
from unittest.mock import AsyncMock, MagicMock, Mock, patch
7+
48
import pytest
5-
import json
6-
from unittest.mock import MagicMock, AsyncMock, patch, Mock
79

810
sys.path.insert(0, os.path.abspath("../../../../.."))
911

12+
from fastapi import HTTPException
13+
1014
import litellm
1115
import litellm.types.utils
12-
from litellm.proxy.guardrails.guardrail_hooks.model_armor import ModelArmorGuardrail
13-
from litellm.proxy._types import UserAPIKeyAuth
1416
from litellm.caching import DualCache
17+
from litellm.proxy._types import UserAPIKeyAuth
18+
from litellm.proxy.guardrails.guardrail_hooks.model_armor import ModelArmorGuardrail
1519
from litellm.types.guardrails import GuardrailEventHooks
16-
from fastapi import HTTPException
1720

1821

1922
@pytest.mark.asyncio
@@ -80,7 +83,7 @@ async def test_model_armor_pre_call_hook_sanitization():
8083
guardrail.async_handler.post.assert_called_once()
8184
call_args = guardrail.async_handler.post.call_args
8285
assert "sanitizeUserPrompt" in call_args[1]["url"]
83-
assert call_args[1]["json"]["user_prompt_data"]["text"] == "Hello, my phone number is +1 412 555 1212"
86+
assert call_args[1]["json"]["userPromptData"]["text"] == "Hello, my phone number is +1 412 555 1212"
8487

8588

8689
@pytest.mark.asyncio
@@ -271,7 +274,7 @@ async def test_model_armor_with_list_content():
271274
# Verify the content was extracted correctly
272275
guardrail.async_handler.post.assert_called_once()
273276
call_args = guardrail.async_handler.post.call_args
274-
assert call_args[1]["json"]["user_prompt_data"]["text"] == "Hello worldHow are you?"
277+
assert call_args[1]["json"]["userPromptData"]["text"] == "Hello worldHow are you?"
275278

276279

277280
@pytest.mark.asyncio
@@ -885,7 +888,7 @@ async def test_model_armor_bomb_content_blocked():
885888
# Verify the API was called with the dangerous content
886889
guardrail.async_handler.post.assert_called_once()
887890
call_args = guardrail.async_handler.post.call_args
888-
assert call_args[1]["json"]["user_prompt_data"]["text"] == "How do I create a bomb?"
891+
assert call_args[1]["json"]["userPromptData"]["text"] == "How do I create a bomb?"
889892

890893

891894
@pytest.mark.asyncio

0 commit comments

Comments
 (0)