Skip to content

Commit 197e98b

Browse files
Merge pull request #14758 from TeddyAmkie/model-armor-restore
Resolve Model Armor gaps (Pdfs, basic
2 parents 8628c26 + 9d18780 commit 197e98b

File tree

3 files changed

+289
-50
lines changed

3 files changed

+289
-50
lines changed

litellm/proxy/guardrails/guardrail_hooks/model_armor/model_armor.py

Lines changed: 115 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,11 @@ def _get_api_endpoint(self) -> str:
8888
def _create_sanitize_request(
8989
self, content: str, source: Literal["user_prompt", "model_response"]
9090
) -> dict:
91-
"""Create request body for Model Armor API."""
91+
"""Create request body for Model Armor API with correct camelCase field names."""
9292
if source == "user_prompt":
93-
return {"user_prompt_data": {"text": content}}
93+
return {"userPromptData": {"text": content}}
9494
else:
95-
return {"model_response_data": {"text": content}}
95+
return {"modelResponseData": {"text": content}}
9696

9797
def _extract_content_from_response(
9898
self, response: Union[Any, ModelResponse]
@@ -119,11 +119,16 @@ def _extract_content_from_response(
119119

120120
async def make_model_armor_request(
121121
self,
122-
content: str,
123-
source: Literal["user_prompt", "model_response"],
122+
content: Optional[str] = None,
123+
source: Literal["user_prompt", "model_response"] = "user_prompt",
124124
request_data: Optional[dict] = None,
125+
file_bytes: Optional[bytes] = None,
126+
file_type: Optional[str] = None,
125127
) -> dict:
126-
"""Make request to Model Armor API."""
128+
"""
129+
Make request to Model Armor API. Supports both text and file prompt sanitization.
130+
If file_bytes and file_type are provided, file prompt sanitization is performed.
131+
"""
127132
# Get access token using VertexBase auth
128133
access_token, resolved_project_id = await self._ensure_access_token_async(
129134
credentials=self.credentials,
@@ -143,7 +148,14 @@ async def make_model_armor_request(
143148
url = f"{endpoint}/v1/projects/{self.project_id}/locations/{self.location}/templates/{self.template_id}:sanitizeModelResponse"
144149

145150
# Create request body
146-
body = self._create_sanitize_request(content, source)
151+
if file_bytes is not None and file_type is not None:
152+
body = self.sanitize_file_prompt(file_bytes, file_type, source)
153+
elif content is not None:
154+
body = self._create_sanitize_request(content, source)
155+
else:
156+
raise ValueError(
157+
"Either content or file_bytes and file_type must be provided."
158+
)
147159

148160
# Set headers
149161
headers = {
@@ -189,57 +201,110 @@ async def make_model_armor_request(
189201
return await json_response
190202
return json_response
191203

204+
def sanitize_file_prompt(
205+
self, file_bytes: bytes, file_type: str, source: str = "user_prompt"
206+
) -> dict:
207+
"""
208+
Helper to build the request body for file prompt sanitization for Model Armor.
209+
file_type should be one of: PLAINTEXT_UTF8, PDF, WORD_DOCUMENT, EXCEL_DOCUMENT, POWERPOINT_DOCUMENT, TXT, CSV
210+
Returns the request body dict.
211+
"""
212+
import base64
213+
214+
base64_data = base64.b64encode(file_bytes).decode("utf-8")
215+
if source == "user_prompt":
216+
return {
217+
"userPromptData": {
218+
"byteItem": {"byteDataType": file_type, "byteData": base64_data}
219+
}
220+
}
221+
else:
222+
return {
223+
"modelResponseData": {
224+
"byteItem": {"byteDataType": file_type, "byteData": base64_data}
225+
}
226+
}
227+
192228
def _should_block_content(self, armor_response: dict) -> bool:
193-
"""Check if Model Armor response indicates content should be blocked."""
194-
# Check the sanitizationResult from Model Armor API
229+
"""Check if Model Armor response indicates content should be blocked, including both inspectResult and deidentifyResult."""
195230
sanitization_result = armor_response.get("sanitizationResult", {})
196231
filter_results = sanitization_result.get("filterResults", {})
197232

198-
# Check blocking filters (these should cause the request to be blocked)
199-
# RAI (Responsible AI) filters
200-
rai_results = filter_results.get("rai", {}).get("raiFilterResult", {})
201-
if rai_results.get("matchState") == "MATCH_FOUND":
202-
return True
203-
204-
# Prompt injection and jailbreak filters
205-
pi_jailbreak = filter_results.get("piAndJailbreakFilterResult", {})
206-
if pi_jailbreak.get("matchState") == "MATCH_FOUND":
207-
return True
208-
209-
# Malicious URI filters
210-
malicious_uri = filter_results.get("maliciousUriFilterResult", {})
211-
if malicious_uri.get("matchState") == "MATCH_FOUND":
212-
return True
213-
214-
# CSAM filters
215-
csam = filter_results.get("csamFilterFilterResult", {})
216-
if csam.get("matchState") == "MATCH_FOUND":
217-
return True
218-
219-
# Virus scan filters
220-
virus_scan = filter_results.get("virusScanFilterResult", {})
221-
if virus_scan.get("matchState") == "MATCH_FOUND":
222-
return True
223-
233+
# filterResults can be a dict (named keys) or a list (array of filter result dicts)
234+
filter_result_items = []
235+
if isinstance(filter_results, dict):
236+
filter_result_items = [filter_results]
237+
elif isinstance(filter_results, list):
238+
filter_result_items = filter_results
239+
240+
for filt in filter_result_items:
241+
# Check RAI, PI/Jailbreak, Malicious URI, CSAM, Virus scan as before
242+
if filt.get("raiFilterResult", {}).get("matchState") == "MATCH_FOUND":
243+
return True
244+
if (
245+
filt.get("piAndJailbreakFilterResult", {}).get("matchState")
246+
== "MATCH_FOUND"
247+
):
248+
return True
249+
if (
250+
filt.get("maliciousUriFilterResult", {}).get("matchState")
251+
== "MATCH_FOUND"
252+
):
253+
return True
254+
if (
255+
filt.get("csamFilterFilterResult", {}).get("matchState")
256+
== "MATCH_FOUND"
257+
):
258+
return True
259+
if filt.get("virusScanFilterResult", {}).get("matchState") == "MATCH_FOUND":
260+
return True
261+
# Check sdpFilterResult for both inspectResult and deidentifyResult
262+
sdp = filt.get("sdpFilterResult")
263+
if sdp:
264+
if sdp.get("inspectResult", {}).get("matchState") == "MATCH_FOUND":
265+
return True
266+
if sdp.get("deidentifyResult", {}).get("matchState") == "MATCH_FOUND":
267+
return True
268+
# Fallback dict code removed; all cases handled above
224269
return False
225270

226271
def _get_sanitized_content(self, armor_response: dict) -> Optional[str]:
227-
"""Extract sanitized content from Model Armor response."""
228-
# Model Armor returns sanitized content in the sanitizationResult
229-
sanitization_result = armor_response.get("sanitizationResult", {})
230-
231-
# Check for sdp structure (for deidentification)
232-
filter_results = sanitization_result.get("filterResults", {})
233-
sdp = filter_results.get("sdp", {}).get("sdpFilterResult")
234-
235-
if sdp is not None:
236-
# Model Armor returns sanitized text under deidentifyResult in sdp
237-
deidentify_result = sdp.get("deidentifyResult", {})
238-
sanitized_text = deidentify_result.get("data", {}).get("text", "")
239-
if deidentify_result.get("matchState") == "MATCH_FOUND" and sanitized_text:
240-
return sanitized_text
272+
"""
273+
Get the sanitized content from a Model Armor response, if available.
274+
Looks for sanitized text in deidentifyResult, and falls back to root-level fields if not found.
275+
"""
276+
result = armor_response.get("sanitizationResult", {})
277+
filter_results = result.get("filterResults", {})
278+
279+
# filterResults can be a dict (single filter) or a list (multiple filters)
280+
filters = (
281+
[filter_results]
282+
if isinstance(filter_results, dict)
283+
else filter_results
284+
if isinstance(filter_results, list)
285+
else []
286+
)
241287

242-
# Fallback to checking root level
288+
# Prefer sanitized text from deidentifyResult if present
289+
for filter_entry in filters:
290+
sdp = filter_entry.get("sdpFilterResult")
291+
if sdp:
292+
deid = sdp.get("deidentifyResult", {})
293+
sanitized = deid.get("data", {}).get("text", "")
294+
# If Model Armor found something and returned a sanitized version, use it
295+
if deid.get("matchState") == "MATCH_FOUND" and sanitized:
296+
return sanitized
297+
298+
# If no deidentifyResult, optionally check for inspectResult (rare, but could have findings)
299+
for filter_entry in filters:
300+
sdp = filter_entry.get("sdpFilterResult")
301+
if sdp:
302+
inspect = sdp.get("inspectResult", {})
303+
# If Model Armor flagged something but didn't sanitize, return None
304+
if inspect.get("matchState") == "MATCH_FOUND":
305+
return None
306+
307+
# Fallback: if Model Armor put sanitized text at the root, use it
243308
return armor_response.get("sanitizedText") or armor_response.get("text")
244309

245310
def _process_response(
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import sys
2+
import os
3+
import pytest
4+
from unittest.mock import AsyncMock
5+
from fastapi import HTTPException
6+
7+
sys.path.insert(0, os.path.abspath("../.."))
8+
9+
from litellm.proxy.guardrails.guardrail_hooks.model_armor.model_armor import ModelArmorGuardrail
10+
11+
def test_sanitize_file_prompt_builds_pdf_body():
12+
guardrail = ModelArmorGuardrail(
13+
template_id="dummy-template",
14+
project_id="dummy-project",
15+
location="us-central1",
16+
credentials=None,
17+
)
18+
file_bytes = b"%PDF-1.4 some pdf content"
19+
file_type = "PDF"
20+
body = guardrail.sanitize_file_prompt(file_bytes, file_type, source="user_prompt")
21+
assert "userPromptData" in body
22+
assert body["userPromptData"]["byteItem"]["byteDataType"] == "PDF"
23+
import base64
24+
assert body["userPromptData"]["byteItem"]["byteData"] == base64.b64encode(file_bytes).decode("utf-8")
25+
26+
@pytest.mark.asyncio
27+
async def test_make_model_armor_request_file_prompt():
28+
guardrail = ModelArmorGuardrail(
29+
template_id="dummy-template",
30+
project_id="dummy-project",
31+
location="us-central1",
32+
credentials=None,
33+
)
34+
file_bytes = b"My SSN is 123-45-6789."
35+
file_type = "PLAINTEXT_UTF8"
36+
armor_response = {
37+
"sanitizationResult": {
38+
"filterResults": [
39+
{
40+
"sdpFilterResult": {
41+
"inspectResult": {
42+
"executionState": "EXECUTION_SUCCESS",
43+
"matchState": "MATCH_FOUND",
44+
"findings": [
45+
{"infoType": "US_SOCIAL_SECURITY_NUMBER", "likelihood": "LIKELY"}
46+
]
47+
},
48+
"deidentifyResult": {
49+
"executionState": "EXECUTION_SUCCESS",
50+
"matchState": "MATCH_FOUND",
51+
"data": {"text": "My SSN is [REDACTED]."}
52+
}
53+
}
54+
}
55+
]
56+
}
57+
}
58+
class MockResponse:
59+
def __init__(self, status_code, text, json_data):
60+
self.status_code = status_code
61+
self.text = text
62+
self._json = json_data
63+
def json(self):
64+
return self._json
65+
class MockHandler:
66+
async def post(self, url, json, headers):
67+
return MockResponse(200, str(armor_response), armor_response)
68+
guardrail.async_handler = MockHandler()
69+
guardrail._ensure_access_token_async = AsyncMock(return_value=("dummy-token", "dummy-project"))
70+
result = await guardrail.make_model_armor_request(
71+
file_bytes=file_bytes,
72+
file_type=file_type,
73+
source="user_prompt"
74+
)
75+
assert result["sanitizationResult"]["filterResults"][0]["sdpFilterResult"]["deidentifyResult"]["data"]["text"] == "My SSN is [REDACTED]."
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import sys
2+
import os
3+
import pytest
4+
from unittest.mock import AsyncMock, patch
5+
from fastapi import HTTPException
6+
7+
sys.path.insert(0, os.path.abspath("../.."))
8+
9+
from litellm.proxy.guardrails.guardrail_hooks.model_armor.model_armor import ModelArmorGuardrail
10+
from litellm.proxy._types import UserAPIKeyAuth
11+
from litellm.caching.caching import DualCache
12+
13+
@pytest.mark.asyncio
14+
async def test_model_armor_pre_call_hook_inspect_and_deidentify():
15+
"""
16+
Test Model Armor guardrail pre-call hook for both inspectResult and deidentifyResult handling.
17+
"""
18+
guardrail = ModelArmorGuardrail(
19+
template_id="dummy-template",
20+
project_id="dummy-project",
21+
location="us-central1",
22+
credentials=None,
23+
)
24+
armor_response = {
25+
"sanitizationResult": {
26+
"filterResults": [
27+
{
28+
"sdpFilterResult": {
29+
"inspectResult": {
30+
"executionState": "EXECUTION_SUCCESS",
31+
"matchState": "NO_MATCH_FOUND",
32+
"findings": []
33+
},
34+
"deidentifyResult": {
35+
"executionState": "EXECUTION_SUCCESS",
36+
"matchState": "MATCH_FOUND",
37+
"data": {"text": "sanitized text here"}
38+
}
39+
}
40+
}
41+
]
42+
}
43+
}
44+
with patch.object(guardrail, "make_model_armor_request", AsyncMock(return_value=armor_response)):
45+
user_api_key_dict = UserAPIKeyAuth(api_key="test_key")
46+
cache = DualCache()
47+
data = {
48+
"messages": [
49+
{"role": "system", "content": "You are a helpful assistant."},
50+
{"role": "user", "content": "My SSN is 123-45-6789."}
51+
],
52+
"model": "gpt-3.5-turbo",
53+
"metadata": {}
54+
}
55+
guardrail.mask_request_content = True
56+
with pytest.raises(HTTPException) as exc_info:
57+
await guardrail.async_pre_call_hook(
58+
user_api_key_dict=user_api_key_dict,
59+
cache=cache,
60+
data=data,
61+
call_type="completion"
62+
)
63+
assert exc_info.value.status_code == 400
64+
assert "Content blocked by Model Armor" in str(exc_info.value.detail)
65+
66+
def test_model_armor_should_block_content():
67+
guardrail = ModelArmorGuardrail(
68+
template_id="dummy-template",
69+
project_id="dummy-project",
70+
location="us-central1",
71+
credentials=None,
72+
)
73+
# Block on inspectResult
74+
armor_response_inspect = {
75+
"sanitizationResult": {
76+
"filterResults": [
77+
{"sdpFilterResult": {"inspectResult": {"matchState": "MATCH_FOUND"}}}
78+
]
79+
}
80+
}
81+
assert guardrail._should_block_content(armor_response_inspect)
82+
# Block on deidentifyResult
83+
armor_response_deidentify = {
84+
"sanitizationResult": {
85+
"filterResults": [
86+
{"sdpFilterResult": {"deidentifyResult": {"matchState": "MATCH_FOUND"}}}
87+
]
88+
}
89+
}
90+
assert guardrail._should_block_content(armor_response_deidentify)
91+
# No block if neither
92+
armor_response_none = {
93+
"sanitizationResult": {
94+
"filterResults": [
95+
{"sdpFilterResult": {"inspectResult": {"matchState": "NO_MATCH_FOUND"}, "deidentifyResult": {"matchState": "NO_MATCH_FOUND"}}}
96+
]
97+
}
98+
}
99+
assert not guardrail._should_block_content(armor_response_none)

0 commit comments

Comments
 (0)