Skip to content

Commit 1598d3e

Browse files
fix(bedrock_guardrails.py): respect bedrock runtime endpoint when using guardrails
Closes LIT-983
1 parent 4276d78 commit 1598d3e

File tree

2 files changed

+90
-46
lines changed

2 files changed

+90
-46
lines changed

litellm/proxy/_new_secret_config.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,15 @@ model_list:
1818

1919

2020

21+
22+
guardrails:
23+
- guardrail_name: "intel-bedrock-guard-cfg"
24+
litellm_params:
25+
guardrail: bedrock
26+
mode: [pre_call, post_call]
27+
guardrailIdentifier: "1234"
28+
guardrailVersion: "1"
29+
aws_access_key_id: "os.environ/AWS_ACCESS_KEY_ID"
30+
aws_secret_access_key: "os.environ/AWS_SECRET_ACCESS_KEY"
31+
aws_bedrock_runtime_endpoint: "os.environ/AWS_BEDROCK_RUNTIME_ENDPOINT"
32+
default_on: true

litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py

Lines changed: 78 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import json
1616
import sys
1717
from typing import Any, AsyncGenerator, List, Literal, Optional, Tuple, Union
18-
from litellm.secret_managers.main import get_secret_str
18+
1919
import httpx
2020
from fastapi import HTTPException
2121

