Skip to content

Commit 5070f5d

Browse files
committed
Ensure basic detection format proper
1 parent be193fb commit 5070f5d

File tree

3 files changed

+115
-50
lines changed

3 files changed

+115
-50
lines changed

git_model_armor.py

Whitespace-only changes.

litellm/proxy/guardrails/guardrail_hooks/model_armor/model_armor.py

Lines changed: 115 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,11 @@ def _get_api_endpoint(self) -> str:
8888
def _create_sanitize_request(
8989
self, content: str, source: Literal["user_prompt", "model_response"]
9090
) -> dict:
91-
"""Create request body for Model Armor API."""
91+
"""Create request body for Model Armor API with correct camelCase field names."""
9292
if source == "user_prompt":
93-
return {"user_prompt_data": {"text": content}}
93+
return {"userPromptData": {"text": content}}
9494
else:
95-
return {"model_response_data": {"text": content}}
95+
return {"modelResponseData": {"text": content}}
9696

9797
def _extract_content_from_response(
9898
self, response: Union[Any, ModelResponse]
@@ -119,11 +119,16 @@ def _extract_content_from_response(
119119

120120
async def make_model_armor_request(
121121
self,
122-
content: str,
123-
source: Literal["user_prompt", "model_response"],
122+
content: Optional[str] = None,
123+
source: Literal["user_prompt", "model_response"] = "user_prompt",
124124
request_data: Optional[dict] = None,
125+
file_bytes: Optional[bytes] = None,
126+
file_type: Optional[str] = None,
125127
) -> dict:
126-
"""Make request to Model Armor API."""
128+
"""
129+
Make request to Model Armor API. Supports both text and file prompt sanitization.
130+
If file_bytes and file_type are provided, file prompt sanitization is performed.
131+
"""
127132
# Get access token using VertexBase auth
128133
access_token, resolved_project_id = await self._ensure_access_token_async(
129134
credentials=self.credentials,
@@ -143,7 +148,14 @@ async def make_model_armor_request(
143148
url = f"{endpoint}/v1/projects/{self.project_id}/locations/{self.location}/templates/{self.template_id}:sanitizeModelResponse"
144149

145150
# Create request body
146-
body = self._create_sanitize_request(content, source)
151+
if file_bytes is not None and file_type is not None:
152+
body = self.sanitize_file_prompt(file_bytes, file_type, source)
153+
elif content is not None:
154+
body = self._create_sanitize_request(content, source)
155+
else:
156+
raise ValueError(
157+
"Either content or file_bytes and file_type must be provided."
158+
)
147159

148160
# Set headers
149161
headers = {
@@ -189,57 +201,110 @@ async def make_model_armor_request(
189201
return await json_response
190202
return json_response
191203

204+
def sanitize_file_prompt(
205+
self, file_bytes: bytes, file_type: str, source: str = "user_prompt"
206+
) -> dict:
207+
"""
208+
Helper to build the request body for file prompt sanitization for Model Armor.
209+
file_type should be one of: PLAINTEXT_UTF8, PDF, WORD_DOCUMENT, EXCEL_DOCUMENT, POWERPOINT_DOCUMENT, TXT, CSV
210+
Returns the request body dict.
211+
"""
212+
import base64
213+
214+
base64_data = base64.b64encode(file_bytes).decode("utf-8")
215+
if source == "user_prompt":
216+
return {
217+
"userPromptData": {
218+
"byteItem": {"byteDataType": file_type, "byteData": base64_data}
219+
}
220+
}
221+
else:
222+
return {
223+
"modelResponseData": {
224+
"byteItem": {"byteDataType": file_type, "byteData": base64_data}
225+
}
226+
}
227+
192228
def _should_block_content(self, armor_response: dict) -> bool:
193-
"""Check if Model Armor response indicates content should be blocked."""
194-
# Check the sanitizationResult from Model Armor API
229+
"""Check if Model Armor response indicates content should be blocked, including both inspectResult and deidentifyResult."""
195230
sanitization_result = armor_response.get("sanitizationResult", {})
196231
filter_results = sanitization_result.get("filterResults", {})
197232

198-
# Check blocking filters (these should cause the request to be blocked)
199-
# RAI (Responsible AI) filters
200-
rai_results = filter_results.get("rai", {}).get("raiFilterResult", {})
201-
if rai_results.get("matchState") == "MATCH_FOUND":
202-
return True
203-
204-
# Prompt injection and jailbreak filters
205-
pi_jailbreak = filter_results.get("piAndJailbreakFilterResult", {})
206-
if pi_jailbreak.get("matchState") == "MATCH_FOUND":
207-
return True
208-
209-
# Malicious URI filters
210-
malicious_uri = filter_results.get("maliciousUriFilterResult", {})
211-
if malicious_uri.get("matchState") == "MATCH_FOUND":
212-
return True
213-
214-
# CSAM filters
215-
csam = filter_results.get("csamFilterFilterResult", {})
216-
if csam.get("matchState") == "MATCH_FOUND":
217-
return True
218-
219-
# Virus scan filters
220-
virus_scan = filter_results.get("virusScanFilterResult", {})
221-
if virus_scan.get("matchState") == "MATCH_FOUND":
222-
return True
223-
233+
# filterResults can be a dict (named keys) or a list (array of filter result dicts)
234+
filter_result_items = []
235+
if isinstance(filter_results, dict):
236+
filter_result_items = [filter_results]
237+
elif isinstance(filter_results, list):
238+
filter_result_items = filter_results
239+
240+
for filt in filter_result_items:
241+
# Check RAI, PI/Jailbreak, Malicious URI, CSAM, Virus scan as before
242+
if filt.get("raiFilterResult", {}).get("matchState") == "MATCH_FOUND":
243+
return True
244+
if (
245+
filt.get("piAndJailbreakFilterResult", {}).get("matchState")
246+
== "MATCH_FOUND"
247+
):
248+
return True
249+
if (
250+
filt.get("maliciousUriFilterResult", {}).get("matchState")
251+
== "MATCH_FOUND"
252+
):
253+
return True
254+
if (
255+
filt.get("csamFilterFilterResult", {}).get("matchState")
256+
== "MATCH_FOUND"
257+
):
258+
return True
259+
if filt.get("virusScanFilterResult", {}).get("matchState") == "MATCH_FOUND":
260+
return True
261+
# Check sdpFilterResult for both inspectResult and deidentifyResult
262+
sdp = filt.get("sdpFilterResult")
263+
if sdp:
264+
if sdp.get("inspectResult", {}).get("matchState") == "MATCH_FOUND":
265+
return True
266+
if sdp.get("deidentifyResult", {}).get("matchState") == "MATCH_FOUND":
267+
return True
268+
# Fallback dict code removed; all cases handled above
224269
return False
225270

226271
def _get_sanitized_content(self, armor_response: dict) -> Optional[str]:
227-
"""Extract sanitized content from Model Armor response."""
228-
# Model Armor returns sanitized content in the sanitizationResult
229-
sanitization_result = armor_response.get("sanitizationResult", {})
230-
231-
# Check for sdp structure (for deidentification)
232-
filter_results = sanitization_result.get("filterResults", {})
233-
sdp = filter_results.get("sdp", {}).get("sdpFilterResult")
234-
235-
if sdp is not None:
236-
# Model Armor returns sanitized text under deidentifyResult in sdp
237-
deidentify_result = sdp.get("deidentifyResult", {})
238-
sanitized_text = deidentify_result.get("data", {}).get("text", "")
239-
if deidentify_result.get("matchState") == "MATCH_FOUND" and sanitized_text:
240-
return sanitized_text
272+
"""
273+
Get the sanitized content from a Model Armor response, if available.
274+
Looks for sanitized text in deidentifyResult, and falls back to root-level fields if not found.
275+
"""
276+
result = armor_response.get("sanitizationResult", {})
277+
filter_results = result.get("filterResults", {})
278+
279+
# filterResults can be a dict (single filter) or a list (multiple filters)
280+
filters = (
281+
[filter_results]
282+
if isinstance(filter_results, dict)
283+
else filter_results
284+
if isinstance(filter_results, list)
285+
else []
286+
)
241287

242-
# Fallback to checking root level
288+
# Prefer sanitized text from deidentifyResult if present
289+
for filter_entry in filters:
290+
sdp = filter_entry.get("sdpFilterResult")
291+
if sdp:
292+
deid = sdp.get("deidentifyResult", {})
293+
sanitized = deid.get("data", {}).get("text", "")
294+
# If Model Armor found something and returned a sanitized version, use it
295+
if deid.get("matchState") == "MATCH_FOUND" and sanitized:
296+
return sanitized
297+
298+
# If no deidentifyResult, optionally check for inspectResult (rare, but could have findings)
299+
for filter_entry in filters:
300+
sdp = filter_entry.get("sdpFilterResult")
301+
if sdp:
302+
inspect = sdp.get("inspectResult", {})
303+
# If Model Armor flagged something but didn't sanitize, return None
304+
if inspect.get("matchState") == "MATCH_FOUND":
305+
return None
306+
307+
# Fallback: if Model Armor put sanitized text at the root, use it
243308
return armor_response.get("sanitizedText") or armor_response.get("text")
244309

245310
def _process_response(

test_model_armor.py

Whitespace-only changes.

0 commit comments

Comments
 (0)