Skip to content

fix: Ignore AsyncGenerator return types in function declarations #2514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions contributing/samples/langchain_structured_tool_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"""
This agent aims to test the Langchain tool with Langchain's StructuredTool
"""
from __future__ import annotations

from google.adk.agents.llm_agent import Agent
from google.adk.tools.langchain_tool import LangchainTool
from langchain.tools import tool
Expand All @@ -23,11 +25,13 @@


async def add(x, y) -> int:
"""Adds two numbers."""
return x + y


@tool
def minus(x, y) -> int:
"""Minus two numbers."""
return x - y


Expand Down
3 changes: 2 additions & 1 deletion src/google/adk/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@


async def _convert_tool_union_to_tools(
tool_union: ToolUnion, ctx: ReadonlyContext
tool_union: ToolUnion,
ctx: ReadonlyContext,
) -> list[BaseTool]:
if isinstance(tool_union, BaseTool):
return [tool_union]
Expand Down
6 changes: 4 additions & 2 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def run_live(
invocation_context: InvocationContext,
) -> AsyncGenerator[Event, None]:
"""Runs the flow using live api."""
llm_request = LlmRequest()
llm_request = LlmRequest(live_connect_config=types.LiveConnectConfig())
event_id = Event.new_id()

# Preprocess before calling the LLM.
Expand Down Expand Up @@ -373,7 +373,9 @@ async def _run_one_step_async(
yield event

async def _preprocess_async(
self, invocation_context: InvocationContext, llm_request: LlmRequest
self,
invocation_context: InvocationContext,
llm_request: LlmRequest,
) -> AsyncGenerator[Event, None]:
from ...agents.llm_agent import LlmAgent

Expand Down
49 changes: 25 additions & 24 deletions src/google/adk/flows/llm_flows/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,30 +57,31 @@ async def run_async(
if agent.output_schema and not agent.tools:
llm_request.set_output_schema(agent.output_schema)

llm_request.live_connect_config.response_modalities = (
invocation_context.run_config.response_modalities
)
llm_request.live_connect_config.speech_config = (
invocation_context.run_config.speech_config
)
llm_request.live_connect_config.output_audio_transcription = (
invocation_context.run_config.output_audio_transcription
)
llm_request.live_connect_config.input_audio_transcription = (
invocation_context.run_config.input_audio_transcription
)
llm_request.live_connect_config.realtime_input_config = (
invocation_context.run_config.realtime_input_config
)
llm_request.live_connect_config.enable_affective_dialog = (
invocation_context.run_config.enable_affective_dialog
)
llm_request.live_connect_config.proactivity = (
invocation_context.run_config.proactivity
)
llm_request.live_connect_config.session_resumption = (
invocation_context.run_config.session_resumption
)
if llm_request.live_connect_config:
llm_request.live_connect_config.response_modalities = (
invocation_context.run_config.response_modalities
)
llm_request.live_connect_config.speech_config = (
invocation_context.run_config.speech_config
)
llm_request.live_connect_config.output_audio_transcription = (
invocation_context.run_config.output_audio_transcription
)
llm_request.live_connect_config.input_audio_transcription = (
invocation_context.run_config.input_audio_transcription
)
llm_request.live_connect_config.realtime_input_config = (
invocation_context.run_config.realtime_input_config
)
llm_request.live_connect_config.enable_affective_dialog = (
invocation_context.run_config.enable_affective_dialog
)
llm_request.live_connect_config.proactivity = (
invocation_context.run_config.proactivity
)
llm_request.live_connect_config.session_resumption = (
invocation_context.run_config.session_resumption
)

# TODO: handle tool append here, instead of in BaseTool.process_llm_request.

Expand Down
28 changes: 24 additions & 4 deletions src/google/adk/models/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

from __future__ import annotations

from collections.abc import AsyncGenerator as ABCAsyncGenerator
import inspect
from typing import get_origin
from typing import Optional

from google.genai import types
Expand All @@ -22,6 +25,7 @@
from pydantic import Field

from ..tools.base_tool import BaseTool
from ..tools.function_tool import FunctionTool


def _find_tool_with_function_declarations(
Expand Down Expand Up @@ -66,13 +70,13 @@ class LlmRequest(BaseModel):
config: types.GenerateContentConfig = Field(
default_factory=types.GenerateContentConfig
)
live_connect_config: types.LiveConnectConfig = Field(
default_factory=types.LiveConnectConfig
)
"""Additional config for the generate content request.

