@@ -88,11 +88,11 @@ def _get_api_endpoint(self) -> str:
88
88
def _create_sanitize_request (
89
89
self , content : str , source : Literal ["user_prompt" , "model_response" ]
90
90
) -> dict :
91
- """Create request body for Model Armor API."""
91
+ """Create request body for Model Armor API with correct camelCase field names ."""
92
92
if source == "user_prompt" :
93
- return {"user_prompt_data " : {"text" : content }}
93
+ return {"userPromptData " : {"text" : content }}
94
94
else :
95
- return {"model_response_data " : {"text" : content }}
95
+ return {"modelResponseData " : {"text" : content }}
96
96
97
97
def _extract_content_from_response (
98
98
self , response : Union [Any , ModelResponse ]
@@ -119,11 +119,16 @@ def _extract_content_from_response(
119
119
120
120
async def make_model_armor_request (
121
121
self ,
122
- content : str ,
123
- source : Literal ["user_prompt" , "model_response" ],
122
+ content : Optional [ str ] = None ,
123
+ source : Literal ["user_prompt" , "model_response" ] = "user_prompt" ,
124
124
request_data : Optional [dict ] = None ,
125
+ file_bytes : Optional [bytes ] = None ,
126
+ file_type : Optional [str ] = None ,
125
127
) -> dict :
126
- """Make request to Model Armor API."""
128
+ """
129
+ Make request to Model Armor API. Supports both text and file prompt sanitization.
130
+ If file_bytes and file_type are provided, file prompt sanitization is performed.
131
+ """
127
132
# Get access token using VertexBase auth
128
133
access_token , resolved_project_id = await self ._ensure_access_token_async (
129
134
credentials = self .credentials ,
@@ -143,7 +148,14 @@ async def make_model_armor_request(
143
148
url = f"{ endpoint } /v1/projects/{ self .project_id } /locations/{ self .location } /templates/{ self .template_id } :sanitizeModelResponse"
144
149
145
150
# Create request body
146
- body = self ._create_sanitize_request (content , source )
151
+ if file_bytes is not None and file_type is not None :
152
+ body = self .sanitize_file_prompt (file_bytes , file_type , source )
153
+ elif content is not None :
154
+ body = self ._create_sanitize_request (content , source )
155
+ else :
156
+ raise ValueError (
157
+ "Either content or file_bytes and file_type must be provided."
158
+ )
147
159
148
160
# Set headers
149
161
headers = {
@@ -189,57 +201,110 @@ async def make_model_armor_request(
189
201
return await json_response
190
202
return json_response
191
203
204
+ def sanitize_file_prompt (
205
+ self , file_bytes : bytes , file_type : str , source : str = "user_prompt"
206
+ ) -> dict :
207
+ """
208
+ Helper to build the request body for file prompt sanitization for Model Armor.
209
+ file_type should be one of: PLAINTEXT_UTF8, PDF, WORD_DOCUMENT, EXCEL_DOCUMENT, POWERPOINT_DOCUMENT, TXT, CSV
210
+ Returns the request body dict.
211
+ """
212
+ import base64
213
+
214
+ base64_data = base64 .b64encode (file_bytes ).decode ("utf-8" )
215
+ if source == "user_prompt" :
216
+ return {
217
+ "userPromptData" : {
218
+ "byteItem" : {"byteDataType" : file_type , "byteData" : base64_data }
219
+ }
220
+ }
221
+ else :
222
+ return {
223
+ "modelResponseData" : {
224
+ "byteItem" : {"byteDataType" : file_type , "byteData" : base64_data }
225
+ }
226
+ }
227
+
192
228
def _should_block_content (self , armor_response : dict ) -> bool :
193
- """Check if Model Armor response indicates content should be blocked."""
194
- # Check the sanitizationResult from Model Armor API
229
+ """Check if Model Armor response indicates content should be blocked, including both inspectResult and deidentifyResult."""
195
230
sanitization_result = armor_response .get ("sanitizationResult" , {})
196
231
filter_results = sanitization_result .get ("filterResults" , {})
197
232
198
- # Check blocking filters (these should cause the request to be blocked)
199
- # RAI (Responsible AI) filters
200
- rai_results = filter_results .get ("rai" , {}).get ("raiFilterResult" , {})
201
- if rai_results .get ("matchState" ) == "MATCH_FOUND" :
202
- return True
203
-
204
- # Prompt injection and jailbreak filters
205
- pi_jailbreak = filter_results .get ("piAndJailbreakFilterResult" , {})
206
- if pi_jailbreak .get ("matchState" ) == "MATCH_FOUND" :
207
- return True
208
-
209
- # Malicious URI filters
210
- malicious_uri = filter_results .get ("maliciousUriFilterResult" , {})
211
- if malicious_uri .get ("matchState" ) == "MATCH_FOUND" :
212
- return True
213
-
214
- # CSAM filters
215
- csam = filter_results .get ("csamFilterFilterResult" , {})
216
- if csam .get ("matchState" ) == "MATCH_FOUND" :
217
- return True
218
-
219
- # Virus scan filters
220
- virus_scan = filter_results .get ("virusScanFilterResult" , {})
221
- if virus_scan .get ("matchState" ) == "MATCH_FOUND" :
222
- return True
223
-
233
+ # filterResults can be a dict (named keys) or a list (array of filter result dicts)
234
+ filter_result_items = []
235
+ if isinstance (filter_results , dict ):
236
+ filter_result_items = [filter_results ]
237
+ elif isinstance (filter_results , list ):
238
+ filter_result_items = filter_results
239
+
240
+ for filt in filter_result_items :
241
+ # Check RAI, PI/Jailbreak, Malicious URI, CSAM, Virus scan as before
242
+ if filt .get ("raiFilterResult" , {}).get ("matchState" ) == "MATCH_FOUND" :
243
+ return True
244
+ if (
245
+ filt .get ("piAndJailbreakFilterResult" , {}).get ("matchState" )
246
+ == "MATCH_FOUND"
247
+ ):
248
+ return True
249
+ if (
250
+ filt .get ("maliciousUriFilterResult" , {}).get ("matchState" )
251
+ == "MATCH_FOUND"
252
+ ):
253
+ return True
254
+ if (
255
+ filt .get ("csamFilterFilterResult" , {}).get ("matchState" )
256
+ == "MATCH_FOUND"
257
+ ):
258
+ return True
259
+ if filt .get ("virusScanFilterResult" , {}).get ("matchState" ) == "MATCH_FOUND" :
260
+ return True
261
+ # Check sdpFilterResult for both inspectResult and deidentifyResult
262
+ sdp = filt .get ("sdpFilterResult" )
263
+ if sdp :
264
+ if sdp .get ("inspectResult" , {}).get ("matchState" ) == "MATCH_FOUND" :
265
+ return True
266
+ if sdp .get ("deidentifyResult" , {}).get ("matchState" ) == "MATCH_FOUND" :
267
+ return True
268
+ # Fallback dict code removed; all cases handled above
224
269
return False
225
270
226
271
def _get_sanitized_content (self , armor_response : dict ) -> Optional [str ]:
227
- """Extract sanitized content from Model Armor response."""
228
- # Model Armor returns sanitized content in the sanitizationResult
229
- sanitization_result = armor_response .get ("sanitizationResult" , {})
230
-
231
- # Check for sdp structure (for deidentification)
232
- filter_results = sanitization_result .get ("filterResults" , {})
233
- sdp = filter_results .get ("sdp" , {}).get ("sdpFilterResult" )
234
-
235
- if sdp is not None :
236
- # Model Armor returns sanitized text under deidentifyResult in sdp
237
- deidentify_result = sdp .get ("deidentifyResult" , {})
238
- sanitized_text = deidentify_result .get ("data" , {}).get ("text" , "" )
239
- if deidentify_result .get ("matchState" ) == "MATCH_FOUND" and sanitized_text :
240
- return sanitized_text
272
+ """
273
+ Get the sanitized content from a Model Armor response, if available.
274
+ Looks for sanitized text in deidentifyResult, and falls back to root-level fields if not found.
275
+ """
276
+ result = armor_response .get ("sanitizationResult" , {})
277
+ filter_results = result .get ("filterResults" , {})
278
+
279
+ # filterResults can be a dict (single filter) or a list (multiple filters)
280
+ filters = (
281
+ [filter_results ]
282
+ if isinstance (filter_results , dict )
283
+ else filter_results
284
+ if isinstance (filter_results , list )
285
+ else []
286
+ )
241
287
242
- # Fallback to checking root level
288
+ # Prefer sanitized text from deidentifyResult if present
289
+ for filter_entry in filters :
290
+ sdp = filter_entry .get ("sdpFilterResult" )
291
+ if sdp :
292
+ deid = sdp .get ("deidentifyResult" , {})
293
+ sanitized = deid .get ("data" , {}).get ("text" , "" )
294
+ # If Model Armor found something and returned a sanitized version, use it
295
+ if deid .get ("matchState" ) == "MATCH_FOUND" and sanitized :
296
+ return sanitized
297
+
298
+ # If no deidentifyResult, optionally check for inspectResult (rare, but could have findings)
299
+ for filter_entry in filters :
300
+ sdp = filter_entry .get ("sdpFilterResult" )
301
+ if sdp :
302
+ inspect = sdp .get ("inspectResult" , {})
303
+ # If Model Armor flagged something but didn't sanitize, return None
304
+ if inspect .get ("matchState" ) == "MATCH_FOUND" :
305
+ return None
306
+
307
+ # Fallback: if Model Armor put sanitized text at the root, use it
243
308
return armor_response .get ("sanitizedText" ) or armor_response .get ("text" )
244
309
245
310
def _process_response (
0 commit comments