Skip to content

Commit e024080

Browse files
authored
feat(genai): store thought signatures in additional_kwargs (#1358)
1 parent 81d5498 commit e024080

File tree

2 files changed

+47
-61
lines changed

2 files changed

+47
-61
lines changed

libs/genai/langchain_google_genai/chat_models.py

Lines changed: 34 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,19 @@
134134
_FunctionDeclarationType = FunctionDeclaration | dict[str, Any] | Callable[..., Any]
135135

136136

137+
_FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY = (
138+
"__gemini_function_call_thought_signatures__"
139+
)
140+
141+
142+
def _bytes_to_base64(data: bytes) -> str:
143+
return base64.b64encode(data).decode("utf-8")
144+
145+
146+
def _base64_to_bytes(input_str: str) -> bytes:
147+
return base64.b64decode(input_str.encode("utf-8"))
148+
149+
137150
class ChatGoogleGenerativeAIError(GoogleGenerativeAIError):
138151
"""Custom exception class for errors associated with the `Google GenAI` API.
139152
@@ -387,11 +400,6 @@ def _convert_to_parts(
387400
metadata = VideoMetadata(part["video_metadata"])
388401
media_part.video_metadata = metadata
389402
parts.append(media_part)
390-
elif part["type"] == "function_call_signature":
391-
# Signature for function_call Part - skip it here as it should be
392-
# attached to the actual function_call Part
393-
# This is handled separately in the history parsing logic
394-
pass
395403
elif part["type"] == "thinking":
396404
# Pre-existing thinking block format that we continue to store as
397405
thought_sig = None
@@ -634,23 +642,9 @@ def _parse_chat_history(
634642
role = "model"
635643
if message.tool_calls:
636644
ai_message_parts = []
637-
# Extract any function_call_signature blocks from content
638-
function_call_sigs: dict[int, bytes] = {}
639-
if isinstance(message.content, list):
640-
for idx, item in enumerate(message.content):
641-
if (
642-
isinstance(item, dict)
643-
and item.get("type") == "function_call_signature"
644-
):
645-
sig_str = item.get("signature", "")
646-
if sig_str and isinstance(sig_str, str):
647-
# Decode base64-encoded signature back to bytes
648-
sig_bytes = base64.b64decode(sig_str)
649-
if "index" in item:
650-
function_call_sigs[item["index"]] = sig_bytes
651-
else:
652-
function_call_sigs[idx] = sig_bytes
653-
645+
function_call_sigs: dict[Any, str] = message.additional_kwargs.get(
646+
_FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY, {}
647+
)
654648
for tool_call_idx, tool_call in enumerate(message.tool_calls):
655649
function_call = FunctionCall(
656650
{
@@ -659,11 +653,13 @@ def _parse_chat_history(
659653
}
660654
)
661655
# Check if there's a signature for this function call
662-
# (We use the index to match signature to function call)
663-
sig = function_call_sigs.get(tool_call_idx)
656+
sig = function_call_sigs.get(tool_call.get("id"))
664657
if sig:
665658
ai_message_parts.append(
666-
Part(function_call=function_call, thought_signature=sig)
659+
Part(
660+
function_call=function_call,
661+
thought_signature=_base64_to_bytes(sig),
662+
)
667663
)
668664
else:
669665
ai_message_parts.append(Part(function_call=function_call))
@@ -742,9 +738,6 @@ def _parse_response_candidate(
742738
tool_calls = []
743739
invalid_tool_calls = []
744740
tool_call_chunks = []
745-
# Track function call signatures separately to handle them conditionally
746-
function_call_signatures: list[dict] = []
747-
748741
for part in response_candidate.content.parts:
749742
text: str | None = None
750743
try:
@@ -867,12 +860,13 @@ def _parse_response_candidate(
867860
function_call["arguments"] = json.dumps(corrected_args)
868861
additional_kwargs["function_call"] = function_call
869862

863+
tool_call_id = function_call.get("id", str(uuid.uuid4()))
870864
if streaming:
871865
tool_call_chunks.append(
872866
tool_call_chunk(
873867
name=function_call.get("name"),
874868
args=function_call.get("arguments"),
875-
id=function_call.get("id", str(uuid.uuid4())),
869+
id=tool_call_id,
876870
index=function_call.get("index"), # type: ignore
877871
)
878872
)
@@ -887,7 +881,7 @@ def _parse_response_candidate(
887881
invalid_tool_call(
888882
name=function_call.get("name"),
889883
args=function_call.get("arguments"),
890-
id=function_call.get("id", str(uuid.uuid4())),
884+
id=tool_call_id,
891885
error=str(e),
892886
)
893887
)
@@ -896,26 +890,21 @@ def _parse_response_candidate(
896890
tool_call(
897891
name=tool_call_dict["name"],
898892
args=tool_call_dict["args"],
899-
id=tool_call_dict.get("id", str(uuid.uuid4())),
893+
id=tool_call_id,
900894
)
901895
)
902896

903897
# If this function_call Part has a signature, track it separately
904-
# We'll add it to content only if there's other content present
905898
if thought_sig:
906-
sig_block = {
907-
"type": "function_call_signature",
908-
"signature": thought_sig,
909-
"index": len(tool_calls) - 1,
910-
}
911-
function_call_signatures.append(sig_block)
912-
913-
# Add function call signatures to content only if there's already other content
914-
# This preserves backward compatibility where content is "" for
915-
# function-only responses
916-
if function_call_signatures and content is not None:
917-
for sig_block in function_call_signatures:
918-
content = _append_to_content(content, sig_block)
899+
if _FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY not in additional_kwargs:
900+
additional_kwargs[_FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY] = {}
901+
additional_kwargs[_FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY][
902+
tool_call_id
903+
] = (
904+
_bytes_to_base64(thought_sig)
905+
if isinstance(thought_sig, bytes)
906+
else thought_sig
907+
)
919908

920909
if content is None:
921910
content = ""

libs/genai/tests/unit_tests/test_chat_models.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2572,17 +2572,11 @@ def test_parse_response_candidate_adds_index_to_signature() -> None:
25722572
candidate = Candidate(content=Content(parts=[part1, part2]))
25732573

25742574
msg = _parse_response_candidate(candidate)
2575-
2576-
# Check if signature block is present and has index
2577-
found = False
2578-
for block in msg.content:
2579-
if isinstance(block, dict) and block.get("type") == "function_call_signature":
2580-
assert block.get("signature") == base64.b64encode(sig).decode("ascii")
2581-
assert "index" in block
2582-
assert block["index"] == 0
2583-
found = True
2584-
2585-
assert found, "Signature block not found"
2575+
function_call_map = msg.additional_kwargs[
2576+
"__gemini_function_call_thought_signatures__"
2577+
]
2578+
tool_call_id = msg.tool_calls[0]["id"]
2579+
assert function_call_map[tool_call_id] == base64.b64encode(sig).decode("ascii")
25862580

25872581

25882582
def test_parse_chat_history_uses_index_for_signature() -> None:
@@ -2592,14 +2586,17 @@ def test_parse_chat_history_uses_index_for_signature() -> None:
25922586

25932587
# Content with thinking block (index 0) and signature block (index 1)
25942588
# The signature block points to tool call index 0
2595-
content = [
2596-
{"type": "thinking", "thinking": "I should use the tool."},
2597-
{"type": "function_call_signature", "signature": sig_b64, "index": 0},
2598-
]
2589+
content = [{"type": "thinking", "thinking": "I should use the tool."}]
25992590

26002591
tool_calls = [{"name": "my_tool", "args": {"param": "value"}, "id": "call_1"}]
26012592

2602-
message = AIMessage(content=content, tool_calls=tool_calls) # type: ignore[arg-type]
2593+
message = AIMessage(
2594+
content=content, # type: ignore[arg-type]
2595+
tool_calls=tool_calls,
2596+
additional_kwargs={
2597+
"__gemini_function_call_thought_signatures__": {"call_1": sig_b64}
2598+
},
2599+
)
26032600

26042601
# Parse the history
26052602
_, formatted_messages = _parse_chat_history([message])

0 commit comments

Comments
 (0)