Skip to content

Commit 1371abf

Browse files
committed
add last message as default in gaurdrail
1 parent e168161 commit 1371abf

File tree

2 files changed

+275
-1
lines changed

2 files changed

+275
-1
lines changed

litellm/llms/bedrock/chat/converse_transformation.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,69 @@ def get_config_blocks(cls) -> dict:
102102
"performanceConfig": PerformanceConfigBlock,
103103
}
104104

105+
@staticmethod
106+
def _convert_last_user_message_to_guarded_text(
107+
messages: List[AllMessageValues], optional_params: dict
108+
) -> List[AllMessageValues]:
109+
"""
110+
Convert the last user message to guarded_text type if guardrailConfig is present
111+
and no guarded_text is already present in the last user message.
112+
"""
113+
# Check if guardrailConfig is present
114+
if "guardrailConfig" not in optional_params:
115+
return messages
116+
117+
# Find the last user message
118+
last_user_message = None
119+
last_user_message_index = -1
120+
for i in range(len(messages) - 1, -1, -1):
121+
if messages[i].get("role") == "user":
122+
last_user_message = messages[i]
123+
last_user_message_index = i
124+
break
125+
126+
if last_user_message is None:
127+
return messages
128+
129+
# Check if the last user message already has guarded_text
130+
content = last_user_message.get("content", [])
131+
if isinstance(content, list):
132+
has_guarded_text = any(
133+
isinstance(item, dict) and item.get("type") == "guarded_text"
134+
for item in content
135+
)
136+
if has_guarded_text:
137+
return messages
138+
139+
# Convert text elements to guarded_text
140+
new_content = []
141+
for item in content:
142+
if isinstance(item, dict) and item.get("type") == "text":
143+
new_item = {
144+
"type": "guarded_text",
145+
"text": item["text"]
146+
}
147+
new_content.append(new_item)
148+
else:
149+
new_content.append(item)
150+
151+
# Create a copy of messages and update the last user message
152+
messages_copy = copy.deepcopy(messages)
153+
messages_copy[last_user_message_index]["content"] = new_content
154+
return messages_copy
155+
elif isinstance(content, str):
156+
# If content is a string, convert it to guarded_text
157+
messages_copy = copy.deepcopy(messages)
158+
messages_copy[last_user_message_index]["content"] = [
159+
{
160+
"type": "guarded_text",
161+
"text": content
162+
}
163+
]
164+
return messages_copy
165+
166+
return messages
167+
105168
@classmethod
106169
def get_config(cls):
107170
return {
@@ -769,6 +832,9 @@ async def _async_transform_request(
769832
headers: Optional[dict] = None,
770833
) -> RequestObject:
771834
messages, system_content_blocks = self._transform_system_message(messages)
835+
836+
# Convert last user message to guarded_text if guardrailConfig is present
837+
messages = self._convert_last_user_message_to_guarded_text(messages, optional_params)
772838
## TRANSFORMATION ##
773839

774840
_data: CommonRequestObject = self._transform_request_helper(
@@ -821,6 +887,9 @@ def _transform_request(
821887
) -> RequestObject:
822888
messages, system_content_blocks = self._transform_system_message(messages)
823889

890+
# Convert last user message to guarded_text if guardrailConfig is present
891+
messages = self._convert_last_user_message_to_guarded_text(messages, optional_params)
892+
824893
_data: CommonRequestObject = self._transform_request_helper(
825894
model=model,
826895
system_content_blocks=system_content_blocks,
@@ -1277,4 +1346,4 @@ def should_fake_stream(
12771346
###################################################################
12781347
if "ai21" in model:
12791348
return True
1280-
return False
1349+
return False

tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,3 +1868,208 @@ def test_guarded_text_guardrail_config_preserved():
18681868
assert result["inferenceConfig"]["guardrailConfig"]["guardrailIdentifier"] == "gr-abc123"
18691869

18701870

1871+
def test_auto_convert_last_user_message_to_guarded_text():
1872+
"""Test that last user message is automatically converted to guarded_text when guardrailConfig is present."""
1873+
config = AmazonConverseConfig()
1874+
1875+
messages = [
1876+
{
1877+
"role": "user",
1878+
"content": [
1879+
{
1880+
"type": "text",
1881+
"text": "What is the main topic of this legal document?"
1882+
}
1883+
]
1884+
}
1885+
]
1886+
1887+
optional_params = {
1888+
"guardrailConfig": {
1889+
"guardrailIdentifier": "gr-abc123",
1890+
"guardrailVersion": "1"
1891+
}
1892+
}
1893+
1894+
# Test the helper method directly
1895+
converted_messages = config._convert_last_user_message_to_guarded_text(messages, optional_params)
1896+
1897+
# Verify the conversion
1898+
assert len(converted_messages) == 1
1899+
assert converted_messages[0]["role"] == "user"
1900+
assert len(converted_messages[0]["content"]) == 1
1901+
assert converted_messages[0]["content"][0]["type"] == "guarded_text"
1902+
assert converted_messages[0]["content"][0]["text"] == "What is the main topic of this legal document?"
1903+
1904+
1905+
def test_auto_convert_last_user_message_string_content():
1906+
"""Test that last user message with string content is automatically converted to guarded_text when guardrailConfig is present."""
1907+
config = AmazonConverseConfig()
1908+
1909+
messages = [
1910+
{
1911+
"role": "user",
1912+
"content": "What is the main topic of this legal document?"
1913+
}
1914+
]
1915+
1916+
optional_params = {
1917+
"guardrailConfig": {
1918+
"guardrailIdentifier": "gr-abc123",
1919+
"guardrailVersion": "1"
1920+
}
1921+
}
1922+
1923+
# Test the helper method directly
1924+
converted_messages = config._convert_last_user_message_to_guarded_text(messages, optional_params)
1925+
1926+
# Verify the conversion
1927+
assert len(converted_messages) == 1
1928+
assert converted_messages[0]["role"] == "user"
1929+
assert len(converted_messages[0]["content"]) == 1
1930+
assert converted_messages[0]["content"][0]["type"] == "guarded_text"
1931+
assert converted_messages[0]["content"][0]["text"] == "What is the main topic of this legal document?"
1932+
1933+
1934+
def test_no_conversion_when_no_guardrail_config():
1935+
"""Test that no conversion happens when guardrailConfig is not present."""
1936+
config = AmazonConverseConfig()
1937+
1938+
messages = [
1939+
{
1940+
"role": "user",
1941+
"content": [
1942+
{
1943+
"type": "text",
1944+
"text": "What is the main topic of this legal document?"
1945+
}
1946+
]
1947+
}
1948+
]
1949+
1950+
optional_params = {}
1951+
1952+
# Test the helper method directly
1953+
converted_messages = config._convert_last_user_message_to_guarded_text(messages, optional_params)
1954+
1955+
# Verify no conversion happened
1956+
assert converted_messages == messages
1957+
1958+
1959+
def test_no_conversion_when_guarded_text_already_present():
1960+
"""Test that no conversion happens when guarded_text is already present in the last user message."""
1961+
config = AmazonConverseConfig()
1962+
1963+
messages = [
1964+
{
1965+
"role": "user",
1966+
"content": [
1967+
{
1968+
"type": "guarded_text",
1969+
"text": "This is already guarded"
1970+
}
1971+
]
1972+
}
1973+
]
1974+
1975+
optional_params = {
1976+
"guardrailConfig": {
1977+
"guardrailIdentifier": "gr-abc123",
1978+
"guardrailVersion": "1"
1979+
}
1980+
}
1981+
1982+
# Test the helper method directly
1983+
converted_messages = config._convert_last_user_message_to_guarded_text(messages, optional_params)
1984+
1985+
# Verify no conversion happened
1986+
assert converted_messages == messages
1987+
1988+
1989+
def test_auto_convert_with_mixed_content():
1990+
"""Test that only text elements are converted to guarded_text, other content types are preserved."""
1991+
config = AmazonConverseConfig()
1992+
1993+
messages = [
1994+
{
1995+
"role": "user",
1996+
"content": [
1997+
{
1998+
"type": "text",
1999+
"text": "What is the main topic of this legal document?"
2000+
},
2001+
{
2002+
"type": "image_url",
2003+
"image_url": {"url": "https://example.com/image.jpg"}
2004+
}
2005+
]
2006+
}
2007+
]
2008+
2009+
optional_params = {
2010+
"guardrailConfig": {
2011+
"guardrailIdentifier": "gr-abc123",
2012+
"guardrailVersion": "1"
2013+
}
2014+
}
2015+
2016+
# Test the helper method directly
2017+
converted_messages = config._convert_last_user_message_to_guarded_text(messages, optional_params)
2018+
2019+
# Verify the conversion
2020+
assert len(converted_messages) == 1
2021+
assert converted_messages[0]["role"] == "user"
2022+
assert len(converted_messages[0]["content"]) == 2
2023+
2024+
# First element should be converted to guarded_text
2025+
assert converted_messages[0]["content"][0]["type"] == "guarded_text"
2026+
assert converted_messages[0]["content"][0]["text"] == "What is the main topic of this legal document?"
2027+
2028+
# Second element should remain unchanged
2029+
assert converted_messages[0]["content"][1]["type"] == "image_url"
2030+
assert converted_messages[0]["content"][1]["image_url"]["url"] == "https://example.com/image.jpg"
2031+
2032+
2033+
def test_auto_convert_in_full_transformation():
2034+
"""Test that the automatic conversion works in the full transformation pipeline."""
2035+
config = AmazonConverseConfig()
2036+
2037+
messages = [
2038+
{
2039+
"role": "user",
2040+
"content": [
2041+
{
2042+
"type": "text",
2043+
"text": "What is the main topic of this legal document?"
2044+
}
2045+
]
2046+
}
2047+
]
2048+
2049+
optional_params = {
2050+
"guardrailConfig": {
2051+
"guardrailIdentifier": "gr-abc123",
2052+
"guardrailVersion": "1"
2053+
}
2054+
}
2055+
2056+
# Test the full transformation
2057+
result = config._transform_request(
2058+
model="anthropic.claude-3-sonnet-20240229-v1:0",
2059+
messages=messages,
2060+
optional_params=optional_params,
2061+
litellm_params={},
2062+
headers={}
2063+
)
2064+
2065+
# Verify the transformation worked
2066+
assert "messages" in result
2067+
assert len(result["messages"]) == 1
2068+
2069+
# The message should have guardrailConverseContent
2070+
message = result["messages"][0]
2071+
assert "content" in message
2072+
assert len(message["content"]) == 1
2073+
assert "guardrailConverseContent" in message["content"][0]
2074+
assert message["content"][0]["guardrailConverseContent"]["text"] == "What is the main topic of this legal document?"
2075+

0 commit comments

Comments
 (0)