Skip to content

Commit ab1fb2b

Browse files
authored
Add Support for Bedrock Guardrails to supportive selective Guarding (#14575)
* Add Support for Bedrock Guardrails to supportive selective Guarding * Add method for better handling * Add guarded_text content type * Add guarded_text content type * Update Dockerfile * Update Dockerfile
1 parent 7de8811 commit ab1fb2b

File tree

6 files changed

+427
-81
lines changed

6 files changed

+427
-81
lines changed

docs/my-website/docs/providers/bedrock.md

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,19 @@ curl http://0.0.0.0:4000/v1/chat/completions \
889889

890890
Example of using [Bedrock Guardrails with LiteLLM](https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails-use-converse-api.html)
891891

892+
### Selective Content Moderation with `guarded_text`
893+
894+
LiteLLM supports selective content moderation using the `guarded_text` content type. This allows you to wrap only specific content that should be moderated by Bedrock Guardrails, rather than evaluating the entire conversation.
895+
896+
**How it works:**
897+
- Content with `type: "guarded_text"` gets automatically wrapped in `guardrailConverseContent` blocks
898+
- Only the wrapped content is evaluated by Bedrock Guardrails
899+
- Regular content with `type: "text"` bypasses guardrail evaluation
900+
901+
:::note
902+
If `guarded_text` is not used, the entire conversation history will be sent to the guardrail for evaluation, which can increase latency and costs.
903+
:::
904+
892905
<Tabs>
893906
<TabItem value="sdk" label="LiteLLM SDK">
894907

@@ -915,6 +928,24 @@ response = completion(
915928
"trace": "disabled", # The trace behavior for the guardrail. Can either be "disabled" or "enabled"
916929
},
917930
)
931+
932+
# Selective guardrail usage with guarded_text - only specific content is evaluated
933+
response_guard = completion(
934+
model="anthropic.claude-v2",
935+
messages=[
936+
{
937+
"role": "user",
938+
"content": [
939+
{"type": "text", "text": "What is the main topic of this legal document?"},
940+
{"type": "guarded_text", "text": "This document contains sensitive legal information that should be moderated by guardrails."}
941+
]
942+
}
943+
],
944+
guardrailConfig={
945+
"guardrailIdentifier": "gr-abc123",
946+
"guardrailVersion": "DRAFT"
947+
}
948+
)
918949
```
919950
</TabItem>
920951
<TabItem value="proxy" label="Proxy on request">
@@ -993,7 +1024,20 @@ response = client.chat.completions.create(model="bedrock-claude-v1", messages =
9931024
temperature=0.7
9941025
)
9951026

996-
print(response)
1027+
# For adding selective guardrail usage with guarded_text
1028+
response_guard = client.chat.completions.create(model="bedrock-claude-v1", messages = [
1029+
{
1030+
"role": "user",
1031+
"content": [
1032+
{"type": "text", "text": "What is the main topic of this legal document?"},
1033+
{"type": "guarded_text", "text": "This document contains sensitive legal information that should be moderated by guardrails."}
1034+
]
1035+
}
1036+
],
1037+
temperature=0.7
1038+
)
1039+
1040+
print(response_guard)
9971041
```
9981042
</TabItem>
9991043
</Tabs>

litellm/litellm_core_utils/prompt_templates/factory.py

