Skip to content

Commit 4142ddc

Browse files
committed
Revert "Ensure basic detection format proper"
This reverts commit 5070f5d. Commit for a different branch, accidentally pulled in.
1 parent f266fa2 commit 4142ddc

File tree

3 files changed

+50
-115
lines changed

3 files changed

+50
-115
lines changed

git_model_armor.py

Whitespace-only changes.

litellm/proxy/guardrails/guardrail_hooks/model_armor/model_armor.py

Lines changed: 50 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,11 @@ def _get_api_endpoint(self) -> str:
8888
def _create_sanitize_request(
8989
self, content: str, source: Literal["user_prompt", "model_response"]
9090
) -> dict:
91-
"""Create request body for Model Armor API with correct camelCase field names."""
91+
"""Create request body for Model Armor API."""
9292
if source == "user_prompt":
93-
return {"userPromptData": {"text": content}}
93+
return {"user_prompt_data": {"text": content}}
9494
else:
95-
return {"modelResponseData": {"text": content}}
95+
return {"model_response_data": {"text": content}}
9696

9797
def _extract_content_from_response(
9898
self, response: Union[Any, ModelResponse]
@@ -119,16 +119,11 @@ def _extract_content_from_response(
119119

120120
async def make_model_armor_request(
121121
self,
122-
content: Optional[str] = None,
123-
source: Literal["user_prompt", "model_response"] = "user_prompt",
122+
content: str,
123+
source: Literal["user_prompt", "model_response"],
124124
request_data: Optional[dict] = None,
125-
file_bytes: Optional[bytes] = None,
126-
file_type: Optional[str] = None,
127125
) -> dict:
128-
"""
129-
Make request to Model Armor API. Supports both text and file prompt sanitization.
130-
If file_bytes and file_type are provided, file prompt sanitization is performed.
131-
"""
126+
"""Make request to Model Armor API."""
132127
# Get access token using VertexBase auth
133128
access_token, resolved_project_id = await self._ensure_access_token_async(
134129
credentials=self.credentials,
@@ -148,14 +143,7 @@ async def make_model_armor_request(
148143
url = f"{endpoint}/v1/projects/{self.project_id}/locations/{self.location}/templates/{self.template_id}:sanitizeModelResponse"
149144

150145
# Create request body
151-
if file_bytes is not None and file_type is not None:
152-
body = self.sanitize_file_prompt(file_bytes, file_type, source)
153-
elif content is not None:
154-
body = self._create_sanitize_request(content, source)
155-
else:
156-
raise ValueError(
157-
"Either content or file_bytes and file_type must be provided."
158-
)
146+
body = self._create_sanitize_request(content, source)
159147

160148
# Set headers
161149
headers = {
@@ -201,110 +189,57 @@ async def make_model_armor_request(
201189
return await json_response
202190
return json_response
203191

204-
def sanitize_file_prompt(
205-
self, file_bytes: bytes, file_type: str, source: str = "user_prompt"
206-
) -> dict:
207-
"""
208-
Helper to build the request body for file prompt sanitization for Model Armor.
209-
file_type should be one of: PLAINTEXT_UTF8, PDF, WORD_DOCUMENT, EXCEL_DOCUMENT, POWERPOINT_DOCUMENT, TXT, CSV
210-
Returns the request body dict.
211-
"""
212-
import base64
213-
214-
base64_data = base64.b64encode(file_bytes).decode("utf-8")
215-
if source == "user_prompt":
216-
return {
217-
"userPromptData": {
218-
"byteItem": {"byteDataType": file_type, "byteData": base64_data}
219-
}
220-
}
221-
else:
222-
return {
223-
"modelResponseData": {
224-
"byteItem": {"byteDataType": file_type, "byteData": base64_data}
225-
}
226-
}
227-
228192
def _should_block_content(self, armor_response: dict) -> bool:
229-
"""Check if Model Armor response indicates content should be blocked, including both inspectResult and deidentifyResult."""
193+
"""Check if Model Armor response indicates content should be blocked."""
194+
# Check the sanitizationResult from Model Armor API
230195
sanitization_result = armor_response.get("sanitizationResult", {})
231196
filter_results = sanitization_result.get("filterResults", {})
232197

233-
# filterResults can be a dict (named keys) or a list (array of filter result dicts)
234-
filter_result_items = []
235-
if isinstance(filter_results, dict):
236-
filter_result_items = [filter_results]
237-
elif isinstance(filter_results, list):
238-
filter_result_items = filter_results
239-
240-
for filt in filter_result_items:
241-
# Check RAI, PI/Jailbreak, Malicious URI, CSAM, Virus scan as before
242-
if filt.get("raiFilterResult", {}).get("matchState") == "MATCH_FOUND":
243-
return True
244-
if (
245-
filt.get("piAndJailbreakFilterResult", {}).get("matchState")
246-
== "MATCH_FOUND"
247-
):
248-
return True
249-
if (
250-
filt.get("maliciousUriFilterResult", {}).get("matchState")
251-
== "MATCH_FOUND"
252-
):
253-
return True
254-
if (
255-
filt.get("csamFilterFilterResult", {}).get("matchState")
256-
== "MATCH_FOUND"
257-
):
258-
return True
259-
if filt.get("virusScanFilterResult", {}).get("matchState") == "MATCH_FOUND":
260-
return True
261-
# Check sdpFilterResult for both inspectResult and deidentifyResult
262-
sdp = filt.get("sdpFilterResult")
263-
if sdp:
264-
if sdp.get("inspectResult", {}).get("matchState") == "MATCH_FOUND":
265-
return True
266-
if sdp.get("deidentifyResult", {}).get("matchState") == "MATCH_FOUND":
267-
return True
268-
# Fallback dict code removed; all cases handled above
198+
# Check blocking filters (these should cause the request to be blocked)
199+
# RAI (Responsible AI) filters
200+
rai_results = filter_results.get("rai", {}).get("raiFilterResult", {})
201+
if rai_results.get("matchState") == "MATCH_FOUND":
202+
return True
203+
204+
# Prompt injection and jailbreak filters
205+
pi_jailbreak = filter_results.get("piAndJailbreakFilterResult", {})
206+
if pi_jailbreak.get("matchState") == "MATCH_FOUND":
207+
return True
208+
209+
# Malicious URI filters
210+
malicious_uri = filter_results.get("maliciousUriFilterResult", {})
211+
if malicious_uri.get("matchState") == "MATCH_FOUND":
212+
return True
213+
214+
# CSAM filters
215+
csam = filter_results.get("csamFilterFilterResult", {})
216+
if csam.get("matchState") == "MATCH_FOUND":
217+
return True
218+
219+
# Virus scan filters
220+
virus_scan = filter_results.get("virusScanFilterResult", {})
221+
if virus_scan.get("matchState") == "MATCH_FOUND":
222+
return True
223+
269224
return False
270225

271226
def _get_sanitized_content(self, armor_response: dict) -> Optional[str]:
272-
"""
273-
Get the sanitized content from a Model Armor response, if available.
274-
Looks for sanitized text in deidentifyResult, and falls back to root-level fields if not found.
275-
"""
276-
result = armor_response.get("sanitizationResult", {})
277-
filter_results = result.get("filterResults", {})
278-
279-
# filterResults can be a dict (single filter) or a list (multiple filters)
280-
filters = (
281-
[filter_results]
282-
if isinstance(filter_results, dict)
283-
else filter_results
284-
if isinstance(filter_results, list)
285-
else []
286-
)
227+
"""Extract sanitized content from Model Armor response."""
228+
# Model Armor returns sanitized content in the sanitizationResult
229+
sanitization_result = armor_response.get("sanitizationResult", {})
230+
231+
# Check for sdp structure (for deidentification)
232+
filter_results = sanitization_result.get("filterResults", {})
233+
sdp = filter_results.get("sdp", {}).get("sdpFilterResult")
234+
235+
if sdp is not None:
236+
# Model Armor returns sanitized text under deidentifyResult in sdp
237+
deidentify_result = sdp.get("deidentifyResult", {})
238+
sanitized_text = deidentify_result.get("data", {}).get("text", "")
239+
if deidentify_result.get("matchState") == "MATCH_FOUND" and sanitized_text:
240+
return sanitized_text
287241

288-
# Prefer sanitized text from deidentifyResult if present
289-
for filter_entry in filters:
290-
sdp = filter_entry.get("sdpFilterResult")
291-
if sdp:
292-
deid = sdp.get("deidentifyResult", {})
293-
sanitized = deid.get("data", {}).get("text", "")
294-
# If Model Armor found something and returned a sanitized version, use it
295-
if deid.get("matchState") == "MATCH_FOUND" and sanitized:
296-
return sanitized
297-
298-
# If no deidentifyResult, optionally check for inspectResult (rare, but could have findings)
299-
for filter_entry in filters:
300-
sdp = filter_entry.get("sdpFilterResult")
301-
if sdp:
302-
inspect = sdp.get("inspectResult", {})
303-
# If Model Armor flagged something but didn't sanitize, return None
304-
if inspect.get("matchState") == "MATCH_FOUND":
305-
return None
306-
307-
# Fallback: if Model Armor put sanitized text at the root, use it
242+
# Fallback to checking root level
308243
return armor_response.get("sanitizedText") or armor_response.get("text")
309244

310245
def _process_response(

test_model_armor.py

Whitespace-only changes.

0 commit comments

Comments
 (0)