Skip to content

Commit 9d18780

Browse files
committed
Added model armor testing files basic
1 parent 5070f5d commit 9d18780

File tree

4 files changed

+174
-0
lines changed

4 files changed

+174
-0
lines changed

git_model_armor.py

Whitespace-only changes.

test_model_armor.py

Whitespace-only changes.
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import sys
2+
import os
3+
import pytest
4+
from unittest.mock import AsyncMock
5+
from fastapi import HTTPException
6+
7+
sys.path.insert(0, os.path.abspath("../.."))
8+
9+
from litellm.proxy.guardrails.guardrail_hooks.model_armor.model_armor import ModelArmorGuardrail
10+
11+
def test_sanitize_file_prompt_builds_pdf_body():
12+
guardrail = ModelArmorGuardrail(
13+
template_id="dummy-template",
14+
project_id="dummy-project",
15+
location="us-central1",
16+
credentials=None,
17+
)
18+
file_bytes = b"%PDF-1.4 some pdf content"
19+
file_type = "PDF"
20+
body = guardrail.sanitize_file_prompt(file_bytes, file_type, source="user_prompt")
21+
assert "userPromptData" in body
22+
assert body["userPromptData"]["byteItem"]["byteDataType"] == "PDF"
23+
import base64
24+
assert body["userPromptData"]["byteItem"]["byteData"] == base64.b64encode(file_bytes).decode("utf-8")
25+
26+
@pytest.mark.asyncio
27+
async def test_make_model_armor_request_file_prompt():
28+
guardrail = ModelArmorGuardrail(
29+
template_id="dummy-template",
30+
project_id="dummy-project",
31+
location="us-central1",
32+
credentials=None,
33+
)
34+
file_bytes = b"My SSN is 123-45-6789."
35+
file_type = "PLAINTEXT_UTF8"
36+
armor_response = {
37+
"sanitizationResult": {
38+
"filterResults": [
39+
{
40+
"sdpFilterResult": {
41+
"inspectResult": {
42+
"executionState": "EXECUTION_SUCCESS",
43+
"matchState": "MATCH_FOUND",
44+
"findings": [
45+
{"infoType": "US_SOCIAL_SECURITY_NUMBER", "likelihood": "LIKELY"}
46+
]
47+
},
48+
"deidentifyResult": {
49+
"executionState": "EXECUTION_SUCCESS",
50+
"matchState": "MATCH_FOUND",
51+
"data": {"text": "My SSN is [REDACTED]."}
52+
}
53+
}
54+
}
55+
]
56+
}
57+
}
58+
class MockResponse:
59+
def __init__(self, status_code, text, json_data):
60+
self.status_code = status_code
61+
self.text = text
62+
self._json = json_data
63+
def json(self):
64+
return self._json
65+
class MockHandler:
66+
async def post(self, url, json, headers):
67+
return MockResponse(200, str(armor_response), armor_response)
68+
guardrail.async_handler = MockHandler()
69+
guardrail._ensure_access_token_async = AsyncMock(return_value=("dummy-token", "dummy-project"))
70+
result = await guardrail.make_model_armor_request(
71+
file_bytes=file_bytes,
72+
file_type=file_type,
73+
source="user_prompt"
74+
)
75+
assert result["sanitizationResult"]["filterResults"][0]["sdpFilterResult"]["deidentifyResult"]["data"]["text"] == "My SSN is [REDACTED]."
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import sys
2+
import os
3+
import pytest
4+
from unittest.mock import AsyncMock, patch
5+
from fastapi import HTTPException
6+
7+
sys.path.insert(0, os.path.abspath("../.."))
8+
9+
from litellm.proxy.guardrails.guardrail_hooks.model_armor.model_armor import ModelArmorGuardrail
10+
from litellm.proxy._types import UserAPIKeyAuth
11+
from litellm.caching.caching import DualCache
12+
13+
@pytest.mark.asyncio
14+
async def test_model_armor_pre_call_hook_inspect_and_deidentify():
15+
"""
16+
Test Model Armor guardrail pre-call hook for both inspectResult and deidentifyResult handling.
17+
"""
18+
guardrail = ModelArmorGuardrail(
19+
template_id="dummy-template",
20+
project_id="dummy-project",
21+
location="us-central1",
22+
credentials=None,
23+
)
24+
armor_response = {
25+
"sanitizationResult": {
26+
"filterResults": [
27+
{
28+
"sdpFilterResult": {
29+
"inspectResult": {
30+
"executionState": "EXECUTION_SUCCESS",
31+
"matchState": "NO_MATCH_FOUND",
32+
"findings": []
33+
},
34+
"deidentifyResult": {
35+
"executionState": "EXECUTION_SUCCESS",
36+
"matchState": "MATCH_FOUND",
37+
"data": {"text": "sanitized text here"}
38+
}
39+
}
40+
}
41+
]
42+
}
43+
}
44+
with patch.object(guardrail, "make_model_armor_request", AsyncMock(return_value=armor_response)):
45+
user_api_key_dict = UserAPIKeyAuth(api_key="test_key")
46+
cache = DualCache()
47+
data = {
48+
"messages": [
49+
{"role": "system", "content": "You are a helpful assistant."},
50+
{"role": "user", "content": "My SSN is 123-45-6789."}
51+
],
52+
"model": "gpt-3.5-turbo",
53+
"metadata": {}
54+
}
55+
guardrail.mask_request_content = True
56+
with pytest.raises(HTTPException) as exc_info:
57+
await guardrail.async_pre_call_hook(
58+
user_api_key_dict=user_api_key_dict,
59+
cache=cache,
60+
data=data,
61+
call_type="completion"
62+
)
63+
assert exc_info.value.status_code == 400
64+
assert "Content blocked by Model Armor" in str(exc_info.value.detail)
65+
66+
def test_model_armor_should_block_content():
67+
guardrail = ModelArmorGuardrail(
68+
template_id="dummy-template",
69+
project_id="dummy-project",
70+
location="us-central1",
71+
credentials=None,
72+
)
73+
# Block on inspectResult
74+
armor_response_inspect = {
75+
"sanitizationResult": {
76+
"filterResults": [
77+
{"sdpFilterResult": {"inspectResult": {"matchState": "MATCH_FOUND"}}}
78+
]
79+
}
80+
}
81+
assert guardrail._should_block_content(armor_response_inspect)
82+
# Block on deidentifyResult
83+
armor_response_deidentify = {
84+
"sanitizationResult": {
85+
"filterResults": [
86+
{"sdpFilterResult": {"deidentifyResult": {"matchState": "MATCH_FOUND"}}}
87+
]
88+
}
89+
}
90+
assert guardrail._should_block_content(armor_response_deidentify)
91+
# No block if neither
92+
armor_response_none = {
93+
"sanitizationResult": {
94+
"filterResults": [
95+
{"sdpFilterResult": {"inspectResult": {"matchState": "NO_MATCH_FOUND"}, "deidentifyResult": {"matchState": "NO_MATCH_FOUND"}}}
96+
]
97+
}
98+
}
99+
assert not guardrail._should_block_content(armor_response_none)

0 commit comments

Comments
 (0)