Lines changed: 69 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from litellm.llms.custom_httpx.http_handler import HTTPHandler, get_async_httpx_client
1717
from litellm.types.files import get_file_extension_from_mime_type
1818
from litellm.types.llms.anthropic import *
19-
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
2019
from litellm.types.llms.bedrock import CachePointBlock
20+
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
2121
from litellm.types.llms.custom_http import httpxSpecialProvider
2222
from litellm.types.llms.ollama import OllamaVisionModelObject
2323
from litellm.types.llms.openai import (
@@ -1067,10 +1067,10 @@ def convert_to_gemini_tool_call_invoke(
10671067
if tool_calls is not None:
10681068
for tool in tool_calls:
10691069
if "function" in tool:
1070-
gemini_function_call: Optional[VertexFunctionCall] = (
1071-
_gemini_tool_call_invoke_helper(
1072-
function_call_params=tool["function"]
1073-
)
1070+
gemini_function_call: Optional[
1071+
VertexFunctionCall
1072+
] = _gemini_tool_call_invoke_helper(
1073+
function_call_params=tool["function"]
10741074
)
10751075
if gemini_function_call is not None:
10761076
_parts_list.append(
@@ -1589,9 +1589,9 @@ def anthropic_messages_pt( # noqa: PLR0915
15891589
)
15901590

15911591
if "cache_control" in _content_element:
1592-
_anthropic_content_element["cache_control"] = (
1593-
_content_element["cache_control"]
1594-
)
1592+
_anthropic_content_element[
1593+
"cache_control"
1594+
] = _content_element["cache_control"]
15951595
user_content.append(_anthropic_content_element)
15961596
elif m.get("type", "") == "text":
15971597
m = cast(ChatCompletionTextObject, m)
@@ -1629,9 +1629,9 @@ def anthropic_messages_pt( # noqa: PLR0915
16291629
)
16301630

16311631
if "cache_control" in _content_element:
1632-
_anthropic_content_text_element["cache_control"] = (
1633-
_content_element["cache_control"]
1634-
)
1632+
_anthropic_content_text_element[
1633+
"cache_control"
1634+
] = _content_element["cache_control"]
16351635

16361636
user_content.append(_anthropic_content_text_element)
16371637

@@ -2482,8 +2482,7 @@ def _validate_format(mime_type: str, image_format: str) -> str:
24822482

24832483
if is_document:
24842484
return BedrockImageProcessor._get_document_format(
2485-
mime_type=mime_type,
2486-
supported_doc_formats=supported_doc_formats
2485+
mime_type=mime_type, supported_doc_formats=supported_doc_formats
24872486
)
24882487

24892488
else:
@@ -2495,12 +2494,9 @@ def _validate_format(mime_type: str, image_format: str) -> str:
24952494
f"Unsupported image format: {image_format}. Supported formats: {supported_image_and_video_formats}"
24962495
)
24972496
return image_format
2498-
2497+
24992498
@staticmethod
2500-
def _get_document_format(
2501-
mime_type: str,
2502-
supported_doc_formats: List[str]
2503-
) -> str:
2499+
def _get_document_format(mime_type: str, supported_doc_formats: List[str]) -> str:
25042500
"""
25052501
Get the document format from the mime type
25062502
@@ -2519,13 +2515,9 @@ def _get_document_format(
25192515
The document format
25202516
"""
25212517
valid_extensions: Optional[List[str]] = None
2522-
potential_extensions = mimetypes.guess_all_extensions(
2523-
mime_type, strict=False
2524-
)
2518+
potential_extensions = mimetypes.guess_all_extensions(mime_type, strict=False)
25252519
valid_extensions = [
2526-
ext[1:]
2527-
for ext in potential_extensions
2528-
if ext[1:] in supported_doc_formats
2520+
ext[1:] for ext in potential_extensions if ext[1:] in supported_doc_formats
25292521
]
25302522

25312523
# Fallback to types/files.py if mimetypes doesn't return valid extensions
@@ -2689,10 +2681,12 @@ def _convert_to_bedrock_tool_call_invoke(
26892681
)
26902682
bedrock_content_block = BedrockContentBlock(toolUse=bedrock_tool)
26912683
_parts_list.append(bedrock_content_block)
2692-
2684+
26932685
# Check for cache_control and add a separate cachePoint block
26942686
if tool.get("cache_control", None) is not None:
2695-
cache_point_block = BedrockContentBlock(cachePoint=CachePointBlock(type="default"))
2687+
cache_point_block = BedrockContentBlock(
2688+
cachePoint=CachePointBlock(type="default")
2689+
)
26962690
_parts_list.append(cache_point_block)
26972691
return _parts_list
26982692
except Exception as e:
@@ -2754,7 +2748,7 @@ def _convert_to_bedrock_tool_call_result(
27542748
for content in content_list:
27552749
if content["type"] == "text":
27562750
content_str += content["text"]
2757-
2751+
27582752
message.get("name", "")
27592753
id = str(message.get("tool_call_id", str(uuid.uuid4())))
27602754

@@ -2763,7 +2757,7 @@ def _convert_to_bedrock_tool_call_result(
27632757
content=[tool_result_content_block],
27642758
toolUseId=id,
27652759
)
2766-
2760+
27672761
content_block = BedrockContentBlock(toolResult=tool_result)
27682762

27692763
return content_block
@@ -3085,6 +3079,7 @@ def _initial_message_setup(
30853079
messages.append(DEFAULT_USER_CONTINUE_MESSAGE)
30863080
return messages
30873081

3082+
30883083
@staticmethod
30893084
async def _bedrock_converse_messages_pt_async( # noqa: PLR0915
30903085
messages: List,
@@ -3128,6 +3123,12 @@ async def _bedrock_converse_messages_pt_async( # noqa: PLR0915
31283123
if element["type"] == "text":
31293124
_part = BedrockContentBlock(text=element["text"])
31303125
_parts.append(_part)
3126+
elif element["type"] == "guarded_text":
3127+
# Wrap guarded_text in guardrailConverseContent block
3128+
_part = BedrockContentBlock(
3129+
guardrailConverseContent={"text": element["text"]}
3130+
)
3131+
_parts.append(_part)
31313132
elif element["type"] == "image_url":
31323133
format: Optional[str] = None
31333134
if isinstance(element["image_url"], dict):
@@ -3170,6 +3171,7 @@ async def _bedrock_converse_messages_pt_async( # noqa: PLR0915
31703171

31713172
msg_i += 1
31723173
if user_content:
3174+
31733175
if len(contents) > 0 and contents[-1]["role"] == "user":
31743176
if (
31753177
assistant_continue_message is not None
@@ -3199,26 +3201,29 @@ async def _bedrock_converse_messages_pt_async( # noqa: PLR0915
31993201
current_message = messages[msg_i]
32003202
tool_call_result = _convert_to_bedrock_tool_call_result(current_message)
32013203
tool_content.append(tool_call_result)
3202-
3204+
32033205
# Check if we need to add a separate cachePoint block
32043206
has_cache_control = False
3205-
3207+
32063208
# Check for message-level cache_control
32073209
if current_message.get("cache_control", None) is not None:
32083210
has_cache_control = True
32093211
# Check for content-level cache_control in list content
32103212
elif isinstance(current_message.get("content"), list):
32113213
for content_element in current_message["content"]:
3212-
if (isinstance(content_element, dict) and
3213-
content_element.get("cache_control", None) is not None):
3214+
if (
3215+
isinstance(content_element, dict)
3216+
and content_element.get("cache_control", None) is not None
3217+
):
32143218
has_cache_control = True
32153219
break
3216-
3220+
32173221
# Add a separate cachePoint block if cache_control is present
32183222
if has_cache_control:
3219-
cache_point_block = BedrockContentBlock(cachePoint=CachePointBlock(type="default"))
3223+
cache_point_block = BedrockContentBlock(
3224+
cachePoint=CachePointBlock(type="default")
3225+
)
32203226
tool_content.append(cache_point_block)
3221-
32223227

32233228
msg_i += 1
32243229
if tool_content:
@@ -3299,7 +3304,7 @@ async def _bedrock_converse_messages_pt_async( # noqa: PLR0915
32993304
image_url=image_url
33003305
)
33013306
assistants_parts.append(assistants_part)
3302-
# Add cache point block for assistant content elements
3307+
# Add cache point block for assistant content elements
33033308
_cache_point_block = (
33043309
litellm.AmazonConverseConfig()._get_cache_point_block(
33053310
message_block=cast(
@@ -3311,8 +3316,12 @@ async def _bedrock_converse_messages_pt_async( # noqa: PLR0915
33113316
if _cache_point_block is not None:
33123317
assistants_parts.append(_cache_point_block)
33133318
assistant_content.extend(assistants_parts)
3314-
elif _assistant_content is not None and isinstance(_assistant_content, str):
3315-
assistant_content.append(BedrockContentBlock(text=_assistant_content))
3319+
elif _assistant_content is not None and isinstance(
3320+
_assistant_content, str
3321+
):
3322+
assistant_content.append(
3323+
BedrockContentBlock(text=_assistant_content)
3324+
)
33163325
# Add cache point block for assistant string content
33173326
_cache_point_block = (
33183327
litellm.AmazonConverseConfig()._get_cache_point_block(
@@ -3496,6 +3505,12 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
34963505
if element["type"] == "text":
34973506
_part = BedrockContentBlock(text=element["text"])
34983507
_parts.append(_part)
3508+
elif element["type"] == "guarded_text":
3509+
# Wrap guarded_text in guardrailConverseContent block
3510+
_part = BedrockContentBlock(
3511+
guardrailConverseContent={"text": element["text"]}
3512+
)
3513+
_parts.append(_part)
34993514
elif element["type"] == "image_url":
35003515
format: Optional[str] = None
35013516
if isinstance(element["image_url"], dict):
@@ -3539,6 +3554,7 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
35393554

35403555
msg_i += 1
35413556
if user_content:
3557+
35423558
if len(contents) > 0 and contents[-1]["role"] == "user":
35433559
if (
35443560
assistant_continue_message is not None
@@ -3565,29 +3581,33 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
35653581
while msg_i < len(messages) and messages[msg_i]["role"] == "tool":
35663582
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
35673583
current_message = messages[msg_i]
3568-
3584+
35693585
# Add the tool result first
35703586
tool_content.append(tool_call_result)
3571-
3587+
35723588
# Check if we need to add a separate cachePoint block
35733589
has_cache_control = False
3574-
3590+
35753591
# Check for message-level cache_control
35763592
if current_message.get("cache_control", None) is not None:
35773593
has_cache_control = True
35783594
# Check for content-level cache_control in list content
35793595
elif isinstance(current_message.get("content"), list):
35803596
for content_element in current_message["content"]:
3581-
if (isinstance(content_element, dict) and
3582-
content_element.get("cache_control", None) is not None):
3597+
if (
3598+
isinstance(content_element, dict)
3599+
and content_element.get("cache_control", None) is not None
3600+
):
35833601
has_cache_control = True
35843602
break
3585-
3603+
35863604
# Add a separate cachePoint block if cache_control is present
35873605
if has_cache_control:
3588-
cache_point_block = BedrockContentBlock(cachePoint=CachePointBlock(type="default"))
3606+
cache_point_block = BedrockContentBlock(
3607+
cachePoint=CachePointBlock(type="default")
3608+
)
35893609
tool_content.append(cache_point_block)
3590-
3610+
35913611
msg_i += 1
35923612
if tool_content:
35933613
# if last message was a 'user' message, then add a blank assistant message (bedrock requires alternating roles)
@@ -3852,10 +3872,9 @@ def function_call_prompt(messages: list, functions: list):
38523872
if isinstance(message["content"], str):
38533873
message["content"] += f""" {function_prompt}"""
38543874
else:
3855-
message["content"].append({
3856-
"type": "text",
3857-
"text": f""" {function_prompt}"""
3858-
})
3875+
message["content"].append(
3876+
{"type": "text", "text": f""" {function_prompt}"""}
3877+
)
38593878
function_added_to_prompt = True
38603879

38613880
if function_added_to_prompt is False:

0 commit comments

Comments
 (0)