tools in generate_content_config should not be set.
"""
live_connect_config: Optional[types.LiveConnectConfig] = None
"""Live connection config.
"""
tools_dict: dict[str, BaseTool] = Field(default_factory=dict, exclude=True)
"""The tools dictionary."""

Expand All @@ -99,7 +103,23 @@ def append_tools(self, tools: list[BaseTool]) -> None:
return
declarations = []
for tool in tools:
declaration = tool._get_declaration()
if self.live_connect_config is not None:
# ignore response for tools that returns AsyncGenerator that the model
# can't understand yet even though the model can't handle it, streaming
# tools can handle it.
# to check type, use typing.collections.abc.AsyncGenerator and not
# typing.AsyncGenerator
is_async_generator_return = False
if isinstance(tool, FunctionTool):
signature = inspect.signature(tool.func)
is_async_generator_return = (
get_origin(signature.return_annotation) is ABCAsyncGenerator
)
declaration = tool._get_declaration(
ignore_return_declaration=is_async_generator_return
)
else:
declaration = tool._get_declaration()
if declaration:
declarations.append(declaration)
self.tools_dict[tool.name] = tool
Expand Down
11 changes: 8 additions & 3 deletions src/google/adk/tools/_automatic_function_calling_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def build_function_declaration(
func: Union[Callable, BaseModel],
ignore_params: Optional[list[str]] = None,
variant: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API,
ignore_return_declaration: bool = False,
) -> types.FunctionDeclaration:
signature = inspect.signature(func)
should_update_signature = False
Expand Down Expand Up @@ -232,9 +233,11 @@ def build_function_declaration(
new_func.__annotations__ = func.__annotations__

return (
from_function_with_options(func, variant)
from_function_with_options(func, variant, ignore_return_declaration)
if not should_update_signature
else from_function_with_options(new_func, variant)
else from_function_with_options(
new_func, variant, ignore_return_declaration
)
)


Expand Down Expand Up @@ -293,6 +296,7 @@ def build_function_declaration_util(
def from_function_with_options(
func: Callable,
variant: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API,
ignore_return_declaration: bool = False,
) -> 'types.FunctionDeclaration':

parameters_properties = {}
Expand Down Expand Up @@ -324,7 +328,8 @@ def from_function_with_options(
declaration.parameters
)
)
if variant == GoogleLLMVariant.GEMINI_API:

if variant == GoogleLLMVariant.GEMINI_API or ignore_return_declaration:
return declaration

return_annotation = inspect.signature(func).return_annotation
Expand Down
5 changes: 4 additions & 1 deletion src/google/adk/tools/agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from typing import Any
from typing import Optional
from typing import TYPE_CHECKING

from google.genai import types
Expand Down Expand Up @@ -61,7 +62,9 @@ def populate_name(cls, data: Any) -> Any:
return data

@override
def _get_declaration(self) -> types.FunctionDeclaration:
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
from ..agents.llm_agent import LlmAgent
from ..utils.variant_utils import GoogleLLMVariant

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Optional
from typing import Union

from google.genai.types import FunctionDeclaration
from google.genai import types
from typing_extensions import override

from ...auth.auth_credential import AuthCredential
Expand Down Expand Up @@ -115,7 +115,9 @@ def __init__(
self._auth_credential = auth_credential

@override
def _get_declaration(self) -> FunctionDeclaration:
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
"""Returns the function declaration in the Gemini Schema format."""
schema_dict = self._rest_api_tool._operation_parser.get_json_schema()
for field in self.EXCLUDE_FIELDS:
Expand All @@ -126,7 +128,7 @@ def _get_declaration(self) -> FunctionDeclaration:
schema_dict['required'].remove(field)

parameters = _to_gemini_schema(schema_dict)
function_decl = FunctionDeclaration(
function_decl = types.FunctionDeclaration(
name=self.name, description=self.description, parameters=parameters
)
return function_decl
Expand Down
4 changes: 3 additions & 1 deletion src/google/adk/tools/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def __init__(
self.is_long_running = is_long_running
self.custom_metadata = custom_metadata

def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
"""Gets the OpenAPI specification of this tool in the form of a FunctionDeclaration.

