15
15
import json
16
16
import sys
17
17
from typing import Any , AsyncGenerator , List , Literal , Optional , Tuple , Union
18
- from litellm . secret_managers . main import get_secret_str
18
+
19
19
import httpx
20
20
from fastapi import HTTPException
21
21
32
32
httpxSpecialProvider ,
33
33
)
34
34
from litellm .proxy ._types import UserAPIKeyAuth
35
+ from litellm .secret_managers .main import get_secret_str
35
36
from litellm .types .guardrails import GuardrailEventHooks
36
37
from litellm .types .llms .openai import AllMessageValues
37
38
from litellm .types .proxy .guardrails .guardrail_hooks .bedrock_guardrails import (
@@ -118,18 +119,17 @@ def __init__(
118
119
"""
119
120
If True, will not raise an exception when the guardrail is blocked.
120
121
"""
121
-
122
122
123
123
# Set supported event hooks to include MCP hooks
124
- if ' supported_event_hooks' not in kwargs :
125
- kwargs [' supported_event_hooks' ] = [
124
+ if " supported_event_hooks" not in kwargs :
125
+ kwargs [" supported_event_hooks" ] = [
126
126
GuardrailEventHooks .pre_call ,
127
127
GuardrailEventHooks .post_call ,
128
128
GuardrailEventHooks .during_call ,
129
129
GuardrailEventHooks .pre_mcp_call ,
130
130
GuardrailEventHooks .during_mcp_call ,
131
131
]
132
-
132
+
133
133
super ().__init__ (** kwargs )
134
134
BaseAWSLLM .__init__ (self )
135
135
@@ -138,9 +138,10 @@ def __init__(
138
138
self .guardrailIdentifier ,
139
139
self .guardrailVersion ,
140
140
)
141
-
142
141
143
- def _create_bedrock_input_content_request (self , messages : Optional [List [AllMessageValues ]]) -> BedrockRequest :
142
+ def _create_bedrock_input_content_request (
143
+ self , messages : Optional [List [AllMessageValues ]]
144
+ ) -> BedrockRequest :
144
145
"""
145
146
Create a bedrock request for the input content - the LLM request.
146
147
"""
@@ -149,8 +150,8 @@ def _create_bedrock_input_content_request(self, messages: Optional[List[AllMessa
149
150
if messages is None :
150
151
return bedrock_request
151
152
for message in messages :
152
- message_text_content : Optional [List [str ]] = (
153
- self . get_content_for_message ( message = message )
153
+ message_text_content : Optional [List [str ]] = self . get_content_for_message (
154
+ message = message
154
155
)
155
156
if message_text_content is None :
156
157
continue
@@ -163,7 +164,9 @@ def _create_bedrock_input_content_request(self, messages: Optional[List[AllMessa
163
164
bedrock_request ["content" ] = bedrock_request_content
164
165
return bedrock_request
165
166
166
- def _create_bedrock_output_content_request (self , response : Union [Any , ModelResponse ]) -> BedrockRequest :
167
+ def _create_bedrock_output_content_request (
168
+ self , response : Union [Any , ModelResponse ]
169
+ ) -> BedrockRequest :
167
170
"""
168
171
Create a bedrock request for the output content - the LLM response.
169
172
"""
@@ -199,9 +202,13 @@ def convert_to_bedrock_format(
199
202
"""
200
203
bedrock_request : BedrockRequest = BedrockRequest (source = source )
201
204
if source == "INPUT" :
202
- bedrock_request = self ._create_bedrock_input_content_request (messages = messages )
205
+ bedrock_request = self ._create_bedrock_input_content_request (
206
+ messages = messages
207
+ )
203
208
elif source == "OUTPUT" :
204
- bedrock_request = self ._create_bedrock_output_content_request (response = response )
209
+ bedrock_request = self ._create_bedrock_output_content_request (
210
+ response = response
211
+ )
205
212
return bedrock_request
206
213
207
214
#### CALL HOOKS - proxy only ####
@@ -255,9 +262,19 @@ def _prepare_request(
255
262
headers = {"Content-Type" : "application/json" }
256
263
if extra_headers is not None :
257
264
headers = {"Content-Type" : "application/json" , ** extra_headers }
258
- api_base = f"https://bedrock-runtime.{ aws_region_name } .amazonaws.com/guardrail/{ self .guardrailIdentifier } /version/{ self .guardrailVersion } /apply"
265
+
266
+ aws_bedrock_runtime_endpoint = self .optional_params .get (
267
+ "aws_bedrock_runtime_endpoint" , None
268
+ )
269
+ _ , proxy_endpoint_url = self .get_runtime_endpoint (
270
+ api_base = None ,
271
+ aws_bedrock_runtime_endpoint = aws_bedrock_runtime_endpoint ,
272
+ aws_region_name = aws_region_name ,
273
+ )
274
+ proxy_endpoint_url = f"{ proxy_endpoint_url } /guardrail/{ self .guardrailIdentifier } /version/{ self .guardrailVersion } /apply"
275
+ # api_base = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com/guardrail/{self.guardrailIdentifier}/version/{self.guardrailVersion}/apply"
259
276
encoded_data = json .dumps (data ).encode ("utf-8" )
260
-
277
+
261
278
# first check api-key, if none, fall back to sigV4
262
279
if api_key is not None :
263
280
aws_bearer_token : Optional [str ] = api_key
@@ -268,21 +285,31 @@ def _prepare_request(
268
285
try :
269
286
from botocore .awsrequest import AWSRequest
270
287
except ImportError :
271
- raise ImportError ("Missing boto3 to call bedrock. Run 'pip install boto3'." )
288
+ raise ImportError (
289
+ "Missing boto3 to call bedrock. Run 'pip install boto3'."
290
+ )
272
291
headers ["Authorization" ] = f"Bearer { aws_bearer_token } "
273
292
request = AWSRequest (
274
- method = "POST" , url = api_base , data = encoded_data , headers = headers
293
+ method = "POST" ,
294
+ url = proxy_endpoint_url ,
295
+ data = encoded_data ,
296
+ headers = headers ,
275
297
)
276
298
else :
277
299
try :
278
300
from botocore .auth import SigV4Auth
279
301
from botocore .awsrequest import AWSRequest
280
302
except ImportError :
281
- raise ImportError ("Missing boto3 to call bedrock. Run 'pip install boto3'." )
303
+ raise ImportError (
304
+ "Missing boto3 to call bedrock. Run 'pip install boto3'."
305
+ )
282
306
283
307
sigv4 = SigV4Auth (credentials , "bedrock" , aws_region_name )
284
308
request = AWSRequest (
285
- method = "POST" , url = api_base , data = encoded_data , headers = headers
309
+ method = "POST" ,
310
+ url = proxy_endpoint_url ,
311
+ data = encoded_data ,
312
+ headers = headers ,
286
313
)
287
314
sigv4 .add_auth (request )
288
315
if (
@@ -294,20 +321,19 @@ def _prepare_request(
294
321
return prepped_request
295
322
296
323
async def make_bedrock_api_request (
297
- self ,
324
+ self ,
298
325
source : Literal ["INPUT" , "OUTPUT" ],
299
326
messages : Optional [List [AllMessageValues ]] = None ,
300
327
response : Optional [Union [Any , litellm .ModelResponse ]] = None ,
301
- request_data : Optional [dict ] = None
328
+ request_data : Optional [dict ] = None ,
302
329
) -> BedrockGuardrailResponse :
303
330
from datetime import datetime
331
+
304
332
start_time = datetime .now ()
305
333
credentials , aws_region_name = self ._load_credentials ()
306
334
bedrock_request_data : dict = dict (
307
335
self .convert_to_bedrock_format (
308
- source = source ,
309
- messages = messages ,
310
- response = response
336
+ source = source , messages = messages , response = response
311
337
)
312
338
)
313
339
bedrock_guardrail_response : BedrockGuardrailResponse = (
@@ -316,11 +342,13 @@ async def make_bedrock_api_request(
316
342
api_key : Optional [str ] = None
317
343
if request_data :
318
344
bedrock_request_data .update (
319
- self .get_guardrail_dynamic_request_body_params (request_data = request_data )
345
+ self .get_guardrail_dynamic_request_body_params (
346
+ request_data = request_data
347
+ )
320
348
)
321
349
if request_data .get ("api_key" ) is not None :
322
350
api_key = request_data ["api_key" ]
323
-
351
+
324
352
prepared_request = self ._prepare_request (
325
353
credentials = credentials ,
326
354
data = bedrock_request_data ,
@@ -346,7 +374,9 @@ async def make_bedrock_api_request(
346
374
self .add_standard_logging_guardrail_information_to_request_data (
347
375
guardrail_json_response = response .json (),
348
376
request_data = request_data or {},
349
- guardrail_status = self ._get_bedrock_guardrail_response_status (response = response ),
377
+ guardrail_status = self ._get_bedrock_guardrail_response_status (
378
+ response = response
379
+ ),
350
380
start_time = start_time .timestamp (),
351
381
end_time = datetime .now ().timestamp (),
352
382
duration = (datetime .now () - start_time ).total_seconds (),
@@ -372,16 +402,20 @@ async def make_bedrock_api_request(
372
402
)
373
403
374
404
return bedrock_guardrail_response
375
-
376
- def _get_bedrock_guardrail_response_status (self , response : httpx .Response ) -> Literal ["success" , "failure" ]:
405
+
406
+ def _get_bedrock_guardrail_response_status (
407
+ self , response : httpx .Response
408
+ ) -> Literal ["success" , "failure" ]:
377
409
"""
378
410
Get the status of the bedrock guardrail response.
379
411
"""
380
412
if response .status_code == 200 :
381
413
return "success"
382
414
return "failure"
383
415
384
- def _get_http_exception_for_blocked_guardrail (self , response : BedrockGuardrailResponse ) -> HTTPException :
416
+ def _get_http_exception_for_blocked_guardrail (
417
+ self , response : BedrockGuardrailResponse
418
+ ) -> HTTPException :
385
419
"""
386
420
Get the HTTP exception for a blocked guardrail.
387
421
"""
@@ -393,17 +427,15 @@ def _get_http_exception_for_blocked_guardrail(self, response: BedrockGuardrailRe
393
427
for output in outputs :
394
428
if output .get ("text" ):
395
429
bedrock_guardrail_output_text += output .get ("text" ) or ""
396
-
397
-
430
+
398
431
return HTTPException (
399
432
status_code = 400 ,
400
433
detail = {
401
- "error" : "Violated guardrail policy" ,
434
+ "error" : "Violated guardrail policy" ,
402
435
"bedrock_guardrail_response" : bedrock_guardrail_output_text ,
403
- }
436
+ },
404
437
)
405
438
406
-
407
439
def _should_raise_guardrail_blocked_exception (
408
440
self , response : BedrockGuardrailResponse
409
441
) -> bool :
@@ -416,7 +448,7 @@ def _should_raise_guardrail_blocked_exception(
416
448
# if user opted into masking, return False. since we'll use the masked output from the guardrail
417
449
if self .mask_request_content or self .mask_response_content :
418
450
return False
419
-
451
+
420
452
if self .disable_exception_on_block is True :
421
453
return False
422
454
@@ -631,9 +663,7 @@ async def async_post_call_success_hook(
631
663
########## 1. Make parallel Bedrock API requests ##########
632
664
#########################################################
633
665
output_content_bedrock = await self .make_bedrock_api_request (
634
- source = "OUTPUT" ,
635
- response = response ,
636
- request_data = data
666
+ source = "OUTPUT" , response = response , request_data = data
637
667
) # Only response
638
668
639
669
#########################################################
@@ -729,16 +759,16 @@ async def async_post_call_streaming_iterator_hook(
729
759
###################################################################
730
760
# Create tasks for parallel execution
731
761
input_task = self .make_bedrock_api_request (
732
- source = "INPUT" , messages = request_data .get ("messages" ), request_data = request_data
762
+ source = "INPUT" ,
763
+ messages = request_data .get ("messages" ),
764
+ request_data = request_data ,
733
765
) # Only input messages
734
766
output_task = self .make_bedrock_api_request (
735
767
source = "OUTPUT" , response = assembled_model_response
736
768
) # Only response
737
769
738
770
# Execute both requests in parallel
739
- _ , output_guardrail_response = await asyncio .gather (
740
- input_task , output_task
741
- )
771
+ _ , output_guardrail_response = await asyncio .gather (input_task , output_task )
742
772
743
773
#########################################################################
744
774
########## 2. Apply masking to response with output guardrail response ##########
@@ -891,7 +921,7 @@ def _apply_masking_to_response(
891
921
) -> None :
892
922
"""
893
923
Apply masked content from bedrock guardrail to the response object.
894
-
924
+
895
925
Args:
896
926
response: The response object to modify
897
927
bedrock_guardrail_response: Response from Bedrock guardrail containing masked content
@@ -902,7 +932,9 @@ def _apply_masking_to_response(
902
932
)
903
933
904
934
if not masked_texts :
905
- verbose_proxy_logger .debug ("No masked outputs found, skipping response masking" )
935
+ verbose_proxy_logger .debug (
936
+ "No masked outputs found, skipping response masking"
937
+ )
906
938
return
907
939
908
940
verbose_proxy_logger .debug (
@@ -922,13 +954,13 @@ def _apply_masking_to_model_response(
922
954
) -> None :
923
955
"""
924
956
Apply masked texts to a ModelResponse object.
925
-
957
+
926
958
Args:
927
959
response: The ModelResponse object to modify in-place
928
960
masked_texts: List of masked text strings from guardrail
929
961
"""
930
962
masking_index = 0
931
-
963
+
932
964
for choice in response .choices :
933
965
if isinstance (choice , Choices ):
934
966
# For chat completions
0 commit comments