@@ -32,6 +32,7 @@
3232
httpxSpecialProvider,
3333
)
3434
from litellm.proxy._types import UserAPIKeyAuth
35+
from litellm.secret_managers.main import get_secret_str
3536
from litellm.types.guardrails import GuardrailEventHooks
3637
from litellm.types.llms.openai import AllMessageValues
3738
from litellm.types.proxy.guardrails.guardrail_hooks.bedrock_guardrails import (
@@ -118,18 +119,17 @@ def __init__(
118119
"""
119120
If True, will not raise an exception when the guardrail is blocked.
120121
"""
121-
122122

123123
# Set supported event hooks to include MCP hooks
124-
if 'supported_event_hooks' not in kwargs:
125-
kwargs['supported_event_hooks'] = [
124+
if "supported_event_hooks" not in kwargs:
125+
kwargs["supported_event_hooks"] = [
126126
GuardrailEventHooks.pre_call,
127127
GuardrailEventHooks.post_call,
128128
GuardrailEventHooks.during_call,
129129
GuardrailEventHooks.pre_mcp_call,
130130
GuardrailEventHooks.during_mcp_call,
131131
]
132-
132+
133133
super().__init__(**kwargs)
134134
BaseAWSLLM.__init__(self)
135135

@@ -138,9 +138,10 @@ def __init__(
138138
self.guardrailIdentifier,
139139
self.guardrailVersion,
140140
)
141-
142141

143-
def _create_bedrock_input_content_request(self, messages: Optional[List[AllMessageValues]]) -> BedrockRequest:
142+
def _create_bedrock_input_content_request(
143+
self, messages: Optional[List[AllMessageValues]]
144+
) -> BedrockRequest:
144145
"""
145146
Create a bedrock request for the input content - the LLM request.
146147
"""
@@ -149,8 +150,8 @@ def _create_bedrock_input_content_request(self, messages: Optional[List[AllMessa
149150
if messages is None:
150151
return bedrock_request
151152
for message in messages:
152-
message_text_content: Optional[List[str]] = (
153-
self.get_content_for_message(message=message)
153+
message_text_content: Optional[List[str]] = self.get_content_for_message(
154+
message=message
154155
)
155156
if message_text_content is None:
156157
continue
@@ -163,7 +164,9 @@ def _create_bedrock_input_content_request(self, messages: Optional[List[AllMessa
163164
bedrock_request["content"] = bedrock_request_content
164165
return bedrock_request
165166

166-
def _create_bedrock_output_content_request(self, response: Union[Any, ModelResponse]) -> BedrockRequest:
167+
def _create_bedrock_output_content_request(
168+
self, response: Union[Any, ModelResponse]
169+
) -> BedrockRequest:
167170
"""
168171
Create a bedrock request for the output content - the LLM response.
169172
"""
@@ -199,9 +202,13 @@ def convert_to_bedrock_format(
199202
"""
200203
bedrock_request: BedrockRequest = BedrockRequest(source=source)
201204
if source == "INPUT":
202-
bedrock_request = self._create_bedrock_input_content_request(messages=messages)
205+
bedrock_request = self._create_bedrock_input_content_request(
206+
messages=messages
207+
)
203208
elif source == "OUTPUT":
204-
bedrock_request = self._create_bedrock_output_content_request(response=response)
209+
bedrock_request = self._create_bedrock_output_content_request(
210+
response=response
211+
)
205212
return bedrock_request
206213

207214
#### CALL HOOKS - proxy only ####
@@ -255,9 +262,19 @@ def _prepare_request(
255262
headers = {"Content-Type": "application/json"}
256263
if extra_headers is not None:
257264
headers = {"Content-Type": "application/json", **extra_headers}
258-
api_base = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com/guardrail/{self.guardrailIdentifier}/version/{self.guardrailVersion}/apply"
265+
266+
aws_bedrock_runtime_endpoint = self.optional_params.get(
267+
"aws_bedrock_runtime_endpoint", None
268+
)
269+
_, proxy_endpoint_url = self.get_runtime_endpoint(
270+
api_base=None,
271+
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
272+
aws_region_name=aws_region_name,
273+
)
274+
proxy_endpoint_url = f"{proxy_endpoint_url}/guardrail/{self.guardrailIdentifier}/version/{self.guardrailVersion}/apply"
275+
# api_base = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com/guardrail/{self.guardrailIdentifier}/version/{self.guardrailVersion}/apply"
259276
encoded_data = json.dumps(data).encode("utf-8")
260-
277+
261278
# first check api-key, if none, fall back to sigV4
262279
if api_key is not None:
263280
aws_bearer_token: Optional[str] = api_key
@@ -268,21 +285,31 @@ def _prepare_request(
268285
try:
269286
from botocore.awsrequest import AWSRequest
270287
except ImportError:
271-
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
288+
raise ImportError(
289+
"Missing boto3 to call bedrock. Run 'pip install boto3'."
290+
)
272291
headers["Authorization"] = f"Bearer {aws_bearer_token}"
273292
request = AWSRequest(
274-
method="POST", url=api_base, data=encoded_data, headers=headers
293+
method="POST",
294+
url=proxy_endpoint_url,
295+
data=encoded_data,
296+
headers=headers,
275297
)
276298
else:
277299
try:
278300
from botocore.auth import SigV4Auth
279301
from botocore.awsrequest import AWSRequest
280302
except ImportError:
281-
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
303+
raise ImportError(
304+
"Missing boto3 to call bedrock. Run 'pip install boto3'."
305+
)
282306

283307
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
284308
request = AWSRequest(
285-
method="POST", url=api_base, data=encoded_data, headers=headers
309+
method="POST",
310+
url=proxy_endpoint_url,
311+
data=encoded_data,
312+
headers=headers,
286313
)
287314
sigv4.add_auth(request)
288315
if (
@@ -294,20 +321,19 @@ def _prepare_request(
294321
return prepped_request
295322

296323
async def make_bedrock_api_request(
297-
self,
324+
self,
298325
source: Literal["INPUT", "OUTPUT"],
299326
messages: Optional[List[AllMessageValues]] = None,
300327
response: Optional[Union[Any, litellm.ModelResponse]] = None,
301-
request_data: Optional[dict] = None
328+
request_data: Optional[dict] = None,
302329
) -> BedrockGuardrailResponse:
303330
from datetime import datetime
331+
304332
start_time = datetime.now()
305333
credentials, aws_region_name = self._load_credentials()
306334
bedrock_request_data: dict = dict(
307335
self.convert_to_bedrock_format(
308-
source=source,
309-
messages=messages,
310-
response=response
336+
source=source, messages=messages, response=response
311337
)
312338
)
313339
bedrock_guardrail_response: BedrockGuardrailResponse = (
@@ -316,11 +342,13 @@ async def make_bedrock_api_request(
316342
api_key: Optional[str] = None
317343
if request_data:
318344
bedrock_request_data.update(
319-
self.get_guardrail_dynamic_request_body_params(request_data=request_data)
345+
self.get_guardrail_dynamic_request_body_params(
346+
request_data=request_data
347+
)
320348
)
321349
if request_data.get("api_key") is not None:
322350
api_key = request_data["api_key"]
323-
351+
324352
prepared_request = self._prepare_request(
325353
credentials=credentials,
326354
data=bedrock_request_data,
@@ -346,7 +374,9 @@ async def make_bedrock_api_request(
346374
self.add_standard_logging_guardrail_information_to_request_data(
347375
guardrail_json_response=response.json(),
348376
request_data=request_data or {},
349-
guardrail_status=self._get_bedrock_guardrail_response_status(response=response),
377+
guardrail_status=self._get_bedrock_guardrail_response_status(
378+
response=response
379+
),
350380
start_time=start_time.timestamp(),
351381
end_time=datetime.now().timestamp(),
352382
duration=(datetime.now() - start_time).total_seconds(),
@@ -372,16 +402,20 @@ async def make_bedrock_api_request(
372402
)
373403

374404
return bedrock_guardrail_response
375-
376-
def _get_bedrock_guardrail_response_status(self, response: httpx.Response) -> Literal["success", "failure"]:
405+
406+
def _get_bedrock_guardrail_response_status(
407+
self, response: httpx.Response
408+
) -> Literal["success", "failure"]:
377409
"""
378410
Get the status of the bedrock guardrail response.
379411
"""
380412
if response.status_code == 200:
381413
return "success"
382414
return "failure"
383415

384-
def _get_http_exception_for_blocked_guardrail(self, response: BedrockGuardrailResponse) -> HTTPException:
416+
def _get_http_exception_for_blocked_guardrail(
417+
self, response: BedrockGuardrailResponse
418+
) -> HTTPException:
385419
"""
386420
Get the HTTP exception for a blocked guardrail.
387421
"""
@@ -393,17 +427,15 @@ def _get_http_exception_for_blocked_guardrail(self, response: BedrockGuardrailRe
393427
for output in outputs:
394428
if output.get("text"):
395429
bedrock_guardrail_output_text += output.get("text") or ""
396-
397-
430+
398431
return HTTPException(
399432
status_code=400,
400433
detail={
401-
"error": "Violated guardrail policy",
434+
"error": "Violated guardrail policy",
402435
"bedrock_guardrail_response": bedrock_guardrail_output_text,
403-
}
436+
},
404437
)
405438

406-
407439
def _should_raise_guardrail_blocked_exception(
408440
self, response: BedrockGuardrailResponse
409441
) -> bool:
@@ -416,7 +448,7 @@ def _should_raise_guardrail_blocked_exception(
416448
# if user opted into masking, return False. since we'll use the masked output from the guardrail
417449
if self.mask_request_content or self.mask_response_content:
418450
return False
419-
451+
420452
if self.disable_exception_on_block is True:
421453
return False
422454

@@ -631,9 +663,7 @@ async def async_post_call_success_hook(
631663
########## 1. Make parallel Bedrock API requests ##########
632664
#########################################################
633665
output_content_bedrock = await self.make_bedrock_api_request(
634-
source="OUTPUT",
635-
response=response,
636-
request_data=data
666+
source="OUTPUT", response=response, request_data=data
637667
) # Only response
638668

639669
#########################################################
@@ -729,16 +759,16 @@ async def async_post_call_streaming_iterator_hook(
729759
###################################################################
730760
# Create tasks for parallel execution
731761
input_task = self.make_bedrock_api_request(
732-
source="INPUT", messages=request_data.get("messages"), request_data=request_data
762+
source="INPUT",
763+
messages=request_data.get("messages"),
764+
request_data=request_data,
733765
) # Only input messages
734766
output_task = self.make_bedrock_api_request(
735767
source="OUTPUT", response=assembled_model_response
736768
) # Only response
737769

738770
# Execute both requests in parallel
739-
_, output_guardrail_response = await asyncio.gather(
740-
input_task, output_task
741-
)
771+
_, output_guardrail_response = await asyncio.gather(input_task, output_task)
742772

743773
#########################################################################
744774
########## 2. Apply masking to response with output guardrail response ##########
@@ -891,7 +921,7 @@ def _apply_masking_to_response(
891921
) -> None:
892922
"""
893923
Apply masked content from bedrock guardrail to the response object.
894-
924+
895925
Args:
896926
response: The response object to modify
897927
bedrock_guardrail_response: Response from Bedrock guardrail containing masked content
@@ -902,7 +932,9 @@ def _apply_masking_to_response(
902932
)
903933

904934
if not masked_texts:
905-
verbose_proxy_logger.debug("No masked outputs found, skipping response masking")
935+
verbose_proxy_logger.debug(
936+
"No masked outputs found, skipping response masking"
937+
)
906938
return
907939

908940
verbose_proxy_logger.debug(
@@ -922,13 +954,13 @@ def _apply_masking_to_model_response(
922954
) -> None:
923955
"""
924956
Apply masked texts to a ModelResponse object.
925-
957+
926958
Args:
927959
response: The ModelResponse object to modify in-place
928960
masked_texts: List of masked text strings from guardrail
929961
"""
930962
masking_index = 0
931-
963+
932964
for choice in response.choices:
933965
if isinstance(choice, Choices):
934966
# For chat completions

0 commit comments

Comments
 (0)