Skip to content

Commit b6f6dc5

Browse files
feat(vertex_ai.py): support parsing thinking content into gemini format
allows function calls with thought signatures to be sent back to gemini Closes #13842
1 parent 51c73dc commit b6f6dc5

File tree

4 files changed

+248
-9
lines changed

4 files changed

+248
-9
lines changed

litellm/llms/vertex_ai/gemini/transformation.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,64 @@ def _process_gemini_image(image_url: str, format: Optional[str] = None) -> PartT
105105
raise e
106106

107107

108+
def _snake_to_camel(snake_str: str) -> str:
109+
"""Convert snake_case to camelCase"""
110+
components = snake_str.split("_")
111+
return components[0] + "".join(x.capitalize() for x in components[1:])
112+
113+
114+
def _camel_to_snake(camel_str: str) -> str:
115+
"""Convert camelCase to snake_case"""
116+
import re
117+
118+
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_str).lower()
119+
120+
121+
def _get_equivalent_key(key: str, available_keys: set) -> Optional[str]:
122+
"""
123+
Get the equivalent key from available keys, checking both camelCase and snake_case variants
124+
"""
125+
if key in available_keys:
126+
return key
127+
128+
# Try camelCase version
129+
camel_key = _snake_to_camel(key)
130+
if camel_key in available_keys:
131+
return camel_key
132+
133+
# Try snake_case version
134+
snake_key = _camel_to_snake(key)
135+
if snake_key in available_keys:
136+
return snake_key
137+
138+
return None
139+
140+
141+
def check_if_part_exists_in_parts(
142+
parts: List[PartType], part: PartType, excluded_keys: List[str] = []
143+
) -> bool:
144+
"""
145+
Check if a part exists in a list of parts
146+
Handles both camelCase and snake_case key variations (e.g., function_call vs functionCall)
147+
"""
148+
keys_to_compare = set(part.keys()) - set(excluded_keys)
149+
for p in parts:
150+
p_keys = set(p.keys())
151+
# Check if all keys in part have equivalent values in p
152+
match_found = True
153+
for key in keys_to_compare:
154+
equivalent_key = _get_equivalent_key(key, p_keys)
155+
if equivalent_key is None or p.get(equivalent_key, None) != part.get(
156+
key, None
157+
):
158+
match_found = False
159+
break
160+
161+
if match_found:
162+
return True
163+
return False
164+
165+
108166
def _gemini_convert_messages_with_history( # noqa: PLR0915
109167
messages: List[AllMessageValues],
110168
) -> List[ContentType]:
@@ -236,10 +294,33 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
236294
assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
237295
_message_content = assistant_msg.get("content", None)
238296
reasoning_content = assistant_msg.get("reasoning_content", None)
297+
thinking_blocks = assistant_msg.get("thinking_blocks")
239298
if reasoning_content is not None:
240299
assistant_content.append(
241300
PartType(thought=True, text=reasoning_content)
242301
)
302+
if thinking_blocks is not None:
303+
for block in thinking_blocks:
304+
block_thinking_str = block.get("thinking")
305+
block_signature = block.get("signature")
306+
if (
307+
block_thinking_str is not None
308+
and block_signature is not None
309+
):
310+
try:
311+
assistant_content.append(
312+
PartType(
313+
thoughtSignature=block_signature,
314+
**json.loads(block_thinking_str),
315+
)
316+
)
317+
except Exception:
318+
assistant_content.append(
319+
PartType(
320+
thoughtSignature=block_signature,
321+
text=block_thinking_str,
322+
)
323+
)
243324
if _message_content is not None and isinstance(_message_content, list):
244325
_parts = []
245326
for element in _message_content:
@@ -262,9 +343,17 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
262343
assistant_msg.get("tool_calls", []) is not None
263344
or assistant_msg.get("function_call") is not None
264345
): # support assistant tool invoke conversion
265-
assistant_content.extend(
266-
convert_to_gemini_tool_call_invoke(assistant_msg)
346+
gemini_tool_call_parts = convert_to_gemini_tool_call_invoke(
347+
assistant_msg
267348
)
349+
## check if gemini_tool_call already exists in assistant_content
350+
for gemini_tool_call_part in gemini_tool_call_parts:
351+
if not check_if_part_exists_in_parts(
352+
assistant_content,
353+
gemini_tool_call_part,
354+
excluded_keys=["thoughtSignature"],
355+
):
356+
assistant_content.append(gemini_tool_call_part)
268357
last_message_with_tool_calls = assistant_msg
269358

270359
msg_i += 1
@@ -476,6 +565,7 @@ async def async_transform_request_body(
476565
optional_params=optional_params,
477566
)
478567

568+
479569
def _default_user_message_when_system_message_passed() -> ChatCompletionUserMessage:
480570
"""
481571
Returns a default user message when a "system" message is passed in gemini fails.
@@ -484,6 +574,7 @@ def _default_user_message_when_system_message_passed() -> ChatCompletionUserMess
484574
"""
485575
return ChatCompletionUserMessage(content=".", role="user")
486576

577+
487578
def _transform_system_message(
488579
supports_system_message: bool, messages: List[AllMessageValues]
489580
) -> Tuple[Optional[SystemInstructions], List[AllMessageValues]]:

litellm/types/llms/vertex_ai.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class PartType(TypedDict, total=False):
4141
function_call: FunctionCall
4242
function_response: FunctionResponse
4343
thought: bool
44+
thoughtSignature: str
4445

4546

4647
class HttpxFunctionCall(TypedDict):

tests/llm_translation/test_gemini.py

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -661,10 +661,33 @@ def test_system_message_with_no_user_message():
661661
assert response.choices[0].message.content is not None
662662

663663

664+
def get_current_weather(location, unit="fahrenheit"):
665+
"""Get the current weather in a given location"""
666+
if "tokyo" in location.lower():
667+
return json.dumps({"location": "Tokyo", "temperature": "10", "unit": "celsius"})
668+
elif "san francisco" in location.lower():
669+
return json.dumps(
670+
{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}
671+
)
672+
elif "paris" in location.lower():
673+
return json.dumps({"location": "Paris", "temperature": "22", "unit": "celsius"})
674+
else:
675+
return json.dumps({"location": location, "temperature": "unknown"})
676+
677+
664678
def test_gemini_with_thinking():
665679
from litellm import completion
666680

667681
litellm._turn_on_debug()
682+
litellm.modify_params = True
683+
model = "gemini/gemini-2.5-flash"
684+
messages = [
685+
{
686+
"role": "user",
687+
"content": "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses",
688+
}
689+
]
690+
668691
tools = [
669692
{
670693
"type": "function",
@@ -676,20 +699,69 @@ def test_gemini_with_thinking():
676699
"properties": {
677700
"location": {
678701
"type": "string",
679-
"description": "The city and state, e.g. San Francisco, CA",
702+
"description": "The city and state",
703+
},
704+
"unit": {
705+
"type": "string",
706+
"enum": ["celsius", "fahrenheit"],
680707
},
681-
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
682708
},
683709
"required": ["location"],
684710
},
685711
},
686712
}
687713
]
688-
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
689-
690-
result = completion(
691-
model="gemini/gemini-2.5-flash",
714+
response = litellm.completion(
715+
model=model,
692716
messages=messages,
693717
tools=tools,
718+
tool_choice="auto", # auto is default, but we'll be explicit
719+
reasoning_effort="low",
694720
)
695-
print(f"result: {result}")
721+
print("Response\n", response)
722+
response_message = response.choices[0].message
723+
tool_calls = response_message.tool_calls
724+
725+
print("Expecting there to be 3 tool calls")
726+
assert len(tool_calls) > 0 # this has to call the function for SF, Tokyo and paris
727+
728+
# Step 2: check if the model wanted to call a function
729+
print(f"tool_calls: {tool_calls}")
730+
if tool_calls:
731+
# Step 3: call the function
732+
# Note: the JSON response may not always be valid; be sure to handle errors
733+
available_functions = {
734+
"get_current_weather": get_current_weather,
735+
} # only one function in this example, but you can have multiple
736+
messages.append(response_message) # extend conversation with assistant's reply
737+
print("Response message\n", response_message)
738+
# Step 4: send the info for each function call and function response to the model
739+
for tool_call in tool_calls:
740+
function_name = tool_call.function.name
741+
if function_name not in available_functions:
742+
# the model called a function that does not exist in available_functions - don't try calling anything
743+
return
744+
function_to_call = available_functions[function_name]
745+
function_args = json.loads(tool_call.function.arguments)
746+
function_response = function_to_call(
747+
location=function_args.get("location"),
748+
unit=function_args.get("unit"),
749+
)
750+
messages.append(
751+
{
752+
"tool_call_id": tool_call.id,
753+
"role": "tool",
754+
"name": function_name,
755+
"content": function_response,
756+
}
757+
) # extend conversation with function response
758+
print(f"messages: {messages}")
759+
second_response = litellm.completion(
760+
model=model,
761+
messages=messages,
762+
seed=22,
763+
reasoning_effort="low",
764+
tools=tools,
765+
drop_params=True,
766+
) # get a new response from the model where it can see the function response
767+
print("second response\n", second_response)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from litellm.llms.vertex_ai.gemini.transformation import check_if_part_exists_in_parts
2+
3+
4+
def test_check_if_part_exists_in_parts():
5+
parts = [
6+
{"text": "Hello", "thought": True},
7+
{"text": "World", "thought": False},
8+
]
9+
part = {"text": "Hello", "thought": True}
10+
new_part = {"text": "Hello World", "thought": True}
11+
assert check_if_part_exists_in_parts(parts, part)
12+
assert not check_if_part_exists_in_parts(parts, new_part, ["thought"])
13+
assert check_if_part_exists_in_parts(parts, new_part, ["text"])
14+
15+
16+
def test_check_if_part_exists_in_parts_camel_case_snake_case():
17+
"""Test that function handles both camelCase and snake_case key variations"""
18+
# Test snake_case to camelCase matching
19+
parts_with_snake_case = [
20+
{
21+
"function_call": {
22+
"name": "get_current_weather",
23+
"args": {"location": "San Francisco, CA"},
24+
}
25+
},
26+
{"text": "Some other content"},
27+
]
28+
29+
part_with_camel_case = {
30+
"functionCall": {
31+
"name": "get_current_weather",
32+
"args": {"location": "San Francisco, CA"},
33+
}
34+
}
35+
36+
# Should find match between function_call and functionCall
37+
assert check_if_part_exists_in_parts(parts_with_snake_case, part_with_camel_case)
38+
39+
# Test camelCase to snake_case matching
40+
parts_with_camel_case = [
41+
{"functionCall": {"name": "calculate_sum", "args": {"a": 1, "b": 2}}}
42+
]
43+
44+
part_with_snake_case = {
45+
"function_call": {"name": "calculate_sum", "args": {"a": 1, "b": 2}}
46+
}
47+
48+
# Should find match between functionCall and function_call
49+
assert check_if_part_exists_in_parts(parts_with_camel_case, part_with_snake_case)
50+
51+
# Test no match when values differ
52+
part_with_different_values = {
53+
"function_call": {"name": "different_function", "args": {"x": 5}}
54+
}
55+
56+
assert not check_if_part_exists_in_parts(
57+
parts_with_snake_case, part_with_different_values
58+
)
59+
60+
# Test multiple keys with mixed casing
61+
parts_mixed = [
62+
{
63+
"function_call": {"name": "test"},
64+
"thoughtSignature": "reasoning",
65+
"text": "content",
66+
}
67+
]
68+
69+
part_mixed_casing = {
70+
"functionCall": {"name": "test"},
71+
"thought_signature": "reasoning",
72+
"text": "content",
73+
}
74+
75+
assert check_if_part_exists_in_parts(parts_mixed, part_mixed_casing)

0 commit comments

Comments
 (0)