NOTE:
Expand Down
6 changes: 5 additions & 1 deletion src/google/adk/tools/crewai_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from __future__ import annotations

from typing import Optional

from google.genai import types
from typing_extensions import override

Expand Down Expand Up @@ -62,7 +64,9 @@ def __init__(self, tool: CrewaiBaseTool, *, name: str, description: str):
self.description = tool.description

@override
def _get_declaration(self) -> types.FunctionDeclaration:
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
"""Build the function declaration for the tool."""
function_declaration = _automatic_function_calling_util.build_function_declaration_for_params_for_crewai(
False,
Expand Down
5 changes: 4 additions & 1 deletion src/google/adk/tools/function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,17 @@ def __init__(self, func: Callable[..., Any]):
self._ignore_params = ['tool_context', 'input_stream']

@override
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
function_decl = types.FunctionDeclaration.model_validate(
build_function_declaration(
func=self.func,
# The model doesn't understand the function context.
# input_stream is for streaming tool
ignore_params=self._ignore_params,
variant=self._api_variant,
ignore_return_declaration=ignore_return_declaration,
)
)

Expand Down
7 changes: 4 additions & 3 deletions src/google/adk/tools/google_api_tool/google_api_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any
from typing import Dict
from typing import Optional

from google.genai.types import FunctionDeclaration
from google.genai import types
from typing_extensions import override

from ...auth.auth_credential import AuthCredential
Expand Down Expand Up @@ -52,7 +51,9 @@ def __init__(
self.configure_auth(client_id, client_secret)

@override
def _get_declaration(self) -> FunctionDeclaration:
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
return self._rest_api_tool._get_declaration()

@override
Expand Down
4 changes: 3 additions & 1 deletion src/google/adk/tools/langchain_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def __init__(
# else: keep default from FunctionTool

@override
def _get_declaration(self) -> types.FunctionDeclaration:
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
"""Build the function declaration for the tool.

Returns:
Expand Down
6 changes: 5 additions & 1 deletion src/google/adk/tools/load_artifacts_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import json
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING

from google.genai import types
Expand All @@ -37,7 +38,10 @@ def __init__(self):
description='Loads the artifacts and adds them to the session.',
)

def _get_declaration(self) -> types.FunctionDeclaration | None:
@override
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
return types.FunctionDeclaration(
name=self.name,
description=self.description,
Expand Down
5 changes: 4 additions & 1 deletion src/google/adk/tools/load_memory_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from typing import Optional
from typing import TYPE_CHECKING

from google.genai import types
Expand Down Expand Up @@ -58,7 +59,9 @@ def __init__(self):
super().__init__(load_memory)

@override
def _get_declaration(self) -> types.FunctionDeclaration | None:
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
return types.FunctionDeclaration(
name=self.name,
description=self.description,
Expand Down
4 changes: 3 additions & 1 deletion src/google/adk/tools/long_running_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def __init__(self, func: Callable):
self.is_long_running = True

@override
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
declaration = super()._get_declaration()
if declaration:
instruction = (
Expand Down
Loading