Skip to content

Commit edf9596

Browse files
committed
Handle consecutive user messages
1 parent 1371abf commit edf9596

File tree

2 files changed

+276
-56
lines changed

2 files changed

+276
-56
lines changed

litellm/llms/bedrock/chat/converse_transformation.py

Lines changed: 47 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -103,67 +103,59 @@ def get_config_blocks(cls) -> dict:
103103
}
104104

105105
@staticmethod
106-
def _convert_last_user_message_to_guarded_text(
106+
def _convert_consecutive_user_messages_to_guarded_text(
107107
messages: List[AllMessageValues], optional_params: dict
108108
) -> List[AllMessageValues]:
109109
"""
110-
Convert the last user message to guarded_text type if guardrailConfig is present
111-
and no guarded_text is already present in the last user message.
110+
Convert consecutive user messages at the end to guarded_text type if guardrailConfig is present
111+
and no guarded_text is already present in those messages.
112112
"""
113113
# Check if guardrailConfig is present
114114
if "guardrailConfig" not in optional_params:
115115
return messages
116116

117-
# Find the last user message
118-
last_user_message = None
119-
last_user_message_index = -1
117+
# Find all consecutive user messages at the end
118+
consecutive_user_message_indices = []
120119
for i in range(len(messages) - 1, -1, -1):
121120
if messages[i].get("role") == "user":
122-
last_user_message = messages[i]
123-
last_user_message_index = i
121+
consecutive_user_message_indices.append(i)
122+
else:
124123
break
125124

126-
if last_user_message is None:
125+
if not consecutive_user_message_indices:
127126
return messages
128127

129-
# Check if the last user message already has guarded_text
130-
content = last_user_message.get("content", [])
131-
if isinstance(content, list):
132-
has_guarded_text = any(
133-
isinstance(item, dict) and item.get("type") == "guarded_text"
134-
for item in content
135-
)
136-
if has_guarded_text:
137-
return messages
138-
139-
# Convert text elements to guarded_text
140-
new_content = []
141-
for item in content:
142-
if isinstance(item, dict) and item.get("type") == "text":
143-
new_item = {
144-
"type": "guarded_text",
145-
"text": item["text"]
146-
}
147-
new_content.append(new_item)
148-
else:
149-
new_content.append(item)
150-
151-
# Create a copy of messages and update the last user message
152-
messages_copy = copy.deepcopy(messages)
153-
messages_copy[last_user_message_index]["content"] = new_content
154-
return messages_copy
155-
elif isinstance(content, str):
156-
# If content is a string, convert it to guarded_text
157-
messages_copy = copy.deepcopy(messages)
158-
messages_copy[last_user_message_index]["content"] = [
159-
{
160-
"type": "guarded_text",
161-
"text": content
162-
}
163-
]
164-
return messages_copy
165-
166-
return messages
128+
# Process each consecutive user message
129+
messages_copy = copy.deepcopy(messages)
130+
for user_message_index in consecutive_user_message_indices:
131+
user_message = messages_copy[user_message_index]
132+
content = user_message.get("content", [])
133+
134+
if isinstance(content, list):
135+
has_guarded_text = any(
136+
isinstance(item, dict) and item.get("type") == "guarded_text"
137+
for item in content
138+
)
139+
if has_guarded_text:
140+
continue # Skip this message if it already has guarded_text
141+
142+
# Convert text elements to guarded_text
143+
new_content = []
144+
for item in content:
145+
if isinstance(item, dict) and item.get("type") == "text":
146+
new_item = {"type": "guarded_text", "text": item["text"]} # type: ignore
147+
new_content.append(new_item)
148+
else:
149+
new_content.append(item)
150+
151+
messages_copy[user_message_index]["content"] = new_content # type: ignore
152+
elif isinstance(content, str):
153+
# If content is a string, convert it to guarded_text
154+
messages_copy[user_message_index]["content"] = [ # type: ignore
155+
{"type": "guarded_text", "text": content} # type: ignore
156+
]
157+
158+
return messages_copy
167159

168160
@classmethod
169161
def get_config(cls):
@@ -832,9 +824,11 @@ async def _async_transform_request(
832824
headers: Optional[dict] = None,
833825
) -> RequestObject:
834826
messages, system_content_blocks = self._transform_system_message(messages)
835-
827+
836828
# Convert last user message to guarded_text if guardrailConfig is present
837-
messages = self._convert_last_user_message_to_guarded_text(messages, optional_params)
829+
messages = self._convert_consecutive_user_messages_to_guarded_text(
830+
messages, optional_params
831+
)
838832
## TRANSFORMATION ##
839833

840834
_data: CommonRequestObject = self._transform_request_helper(
@@ -888,7 +882,9 @@ def _transform_request(
888882
messages, system_content_blocks = self._transform_system_message(messages)
889883

890884
# Convert last user message to guarded_text if guardrailConfig is present
891-
messages = self._convert_last_user_message_to_guarded_text(messages, optional_params)
885+
messages = self._convert_consecutive_user_messages_to_guarded_text(
886+
messages, optional_params
887+
)
892888

893889
_data: CommonRequestObject = self._transform_request_helper(
894890
model=model,
@@ -1346,4 +1342,4 @@ def should_fake_stream(
13461342
###################################################################
13471343
if "ai21" in model:
13481344
return True
1349-
return False
1345+
return False

0 commit comments

Comments
 (0)