Skip to content
45 changes: 35 additions & 10 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,8 @@
except (AttributeError, TypeError):
thought_sig = None

has_function_call = hasattr(part, "function_call") and part.function_call

if hasattr(part, "thought") and part.thought:
thinking_message = {
"type": "thinking",
Expand All @@ -770,7 +772,7 @@
if thought_sig:
thinking_message["signature"] = thought_sig
content = _append_to_content(content, thinking_message)
elif text is not None and text:
elif text is not None and text.strip() and not has_function_call:
# Check if this text Part has a signature attached
if thought_sig:
# Text with signature needs structured block to preserve signature
Expand Down Expand Up @@ -896,18 +898,33 @@

# If this function_call Part has a signature, track it separately
if thought_sig:
if _FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY not in additional_kwargs:
additional_kwargs[_FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY] = {}
additional_kwargs[_FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY][
tool_call_id
] = (
_bytes_to_base64(thought_sig)
if isinstance(thought_sig, bytes)
else thought_sig
)
sig_block = {
"type": "function_call_signature",
"signature": thought_sig,
}
function_call_signatures.append(sig_block)

Check failure on line 905 in libs/genai/langchain_google_genai/chat_models.py

View workflow job for this annotation

GitHub Actions / cd libs/genai / - / make lint #3.12

Ruff (F821)

langchain_google_genai/chat_models.py:905:17: F821 Undefined name `function_call_signatures`

# Add function call signatures to content only if there's already other content
# This preserves backward compatibility where content is "" for
# function-only responses
if function_call_signatures and content is not None:

Check failure on line 910 in libs/genai/langchain_google_genai/chat_models.py

View workflow job for this annotation

GitHub Actions / cd libs/genai / - / make lint #3.12

Ruff (F821)

langchain_google_genai/chat_models.py:910:12: F821 Undefined name `function_call_signatures`
for sig_block in function_call_signatures:

Check failure on line 911 in libs/genai/langchain_google_genai/chat_models.py

View workflow job for this annotation

GitHub Actions / cd libs/genai / - / make lint #3.12

Ruff (F821)

langchain_google_genai/chat_models.py:911:30: F821 Undefined name `function_call_signatures`
content = _append_to_content(content, sig_block)

if content is None:
content = ""

if (
hasattr(response_candidate, "logprobs_result")
and response_candidate.logprobs_result
):
# Note: logprobs is flaky, sometimes available, sometimes not
# https://discuss.ai.google.dev/t/logprobs-is-not-enabled-for-gemini-models/107989/15
response_metadata["logprobs"] = MessageToDict(

Check failure on line 923 in libs/genai/langchain_google_genai/chat_models.py

View workflow job for this annotation

GitHub Actions / cd libs/genai / - / make lint #3.12

Ruff (F821)

langchain_google_genai/chat_models.py:923:41: F821 Undefined name `MessageToDict`
response_candidate.logprobs_result._pb,
preserving_proto_field_name=True,
)

if isinstance(content, list) and any(
isinstance(item, dict) and "executable_code" in item for item in content
):
Expand Down Expand Up @@ -1764,6 +1781,9 @@
stop: list[str] | None = None
"""Stop sequences for the model."""

logprobs: int | None = None
"""The number of logprobs to return."""

streaming: bool | None = None
"""Whether to stream responses from the model."""

Expand Down Expand Up @@ -1931,6 +1951,7 @@
"media_resolution": self.media_resolution,
"thinking_budget": self.thinking_budget,
"include_thoughts": self.include_thoughts,
"logprobs": self.logprobs,
}

def invoke(
Expand Down Expand Up @@ -2024,6 +2045,10 @@
}.items()
if v is not None
}
logprobs = getattr(self, "logprobs", None)
if logprobs:
gen_config["logprobs"] = logprobs
gen_config["response_logprobs"] = True
if generation_config:
gen_config = {**gen_config, **generation_config}

Expand Down
83 changes: 80 additions & 3 deletions libs/genai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,24 +144,98 @@ def test_initialization_inside_threadpool() -> None:
).result()


