Skip to content

Commit 81d5498

Browse files
authored
fix(genai, vertexai): thought signature handling (#1354)
Improve documentation with updated docstrings, enhance test coverage, and fix handling of thought signature indexing. Remove unnecessary validation for Gemini in model names.
1 parent c0ce7c3 commit 81d5498

File tree

8 files changed

+716
-36
lines changed

8 files changed

+716
-36
lines changed

libs/genai/langchain_google_genai/_common.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ class _BaseGoogleGenerativeAI(BaseModel):
8080
If unset, will use the model's default value, which varies by model.
8181
8282
See [docs](https://ai.google.dev/gemini-api/docs/models) for model-specific limits.
83+
84+
To constrain the number of thinking tokens to use when generating a response, see
85+
the `thinking_budget` parameter.
8386
"""
8487

8588
n: int = 1
@@ -157,20 +160,38 @@ class _BaseGoogleGenerativeAI(BaseModel):
157160
)
158161
"""A list of modalities of the response"""
159162

160-
thinking_budget: int | None = Field(
163+
media_resolution: MediaResolution | None = Field(
161164
default=None,
162165
)
163-
"""Indicates the thinking budget in tokens."""
166+
"""Media resolution for the input media."""
164167

165-
media_resolution: MediaResolution | None = Field(
168+
thinking_budget: int | None = Field(
166169
default=None,
167170
)
168-
"""Media resolution for the input media."""
171+
"""Indicates the thinking budget in tokens.
172+
173+
Used to disable thinking for supported models (when set to `0`) or to constrain
174+
the number of tokens used for thinking.
175+
176+
Dynamic thinking (allowing the model to decide how many tokens to use) is
177+
enabled when set to `-1`.
178+
179+
More information, including per-model limits, can be found in the
180+
[Gemini API docs](https://ai.google.dev/gemini-api/docs/thinking#set-budget).
181+
"""
169182

170183
include_thoughts: bool | None = Field(
171184
default=None,
172185
)
173-
"""Indicates whether to include thoughts in the response."""
186+
"""Indicates whether to include thoughts in the response.
187+
188+
!!! note
189+
190+
This parameter is only applicable for models that support thinking.
191+
192+
This does not disable thinking; to disable thinking, set `thinking_budget` to
193+
`0`. for supported models. See the `thinking_budget` parameter for more details.
194+
"""
174195

175196
safety_settings: dict[HarmCategory, HarmBlockThreshold] | None = None
176197
"""Default safety settings to use for all generations.

libs/genai/langchain_google_genai/chat_models.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,9 @@ def _chat_with_retry(**kwargs: Any) -> Any:
222222
except Exception:
223223
raise
224224

225-
params = (
226-
{k: v for k, v in kwargs.items() if k in _allowed_params_prediction_service}
227-
if (request := kwargs.get("request"))
228-
and hasattr(request, "model")
229-
and "gemini" in request.model
230-
else kwargs
231-
)
225+
params = {
226+
k: v for k, v in kwargs.items() if k in _allowed_params_prediction_service
227+
}
232228
return _chat_with_retry(**params)
233229

234230

@@ -271,13 +267,9 @@ async def _achat_with_retry(**kwargs: Any) -> Any:
271267
except Exception:
272268
raise
273269

274-
params = (
275-
{k: v for k, v in kwargs.items() if k in _allowed_params_prediction_service}
276-
if (request := kwargs.get("request"))
277-
and hasattr(request, "model")
278-
and "gemini" in request.model
279-
else kwargs
280-
)
270+
params = {
271+
k: v for k, v in kwargs.items() if k in _allowed_params_prediction_service
272+
}
281273
return await _achat_with_retry(**params)
282274

283275

@@ -654,7 +646,10 @@ def _parse_chat_history(
654646
if sig_str and isinstance(sig_str, str):
655647
# Decode base64-encoded signature back to bytes
656648
sig_bytes = base64.b64decode(sig_str)
657-
function_call_sigs[idx] = sig_bytes
649+
if "index" in item:
650+
function_call_sigs[item["index"]] = sig_bytes
651+
else:
652+
function_call_sigs[idx] = sig_bytes
658653

659654
for tool_call_idx, tool_call in enumerate(message.tool_calls):
660655
function_call = FunctionCall(
@@ -911,6 +906,7 @@ def _parse_response_candidate(
911906
sig_block = {
912907
"type": "function_call_signature",
913908
"signature": thought_sig,
909+
"index": len(tool_calls) - 1,
914910
}
915911
function_call_signatures.append(sig_block)
916912

@@ -1651,6 +1647,25 @@ class Joke(BaseModel):
16511647
success rates and mitigation strategies like prompt...
16521648
```
16531649
1650+
Thinking:
1651+
For thinking models, you have the option to adjust the number of internal
1652+
thinking tokens used (`thinking_budget`) or to disable thinking altogether.
1653+
Note that not all models allow disabling thinking.
1654+
1655+
See the [Gemini API docs](https://ai.google.dev/gemini-api/docs/thinking) for
1656+
more details on thinking models.
1657+
1658+
To see a thinking model's thoughts, set `include_thoughts=True` to have the
1659+
model's reasoning summaries included in the response.
1660+
1661+
```python
1662+
llm = ChatGoogleGenerativeAI(
1663+
model="gemini-2.5-flash",
1664+
include_thoughts=True,
1665+
)
1666+
ai_msg = llm.invoke("How many 'r's are in the word 'strawberry'?")
1667+
```
1668+
16541669
Token usage:
16551670
```python
16561671
ai_msg = llm.invoke(messages)

libs/genai/tests/integration_tests/test_chat_models.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import json
55
from collections.abc import Generator, Sequence
66
from typing import Literal, cast
7+
from unittest.mock import patch
78

89
import pytest
910
from langchain_core.messages import (
@@ -459,7 +460,82 @@ def analyze_weather(location: str, date: str) -> dict:
459460

460461
# Test we can pass the result back in (with signature)
461462
next_message = {"role": "user", "content": "Thanks!"}
462-
_ = llm_with_tools.invoke([input_message, result, next_message])
463+
follow_up_result = llm_with_tools.invoke([input_message, result, next_message])
464+
465+
# Verify the follow-up call succeeded and returned a valid response
466+
assert isinstance(follow_up_result, AIMessage)
467+
assert follow_up_result.content is not None
468+
469+
# If there were signatures in the original response, verify they were properly
470+
# handled in the follow-up (no errors should occur)
471+
if signature_blocks:
472+
# The fact that we got a successful response means signatures were converted
473+
# correctly
474+
# Additional verification that response metadata is preserved
475+
assert "model_provider" in follow_up_result.response_metadata
476+
assert (
477+
follow_up_result.response_metadata["model_provider"] == "google_genai"
478+
)
479+
480+
481+
@pytest.mark.flaky(retries=3, delay=1)
482+
def test_thought_signature_round_trip() -> None:
483+
"""Test thought signatures are properly preserved in round-trip conversations."""
484+
485+
@tool
486+
def simple_tool(query: str) -> str:
487+
"""A simple tool for testing."""
488+
return f"Response to: {query}"
489+
490+
llm = ChatGoogleGenerativeAI(
491+
model=_THINKING_MODEL, include_thoughts=True, output_version="v1"
492+
)
493+
llm_with_tools = llm.bind_tools([simple_tool])
494+
495+
# First call with function calling to generate signatures
496+
first_message = {
497+
"role": "user",
498+
"content": "Use the tool to help answer: What is 2+2?",
499+
}
500+
501+
# Patch the conversion function to verify it's called with signatures
502+
with patch(
503+
"langchain_google_genai.chat_models._convert_from_v1_to_generativelanguage_v1beta"
504+
) as mock_convert:
505+
# Set up the mock to call the real function but also track calls
506+
from langchain_google_genai._compat import (
507+
_convert_from_v1_to_generativelanguage_v1beta as real_convert,
508+
)
509+
510+
mock_convert.side_effect = real_convert
511+
512+
first_result = llm_with_tools.invoke([first_message])
513+
514+
# Verify we got a response with structured content (contains signatures)
515+
assert isinstance(first_result, AIMessage)
516+
assert isinstance(first_result.content, list)
517+
518+
# Second call - this should trigger signature conversion
519+
second_message = {"role": "user", "content": "Thanks!"}
520+
second_result = llm_with_tools.invoke(
521+
[first_message, first_result, second_message]
522+
)
523+
524+
# Verify the conversion function was called when processing the first_result
525+
# (it should be called once for the first_result message)
526+
assert mock_convert.call_count >= 1
527+
528+
# Find the call that processed our AI message with signatures
529+
ai_message_calls = [
530+
call
531+
for call in mock_convert.call_args_list
532+
if call[0][1] == "google_genai" # model_provider argument
533+
]
534+
assert len(ai_message_calls) >= 1
535+
536+
# Verify the second call succeeded (signatures were properly converted)
537+
assert isinstance(second_result, AIMessage)
538+
assert second_result.content is not None
463539

464540

465541
def test_chat_google_genai_invoke_thinking_disabled() -> None:

0 commit comments

Comments
 (0)