Skip to content

Commit e7bc700

Browse files
Merge pull request #14640 from Sameerlite/litellm_gardrail_default_latest_message
[Feat]Add last message as default in gaurdrail
2 parents 8d96626 + edf9596 commit e7bc700

File tree

2 files changed

+494
-0
lines changed

2 files changed

+494
-0
lines changed

litellm/llms/bedrock/chat/converse_transformation.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,61 @@ def get_config_blocks(cls) -> dict:
102102
"performanceConfig": PerformanceConfigBlock,
103103
}
104104

105+
@staticmethod
106+
def _convert_consecutive_user_messages_to_guarded_text(
107+
messages: List[AllMessageValues], optional_params: dict
108+
) -> List[AllMessageValues]:
109+
"""
110+
Convert consecutive user messages at the end to guarded_text type if guardrailConfig is present
111+
and no guarded_text is already present in those messages.
112+
"""
113+
# Check if guardrailConfig is present
114+
if "guardrailConfig" not in optional_params:
115+
return messages
116+
117+
# Find all consecutive user messages at the end
118+
consecutive_user_message_indices = []
119+
for i in range(len(messages) - 1, -1, -1):
120+
if messages[i].get("role") == "user":
121+
consecutive_user_message_indices.append(i)
122+
else:
123+
break
124+
125+
if not consecutive_user_message_indices:
126+
return messages
127+
128+
# Process each consecutive user message
129+
messages_copy = copy.deepcopy(messages)
130+
for user_message_index in consecutive_user_message_indices:
131+
user_message = messages_copy[user_message_index]
132+
content = user_message.get("content", [])
133+
134+
if isinstance(content, list):
135+
has_guarded_text = any(
136+
isinstance(item, dict) and item.get("type") == "guarded_text"
137+
for item in content
138+
)
139+
if has_guarded_text:
140+
continue # Skip this message if it already has guarded_text
141+
142+
# Convert text elements to guarded_text
143+
new_content = []
144+
for item in content:
145+
if isinstance(item, dict) and item.get("type") == "text":
146+
new_item = {"type": "guarded_text", "text": item["text"]} # type: ignore
147+
new_content.append(new_item)
148+
else:
149+
new_content.append(item)
150+
151+
messages_copy[user_message_index]["content"] = new_content # type: ignore
152+
elif isinstance(content, str):
153+
# If content is a string, convert it to guarded_text
154+
messages_copy[user_message_index]["content"] = [ # type: ignore
155+
{"type": "guarded_text", "text": content} # type: ignore
156+
]
157+
158+
return messages_copy
159+
105160
@classmethod
106161
def get_config(cls):
107162
return {
@@ -769,6 +824,11 @@ async def _async_transform_request(
769824
headers: Optional[dict] = None,
770825
) -> RequestObject:
771826
messages, system_content_blocks = self._transform_system_message(messages)
827+
828+
# Convert last user message to guarded_text if guardrailConfig is present
829+
messages = self._convert_consecutive_user_messages_to_guarded_text(
830+
messages, optional_params
831+
)
772832
## TRANSFORMATION ##
773833

774834
_data: CommonRequestObject = self._transform_request_helper(
@@ -821,6 +881,11 @@ def _transform_request(
821881
) -> RequestObject:
822882
messages, system_content_blocks = self._transform_system_message(messages)
823883

884+
# Convert last user message to guarded_text if guardrailConfig is present
885+
messages = self._convert_consecutive_user_messages_to_guarded_text(
886+
messages, optional_params
887+
)
888+
824889
_data: CommonRequestObject = self._transform_request_helper(
825890
model=model,
826891
system_content_blocks=system_content_blocks,

0 commit comments

Comments
 (0)