diff --git a/libs/genai/langchain_google_genai/chat_models.py b/libs/genai/langchain_google_genai/chat_models.py index 803aafebb..7d3704011 100644 --- a/libs/genai/langchain_google_genai/chat_models.py +++ b/libs/genai/langchain_google_genai/chat_models.py @@ -761,6 +761,8 @@ def _parse_response_candidate( except (AttributeError, TypeError): thought_sig = None + has_function_call = hasattr(part, "function_call") and part.function_call + if hasattr(part, "thought") and part.thought: thinking_message = { "type": "thinking", @@ -770,7 +772,7 @@ def _parse_response_candidate( if thought_sig: thinking_message["signature"] = thought_sig content = _append_to_content(content, thinking_message) - elif text is not None and text: + elif text is not None and text.strip() and not has_function_call: # Check if this text Part has a signature attached if thought_sig: # Text with signature needs structured block to preserve signature @@ -896,18 +898,33 @@ def _parse_response_candidate( # If this function_call Part has a signature, track it separately if thought_sig: - if _FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY not in additional_kwargs: - additional_kwargs[_FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY] = {} - additional_kwargs[_FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY][ - tool_call_id - ] = ( - _bytes_to_base64(thought_sig) - if isinstance(thought_sig, bytes) - else thought_sig - ) + sig_block = { + "type": "function_call_signature", + "signature": thought_sig, + } + function_call_signatures.append(sig_block) + + # Add function call signatures to content only if there's already other content + # This preserves backward compatibility where content is "" for + # function-only responses + if function_call_signatures and content is not None: + for sig_block in function_call_signatures: + content = _append_to_content(content, sig_block) if content is None: content = "" + + if ( + hasattr(response_candidate, "logprobs_result") + and response_candidate.logprobs_result + ): + # Note: logprobs is flaky, sometimes available, sometimes not + # https://discuss.ai.google.dev/t/logprobs-is-not-enabled-for-gemini-models/107989/15 + response_metadata["logprobs"] = MessageToDict( + response_candidate.logprobs_result._pb, + preserving_proto_field_name=True, + ) + if isinstance(content, list) and any( isinstance(item, dict) and "executable_code" in item for item in content ): @@ -1764,6 +1781,9 @@ class Joke(BaseModel): stop: list[str] | None = None """Stop sequences for the model.""" + logprobs: int | None = None + """The number of logprobs to return.""" + streaming: bool | None = None """Whether to stream responses from the model.""" @@ -1931,6 +1951,7 @@ def _identifying_params(self) -> dict[str, Any]: "media_resolution": self.media_resolution, "thinking_budget": self.thinking_budget, "include_thoughts": self.include_thoughts, + "logprobs": self.logprobs, } def invoke( @@ -2024,6 +2045,10 @@ def _prepare_params( }.items() if v is not None } + logprobs = getattr(self, "logprobs", None) + if logprobs: + gen_config["logprobs"] = logprobs + gen_config["response_logprobs"] = True if generation_config: gen_config = {**gen_config, **generation_config} diff --git a/libs/genai/tests/unit_tests/test_chat_models.py b/libs/genai/tests/unit_tests/test_chat_models.py index 3f15ff750..4d748db2c 100644 --- a/libs/genai/tests/unit_tests/test_chat_models.py +++ b/libs/genai/tests/unit_tests/test_chat_models.py @@ -144,24 +144,98 @@ def test_initialization_inside_threadpool() -> None: ).result() -def test_client_transport() -> None: +def test_logprobs() -> None: + """Test that logprobs parameter is set correctly and is in the response.""" + llm = ChatGoogleGenerativeAI( + model=MODEL_NAME, + google_api_key=SecretStr("secret-api-key"), + logprobs=10, + ) + assert llm.logprobs == 10 + + # Create proper mock response with logprobs_result + raw_response = { + "candidates": [ + { + "content": {"parts": [{"text": "Test response"}]}, + "finish_reason": 1, + "safety_ratings": [], + "logprobs_result": { + "top_candidates": [ + { + "candidates": [ + {"token": "Test", "log_probability": -0.1}, + ] + } + ] + }, + } + ], + "prompt_feedback": {"block_reason": 0, "safety_ratings": []}, + "usage_metadata": { + "prompt_token_count": 5, + "candidates_token_count": 2, + "total_token_count": 7, + }, + } + response = GenerateContentResponse(raw_response) + + with patch( + "langchain_google_genai.chat_models._chat_with_retry" + ) as mock_chat_with_retry: + mock_chat_with_retry.return_value = response + llm = ChatGoogleGenerativeAI( + model=MODEL_NAME, + google_api_key="test-key", + logprobs=1, + ) + result = llm.invoke("test") + assert "logprobs" in result.response_metadata + assert result.response_metadata["logprobs"] == { + "top_candidates": [ + { + "candidates": [ + {"token": "Test", "log_probability": -0.1}, + ] + } + ] + } + + mock_chat_with_retry.assert_called_once() + request = mock_chat_with_retry.call_args.kwargs["request"] + assert request.generation_config.logprobs == 1 + assert request.generation_config.response_logprobs is True + + +@pytest.mark.enable_socket +@patch("langchain_google_genai._genai_extension.v1betaGenerativeServiceAsyncClient") +@patch("langchain_google_genai._genai_extension.v1betaGenerativeServiceClient") +def test_client_transport(mock_client: Mock, mock_async_client: Mock) -> None: """Test client transport configuration.""" - model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key=FAKE_API_KEY) + mock_client.return_value.transport = Mock() + mock_client.return_value.transport.kind = "grpc" + model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key="fake-key") assert model.client.transport.kind == "grpc" + mock_client.return_value.transport.kind = "rest" model = ChatGoogleGenerativeAI( model=MODEL_NAME, google_api_key="fake-key", transport="rest" ) assert model.client.transport.kind == "rest" async def check_async_client() -> None: - model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key=FAKE_API_KEY) + mock_async_client.return_value.transport = Mock() + mock_async_client.return_value.transport.kind = "grpc_asyncio" + model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key="fake-key") + _ = model.async_client assert model.async_client.transport.kind == "grpc_asyncio" # Test auto conversion of transport to "grpc_asyncio" from "rest" model = ChatGoogleGenerativeAI( model=MODEL_NAME, google_api_key=FAKE_API_KEY, transport="rest" ) + model.async_client_running = None + _ = model.async_client assert model.async_client.transport.kind == "grpc_asyncio" asyncio.run(check_async_client()) @@ -175,6 +249,7 @@ def test_initalization_without_async() -> None: assert chat.async_client is None +@pytest.mark.enable_socket def test_initialization_with_async() -> None: async def initialize_chat_with_async_client() -> ChatGoogleGenerativeAI: model = ChatGoogleGenerativeAI( @@ -1713,6 +1788,7 @@ def test_grounding_metadata_multiple_parts() -> None: assert grounding["grounding_supports"][0]["segment"]["part_index"] == 1 +@pytest.mark.enable_socket @pytest.mark.parametrize( "is_async,mock_target,method_name", [ @@ -1839,6 +1915,7 @@ def mock_stream() -> Iterator[GenerateContentResponse]: assert "timeout" not in call_kwargs +@pytest.mark.enable_socket @pytest.mark.parametrize( "is_async,mock_target,method_name", [