@@ -88,11 +88,11 @@ def _get_api_endpoint(self) -> str:
88
88
def _create_sanitize_request (
89
89
self , content : str , source : Literal ["user_prompt" , "model_response" ]
90
90
) -> dict :
91
- """Create request body for Model Armor API with correct camelCase field names ."""
91
+ """Create request body for Model Armor API."""
92
92
if source == "user_prompt" :
93
- return {"userPromptData " : {"text" : content }}
93
+ return {"user_prompt_data " : {"text" : content }}
94
94
else :
95
- return {"modelResponseData " : {"text" : content }}
95
+ return {"model_response_data " : {"text" : content }}
96
96
97
97
def _extract_content_from_response (
98
98
self , response : Union [Any , ModelResponse ]
@@ -119,16 +119,11 @@ def _extract_content_from_response(
119
119
120
120
async def make_model_armor_request (
121
121
self ,
122
- content : Optional [ str ] = None ,
123
- source : Literal ["user_prompt" , "model_response" ] = "user_prompt" ,
122
+ content : str ,
123
+ source : Literal ["user_prompt" , "model_response" ],
124
124
request_data : Optional [dict ] = None ,
125
- file_bytes : Optional [bytes ] = None ,
126
- file_type : Optional [str ] = None ,
127
125
) -> dict :
128
- """
129
- Make request to Model Armor API. Supports both text and file prompt sanitization.
130
- If file_bytes and file_type are provided, file prompt sanitization is performed.
131
- """
126
+ """Make request to Model Armor API."""
132
127
# Get access token using VertexBase auth
133
128
access_token , resolved_project_id = await self ._ensure_access_token_async (
134
129
credentials = self .credentials ,
@@ -148,14 +143,7 @@ async def make_model_armor_request(
148
143
url = f"{ endpoint } /v1/projects/{ self .project_id } /locations/{ self .location } /templates/{ self .template_id } :sanitizeModelResponse"
149
144
150
145
# Create request body
151
- if file_bytes is not None and file_type is not None :
152
- body = self .sanitize_file_prompt (file_bytes , file_type , source )
153
- elif content is not None :
154
- body = self ._create_sanitize_request (content , source )
155
- else :
156
- raise ValueError (
157
- "Either content or file_bytes and file_type must be provided."
158
- )
146
+ body = self ._create_sanitize_request (content , source )
159
147
160
148
# Set headers
161
149
headers = {
@@ -201,110 +189,57 @@ async def make_model_armor_request(
201
189
return await json_response
202
190
return json_response
203
191
204
- def sanitize_file_prompt (
205
- self , file_bytes : bytes , file_type : str , source : str = "user_prompt"
206
- ) -> dict :
207
- """
208
- Helper to build the request body for file prompt sanitization for Model Armor.
209
- file_type should be one of: PLAINTEXT_UTF8, PDF, WORD_DOCUMENT, EXCEL_DOCUMENT, POWERPOINT_DOCUMENT, TXT, CSV
210
- Returns the request body dict.
211
- """
212
- import base64
213
-
214
- base64_data = base64 .b64encode (file_bytes ).decode ("utf-8" )
215
- if source == "user_prompt" :
216
- return {
217
- "userPromptData" : {
218
- "byteItem" : {"byteDataType" : file_type , "byteData" : base64_data }
219
- }
220
- }
221
- else :
222
- return {
223
- "modelResponseData" : {
224
- "byteItem" : {"byteDataType" : file_type , "byteData" : base64_data }
225
- }
226
- }
227
-
228
192
def _should_block_content (self , armor_response : dict ) -> bool :
229
- """Check if Model Armor response indicates content should be blocked, including both inspectResult and deidentifyResult."""
193
+ """Check if Model Armor response indicates content should be blocked."""
194
+ # Check the sanitizationResult from Model Armor API
230
195
sanitization_result = armor_response .get ("sanitizationResult" , {})
231
196
filter_results = sanitization_result .get ("filterResults" , {})
232
197
233
- # filterResults can be a dict (named keys) or a list (array of filter result dicts)
234
- filter_result_items = []
235
- if isinstance (filter_results , dict ):
236
- filter_result_items = [filter_results ]
237
- elif isinstance (filter_results , list ):
238
- filter_result_items = filter_results
239
-
240
- for filt in filter_result_items :
241
- # Check RAI, PI/Jailbreak, Malicious URI, CSAM, Virus scan as before
242
- if filt .get ("raiFilterResult" , {}).get ("matchState" ) == "MATCH_FOUND" :
243
- return True
244
- if (
245
- filt .get ("piAndJailbreakFilterResult" , {}).get ("matchState" )
246
- == "MATCH_FOUND"
247
- ):
248
- return True
249
- if (
250
- filt .get ("maliciousUriFilterResult" , {}).get ("matchState" )
251
- == "MATCH_FOUND"
252
- ):
253
- return True
254
- if (
255
- filt .get ("csamFilterFilterResult" , {}).get ("matchState" )
256
- == "MATCH_FOUND"
257
- ):
258
- return True
259
- if filt .get ("virusScanFilterResult" , {}).get ("matchState" ) == "MATCH_FOUND" :
260
- return True
261
- # Check sdpFilterResult for both inspectResult and deidentifyResult
262
- sdp = filt .get ("sdpFilterResult" )
263
- if sdp :
264
- if sdp .get ("inspectResult" , {}).get ("matchState" ) == "MATCH_FOUND" :
265
- return True
266
- if sdp .get ("deidentifyResult" , {}).get ("matchState" ) == "MATCH_FOUND" :
267
- return True
268
- # Fallback dict code removed; all cases handled above
198
+ # Check blocking filters (these should cause the request to be blocked)
199
+ # RAI (Responsible AI) filters
200
+ rai_results = filter_results .get ("rai" , {}).get ("raiFilterResult" , {})
201
+ if rai_results .get ("matchState" ) == "MATCH_FOUND" :
202
+ return True
203
+
204
+ # Prompt injection and jailbreak filters
205
+ pi_jailbreak = filter_results .get ("piAndJailbreakFilterResult" , {})
206
+ if pi_jailbreak .get ("matchState" ) == "MATCH_FOUND" :
207
+ return True
208
+
209
+ # Malicious URI filters
210
+ malicious_uri = filter_results .get ("maliciousUriFilterResult" , {})
211
+ if malicious_uri .get ("matchState" ) == "MATCH_FOUND" :
212
+ return True
213
+
214
+ # CSAM filters
215
+ csam = filter_results .get ("csamFilterFilterResult" , {})
216
+ if csam .get ("matchState" ) == "MATCH_FOUND" :
217
+ return True
218
+
219
+ # Virus scan filters
220
+ virus_scan = filter_results .get ("virusScanFilterResult" , {})
221
+ if virus_scan .get ("matchState" ) == "MATCH_FOUND" :
222
+ return True
223
+
269
224
return False
270
225
271
226
def _get_sanitized_content (self , armor_response : dict ) -> Optional [str ]:
272
- """
273
- Get the sanitized content from a Model Armor response, if available.
274
- Looks for sanitized text in deidentifyResult, and falls back to root-level fields if not found.
275
- """
276
- result = armor_response .get ("sanitizationResult" , {})
277
- filter_results = result .get ("filterResults" , {})
278
-
279
- # filterResults can be a dict (single filter) or a list (multiple filters)
280
- filters = (
281
- [filter_results ]
282
- if isinstance (filter_results , dict )
283
- else filter_results
284
- if isinstance (filter_results , list )
285
- else []
286
- )
227
+ """Extract sanitized content from Model Armor response."""
228
+ # Model Armor returns sanitized content in the sanitizationResult
229
+ sanitization_result = armor_response .get ("sanitizationResult" , {})
230
+
231
+ # Check for sdp structure (for deidentification)
232
+ filter_results = sanitization_result .get ("filterResults" , {})
233
+ sdp = filter_results .get ("sdp" , {}).get ("sdpFilterResult" )
234
+
235
+ if sdp is not None :
236
+ # Model Armor returns sanitized text under deidentifyResult in sdp
237
+ deidentify_result = sdp .get ("deidentifyResult" , {})
238
+ sanitized_text = deidentify_result .get ("data" , {}).get ("text" , "" )
239
+ if deidentify_result .get ("matchState" ) == "MATCH_FOUND" and sanitized_text :
240
+ return sanitized_text
287
241
288
- # Prefer sanitized text from deidentifyResult if present
289
- for filter_entry in filters :
290
- sdp = filter_entry .get ("sdpFilterResult" )
291
- if sdp :
292
- deid = sdp .get ("deidentifyResult" , {})
293
- sanitized = deid .get ("data" , {}).get ("text" , "" )
294
- # If Model Armor found something and returned a sanitized version, use it
295
- if deid .get ("matchState" ) == "MATCH_FOUND" and sanitized :
296
- return sanitized
297
-
298
- # If no deidentifyResult, optionally check for inspectResult (rare, but could have findings)
299
- for filter_entry in filters :
300
- sdp = filter_entry .get ("sdpFilterResult" )
301
- if sdp :
302
- inspect = sdp .get ("inspectResult" , {})
303
- # If Model Armor flagged something but didn't sanitize, return None
304
- if inspect .get ("matchState" ) == "MATCH_FOUND" :
305
- return None
306
-
307
- # Fallback: if Model Armor put sanitized text at the root, use it
242
+ # Fallback to checking root level
308
243
return armor_response .get ("sanitizedText" ) or armor_response .get ("text" )
309
244
310
245
def _process_response (
0 commit comments