def test_client_transport() -> None:
def test_logprobs() -> None:
"""Test that logprobs parameter is set correctly and is in the response."""
llm = ChatGoogleGenerativeAI(
model=MODEL_NAME,
google_api_key=SecretStr("secret-api-key"),
logprobs=10,
)
assert llm.logprobs == 10

# Create proper mock response with logprobs_result
raw_response = {
"candidates": [
{
"content": {"parts": [{"text": "Test response"}]},
"finish_reason": 1,
"safety_ratings": [],
"logprobs_result": {
"top_candidates": [
{
"candidates": [
{"token": "Test", "log_probability": -0.1},
]
}
]
},
}
],
"prompt_feedback": {"block_reason": 0, "safety_ratings": []},
"usage_metadata": {
"prompt_token_count": 5,
"candidates_token_count": 2,
"total_token_count": 7,
},
}
response = GenerateContentResponse(raw_response)

with patch(
"langchain_google_genai.chat_models._chat_with_retry"
) as mock_chat_with_retry:
mock_chat_with_retry.return_value = response
llm = ChatGoogleGenerativeAI(
model=MODEL_NAME,
google_api_key="test-key",
logprobs=1,
)
result = llm.invoke("test")
assert "logprobs" in result.response_metadata
assert result.response_metadata["logprobs"] == {
"top_candidates": [
{
"candidates": [
{"token": "Test", "log_probability": -0.1},
]
}
]
}

mock_chat_with_retry.assert_called_once()
request = mock_chat_with_retry.call_args.kwargs["request"]
assert request.generation_config.logprobs == 1
assert request.generation_config.response_logprobs is True


@pytest.mark.enable_socket
@patch("langchain_google_genai._genai_extension.v1betaGenerativeServiceAsyncClient")
@patch("langchain_google_genai._genai_extension.v1betaGenerativeServiceClient")
def test_client_transport(mock_client: Mock, mock_async_client: Mock) -> None:
"""Test client transport configuration."""
model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key=FAKE_API_KEY)
mock_client.return_value.transport = Mock()
mock_client.return_value.transport.kind = "grpc"
model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key="fake-key")
assert model.client.transport.kind == "grpc"

mock_client.return_value.transport.kind = "rest"
model = ChatGoogleGenerativeAI(
model=MODEL_NAME, google_api_key="fake-key", transport="rest"
)
assert model.client.transport.kind == "rest"

async def check_async_client() -> None:
model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key=FAKE_API_KEY)
mock_async_client.return_value.transport = Mock()
mock_async_client.return_value.transport.kind = "grpc_asyncio"
model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key="fake-key")
_ = model.async_client
assert model.async_client.transport.kind == "grpc_asyncio"

# Test auto conversion of transport to "grpc_asyncio" from "rest"
model = ChatGoogleGenerativeAI(
model=MODEL_NAME, google_api_key=FAKE_API_KEY, transport="rest"
)
model.async_client_running = None
_ = model.async_client
assert model.async_client.transport.kind == "grpc_asyncio"

asyncio.run(check_async_client())
Expand All @@ -175,6 +249,7 @@ def test_initalization_without_async() -> None:
assert chat.async_client is None


@pytest.mark.enable_socket
def test_initialization_with_async() -> None:
async def initialize_chat_with_async_client() -> ChatGoogleGenerativeAI:
model = ChatGoogleGenerativeAI(
Expand Down Expand Up @@ -1713,6 +1788,7 @@ def test_grounding_metadata_multiple_parts() -> None:
assert grounding["grounding_supports"][0]["segment"]["part_index"] == 1


@pytest.mark.enable_socket
@pytest.mark.parametrize(
"is_async,mock_target,method_name",
[
Expand Down Expand Up @@ -1839,6 +1915,7 @@ def mock_stream() -> Iterator[GenerateContentResponse]:
assert "timeout" not in call_kwargs


@pytest.mark.enable_socket
@pytest.mark.parametrize(
"is_async,mock_target,method_name",
[
Expand Down
Loading