|
1 | 1 | import base64 |
2 | 2 | import io |
3 | 3 | import json |
| 4 | +import logging |
4 | 5 | import time |
5 | 6 | from collections.abc import Generator, Sequence |
6 | 7 | from typing import Optional, Union, cast |
|
52 | 53 | # For more information about the models, please refer to https://ai.google.dev/gemini-api/docs/thinking |
53 | 54 | DEFAULT_NO_THINKING_MODELS = ["gemini-2.5-flash-lite"] |
54 | 55 |
|
| 56 | +logger = logging.getLogger(__name__) |
| 57 | + |
55 | 58 |
|
56 | 59 | class VertexAiLargeLanguageModel(LargeLanguageModel): |
57 | 60 | def _invoke( |
@@ -634,6 +637,12 @@ def _handle_generate_response( |
634 | 637 | ), |
635 | 638 | ) |
636 | 639 | assistant_prompt_message.tool_calls.append(tool_call) |
| 640 | + # Capture thought_signature if the SDK surfaced it on the same part |
| 641 | + sig = self._extract_thought_signature(part) |
| 642 | + if sig: |
| 643 | + if not hasattr(self, "_last_function_call_signatures"): |
| 644 | + self._last_function_call_signatures = [] |
| 645 | + self._last_function_call_signatures.append(sig) |
637 | 646 | # Check for text |
638 | 647 | elif hasattr(part, 'text') and part.text: |
639 | 648 | if part.thought is True and not is_thinking: |
@@ -698,6 +707,12 @@ def _handle_generate_stream_response( |
698 | 707 | ), |
699 | 708 | ) |
700 | 709 | ) |
| 710 | + # Capture thought_signature if present on the streaming part |
| 711 | + sig = self._extract_thought_signature(part) |
| 712 | + if sig: |
| 713 | + if not hasattr(self, "_last_function_call_signatures"): |
| 714 | + self._last_function_call_signatures = [] |
| 715 | + self._last_function_call_signatures.append(sig) |
701 | 716 | # Check for text |
702 | 717 | elif hasattr(part, 'text') and part.text: |
703 | 718 | if part.thought is True and not is_thinking: |
@@ -774,6 +789,32 @@ def _handle_generate_stream_response( |
774 | 789 | ), |
775 | 790 | ) |
776 | 791 |
|
| 792 | + def _extract_thought_signature(self, part) -> Optional[str]: |
| 793 | + """ |
| 794 | + Best-effort extractor for Vertex AI thought signatures from a Part. |
| 795 | + Handles snake_case and camelCase, and tries dict/extraContent fallbacks. |
| 796 | + """ |
| 797 | + # Direct attributes first |
| 798 | + sig = getattr(part, "thought_signature", None) or getattr(part, "thoughtSignature", None) |
| 799 | + if isinstance(sig, str) and sig: |
| 800 | + return sig |
| 801 | + # Try dict conversion if the SDK object supports it |
| 802 | + try: |
| 803 | + d = part.to_dict() if hasattr(part, "to_dict") else (getattr(part, "__dict__", {}) or {}) |
| 804 | + if isinstance(d, dict): |
| 805 | + sig = d.get("thoughtSignature") or d.get("thought_signature") |
| 806 | + if not sig: |
| 807 | + extra = d.get("extraContent") or d.get("extra_content") or {} |
| 808 | + if isinstance(extra, dict): |
| 809 | + g = extra.get("google") |
| 810 | + if isinstance(g, dict): |
| 811 | + sig = g.get("thought_signature") |
| 812 | + if isinstance(sig, str) and sig: |
| 813 | + return sig |
| 814 | + except Exception as e: |
| 815 | + logger.warning(e, exc_info=True) |
| 816 | + return None |
| 817 | + |
777 | 818 | def _convert_one_message_to_text(self, message: PromptMessage) -> str: |
778 | 819 | """ |
779 | 820 | Convert a single message to a string. |
@@ -830,15 +871,20 @@ def _format_message_to_genai_content(self, message: PromptMessage) -> dict: |
830 | 871 | return {"role": "user", "parts": parts} |
831 | 872 | elif isinstance(message, AssistantPromptMessage): |
832 | 873 | if message.tool_calls: |
833 | | - parts = [ |
834 | | - { |
| 874 | + parts = [] |
| 875 | + for tool_call in message.tool_calls: |
| 876 | + part_dict = { |
835 | 877 | "function_call": { |
836 | 878 | "name": tool_call.function.name, |
837 | 879 | "args": json.loads(tool_call.function.arguments), |
838 | 880 | } |
839 | 881 | } |
840 | | - for tool_call in message.tool_calls |
841 | | - ] |
| 882 | + # Attach thought_signature if we captured one from the previous model output |
| 883 | + if hasattr(self, "_last_function_call_signatures") and self._last_function_call_signatures: |
| 884 | + sig = self._last_function_call_signatures.pop(0) |
| 885 | + if sig: |
| 886 | + part_dict["thought_signature"] = sig |
| 887 | + parts.append(part_dict) |
842 | 888 | else: |
843 | 889 | parts = [{"text": message.content}] |
844 | 890 | return {"role": "model", "parts": parts} |
|
0 commit comments