diff --git a/contributing/samples/langchain_structured_tool_agent/agent.py b/contributing/samples/langchain_structured_tool_agent/agent.py index 5c4c5b9a2..fc0d51d54 100644 --- a/contributing/samples/langchain_structured_tool_agent/agent.py +++ b/contributing/samples/langchain_structured_tool_agent/agent.py @@ -15,6 +15,8 @@ """ This agent aims to test the Langchain tool with Langchain's StructuredTool """ +from __future__ import annotations + from google.adk.agents.llm_agent import Agent from google.adk.tools.langchain_tool import LangchainTool from langchain.tools import tool @@ -23,11 +25,13 @@ async def add(x, y) -> int: + """Adds two numbers.""" return x + y @tool def minus(x, y) -> int: + """Minus two numbers.""" return x - y diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index ae55dd1e4..ba2614f20 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -109,7 +109,8 @@ async def _convert_tool_union_to_tools( - tool_union: ToolUnion, ctx: ReadonlyContext + tool_union: ToolUnion, + ctx: ReadonlyContext, ) -> list[BaseTool]: if isinstance(tool_union, BaseTool): return [tool_union] diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 90cf0fbcf..16a7322f1 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -73,7 +73,7 @@ async def run_live( invocation_context: InvocationContext, ) -> AsyncGenerator[Event, None]: """Runs the flow using live api.""" - llm_request = LlmRequest() + llm_request = LlmRequest(live_connect_config=types.LiveConnectConfig()) event_id = Event.new_id() # Preprocess before calling the LLM. @@ -373,7 +373,9 @@ async def _run_one_step_async( yield event async def _preprocess_async( - self, invocation_context: InvocationContext, llm_request: LlmRequest + self, + invocation_context: InvocationContext, + llm_request: LlmRequest, ) -> AsyncGenerator[Event, None]: from ...agents.llm_agent import LlmAgent diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py index 549c6d875..e5b4fd627 100644 --- a/src/google/adk/flows/llm_flows/basic.py +++ b/src/google/adk/flows/llm_flows/basic.py @@ -57,30 +57,31 @@ async def run_async( if agent.output_schema and not agent.tools: llm_request.set_output_schema(agent.output_schema) - llm_request.live_connect_config.response_modalities = ( - invocation_context.run_config.response_modalities - ) - llm_request.live_connect_config.speech_config = ( - invocation_context.run_config.speech_config - ) - llm_request.live_connect_config.output_audio_transcription = ( - invocation_context.run_config.output_audio_transcription - ) - llm_request.live_connect_config.input_audio_transcription = ( - invocation_context.run_config.input_audio_transcription - ) - llm_request.live_connect_config.realtime_input_config = ( - invocation_context.run_config.realtime_input_config - ) - llm_request.live_connect_config.enable_affective_dialog = ( - invocation_context.run_config.enable_affective_dialog - ) - llm_request.live_connect_config.proactivity = ( - invocation_context.run_config.proactivity - ) - llm_request.live_connect_config.session_resumption = ( - invocation_context.run_config.session_resumption - ) + if llm_request.live_connect_config: + llm_request.live_connect_config.response_modalities = ( + invocation_context.run_config.response_modalities + ) + llm_request.live_connect_config.speech_config = ( + invocation_context.run_config.speech_config + ) + llm_request.live_connect_config.output_audio_transcription = ( + invocation_context.run_config.output_audio_transcription + ) + llm_request.live_connect_config.input_audio_transcription = ( + invocation_context.run_config.input_audio_transcription + ) + llm_request.live_connect_config.realtime_input_config = ( + invocation_context.run_config.realtime_input_config + ) + llm_request.live_connect_config.enable_affective_dialog = ( + invocation_context.run_config.enable_affective_dialog + ) + llm_request.live_connect_config.proactivity = ( + invocation_context.run_config.proactivity + ) + llm_request.live_connect_config.session_resumption = ( + invocation_context.run_config.session_resumption + ) # TODO: handle tool append here, instead of in BaseTool.process_llm_request. diff --git a/src/google/adk/models/llm_request.py b/src/google/adk/models/llm_request.py index b83fd1d99..86d1d17e1 100644 --- a/src/google/adk/models/llm_request.py +++ b/src/google/adk/models/llm_request.py @@ -14,6 +14,9 @@ from __future__ import annotations +from collections.abc import AsyncGenerator as ABCAsyncGenerator +import inspect +from typing import get_origin from typing import Optional from google.genai import types @@ -22,6 +25,7 @@ from pydantic import Field from ..tools.base_tool import BaseTool +from ..tools.function_tool import FunctionTool def _find_tool_with_function_declarations( @@ -66,13 +70,13 @@ class LlmRequest(BaseModel): config: types.GenerateContentConfig = Field( default_factory=types.GenerateContentConfig ) - live_connect_config: types.LiveConnectConfig = Field( - default_factory=types.LiveConnectConfig - ) """Additional config for the generate content request. tools in generate_content_config should not be set. """ + live_connect_config: Optional[types.LiveConnectConfig] = None + """Live connection config. + """ tools_dict: dict[str, BaseTool] = Field(default_factory=dict, exclude=True) """The tools dictionary.""" @@ -99,7 +103,23 @@ def append_tools(self, tools: list[BaseTool]) -> None: return declarations = [] for tool in tools: - declaration = tool._get_declaration() + if self.live_connect_config is not None: + # ignore response for tools that returns AsyncGenerator that the model + # can't understand yet even though the model can't handle it, streaming + # tools can handle it. + # to check type, use typing.collections.abc.AsyncGenerator and not + # typing.AsyncGenerator + is_async_generator_return = False + if isinstance(tool, FunctionTool): + signature = inspect.signature(tool.func) + is_async_generator_return = ( + get_origin(signature.return_annotation) is ABCAsyncGenerator + ) + declaration = tool._get_declaration( + ignore_return_declaration=is_async_generator_return + ) + else: + declaration = tool._get_declaration() if declaration: declarations.append(declaration) self.tools_dict[tool.name] = tool diff --git a/src/google/adk/tools/_automatic_function_calling_util.py b/src/google/adk/tools/_automatic_function_calling_util.py index 5e32f68e0..1a89442f3 100644 --- a/src/google/adk/tools/_automatic_function_calling_util.py +++ b/src/google/adk/tools/_automatic_function_calling_util.py @@ -195,6 +195,7 @@ def build_function_declaration( func: Union[Callable, BaseModel], ignore_params: Optional[list[str]] = None, variant: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API, + ignore_return_declaration: bool = False, ) -> types.FunctionDeclaration: signature = inspect.signature(func) should_update_signature = False @@ -232,9 +233,11 @@ def build_function_declaration( new_func.__annotations__ = func.__annotations__ return ( - from_function_with_options(func, variant) + from_function_with_options(func, variant, ignore_return_declaration) if not should_update_signature - else from_function_with_options(new_func, variant) + else from_function_with_options( + new_func, variant, ignore_return_declaration + ) ) @@ -293,6 +296,7 @@ def build_function_declaration_util( def from_function_with_options( func: Callable, variant: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API, + ignore_return_declaration: bool = False, ) -> 'types.FunctionDeclaration': parameters_properties = {} @@ -324,7 +328,8 @@ def from_function_with_options( declaration.parameters ) ) - if variant == GoogleLLMVariant.GEMINI_API: + + if variant == GoogleLLMVariant.GEMINI_API or ignore_return_declaration: return declaration return_annotation = inspect.signature(func).return_annotation diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index c0d07238d..abe0d13d9 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -15,6 +15,7 @@ from __future__ import annotations from typing import Any +from typing import Optional from typing import TYPE_CHECKING from google.genai import types @@ -61,7 +62,9 @@ def populate_name(cls, data: Any) -> Any: return data @override - def _get_declaration(self) -> types.FunctionDeclaration: + def _get_declaration( + self, ignore_return_declaration: bool = False + ) -> Optional[types.FunctionDeclaration]: from ..agents.llm_agent import LlmAgent from ..utils.variant_utils import GoogleLLMVariant diff --git a/src/google/adk/tools/application_integration_tool/integration_connector_tool.py b/src/google/adk/tools/application_integration_tool/integration_connector_tool.py index 0f1a6895d..7c3e27499 100644 --- a/src/google/adk/tools/application_integration_tool/integration_connector_tool.py +++ b/src/google/adk/tools/application_integration_tool/integration_connector_tool.py @@ -20,7 +20,7 @@ from typing import Optional from typing import Union -from google.genai.types import FunctionDeclaration +from google.genai import types from typing_extensions import override from ...auth.auth_credential import AuthCredential @@ -115,7 +115,9 @@ def __init__( self._auth_credential = auth_credential @override - def _get_declaration(self) -> FunctionDeclaration: + def _get_declaration( + self, ignore_return_declaration: bool = False + ) -> Optional[types.FunctionDeclaration]: """Returns the function declaration in the Gemini Schema format.""" schema_dict = self._rest_api_tool._operation_parser.get_json_schema() for field in self.EXCLUDE_FIELDS: @@ -126,7 +128,7 @@ def _get_declaration(self) -> FunctionDeclaration: schema_dict['required'].remove(field) parameters = _to_gemini_schema(schema_dict) - function_decl = FunctionDeclaration( + function_decl = types.FunctionDeclaration( name=self.name, description=self.description, parameters=parameters ) return function_decl diff --git a/src/google/adk/tools/base_tool.py b/src/google/adk/tools/base_tool.py index 90c575395..eb91adcd6 100644 --- a/src/google/adk/tools/base_tool.py +++ b/src/google/adk/tools/base_tool.py @@ -78,7 +78,9 @@ def __init__( self.is_long_running = is_long_running self.custom_metadata = custom_metadata - def _get_declaration(self) -> Optional[types.FunctionDeclaration]: + def _get_declaration( + self, ignore_return_declaration: bool = False + ) -> Optional[types.FunctionDeclaration]: """Gets the OpenAPI specification of this tool in the form of a FunctionDeclaration. NOTE: diff --git a/src/google/adk/tools/crewai_tool.py b/src/google/adk/tools/crewai_tool.py index 9dfe7cc75..7af00fe74 100644 --- a/src/google/adk/tools/crewai_tool.py +++ b/src/google/adk/tools/crewai_tool.py @@ -14,6 +14,8 @@ from __future__ import annotations +from typing import Optional + from google.genai import types from typing_extensions import override @@ -62,7 +64,9 @@ def __init__(self, tool: CrewaiBaseTool, *, name: str, description: str): self.description = tool.description @override - def _get_declaration(self) -> types.FunctionDeclaration: + def _get_declaration( + self, ignore_return_declaration: bool = False + ) -> Optional[types.FunctionDeclaration]: """Build the function declaration for the tool.""" function_declaration = _automatic_function_calling_util.build_function_declaration_for_params_for_crewai( False, diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index 2687f1200..564cd3100 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -62,7 +62,9 @@ def __init__(self, func: Callable[..., Any]): self._ignore_params = ['tool_context', 'input_stream'] @override - def _get_declaration(self) -> Optional[types.FunctionDeclaration]: + def _get_declaration( + self, ignore_return_declaration: bool = False + ) -> Optional[types.FunctionDeclaration]: function_decl = types.FunctionDeclaration.model_validate( build_function_declaration( func=self.func, @@ -70,6 +72,7 @@ def _get_declaration(self) -> Optional[types.FunctionDeclaration]: # input_stream is for streaming tool ignore_params=self._ignore_params, variant=self._api_variant, + ignore_return_declaration=ignore_return_declaration, ) ) diff --git a/src/google/adk/tools/google_api_tool/google_api_tool.py b/src/google/adk/tools/google_api_tool/google_api_tool.py index d2bac5686..645822da1 100644 --- a/src/google/adk/tools/google_api_tool/google_api_tool.py +++ b/src/google/adk/tools/google_api_tool/google_api_tool.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations from typing import Any from typing import Dict from typing import Optional -from google.genai.types import FunctionDeclaration +from google.genai import types from typing_extensions import override from ...auth.auth_credential import AuthCredential @@ -52,7 +51,9 @@ def __init__( self.configure_auth(client_id, client_secret) @override - def _get_declaration(self) -> FunctionDeclaration: + def _get_declaration( + self, ignore_return_declaration: bool = False + ) -> Optional[types.FunctionDeclaration]: return self._rest_api_tool._get_declaration() @override diff --git a/src/google/adk/tools/langchain_tool.py b/src/google/adk/tools/langchain_tool.py index 44f884ff6..06cea4971 100644 --- a/src/google/adk/tools/langchain_tool.py +++ b/src/google/adk/tools/langchain_tool.py @@ -101,7 +101,9 @@ def __init__( # else: keep default from FunctionTool @override - def _get_declaration(self) -> types.FunctionDeclaration: + def _get_declaration( + self, ignore_return_declaration: bool = False + ) -> Optional[types.FunctionDeclaration]: """Build the function declaration for the tool. Returns: diff --git a/src/google/adk/tools/load_artifacts_tool.py b/src/google/adk/tools/load_artifacts_tool.py index db28aefb9..e5d419a09 100644 --- a/src/google/adk/tools/load_artifacts_tool.py +++ b/src/google/adk/tools/load_artifacts_tool.py @@ -16,6 +16,7 @@ import json from typing import Any +from typing import Optional from typing import TYPE_CHECKING from google.genai import types @@ -37,7 +38,10 @@ def __init__(self): description='Loads the artifacts and adds them to the session.', ) - def _get_declaration(self) -> types.FunctionDeclaration | None: + @override + def _get_declaration( + self, ignore_return_declaration: bool = False + ) -> Optional[types.FunctionDeclaration]: return types.FunctionDeclaration( name=self.name, description=self.description, diff --git a/src/google/adk/tools/load_memory_tool.py b/src/google/adk/tools/load_memory_tool.py index 8410e4114..02cb82d41 100644 --- a/src/google/adk/tools/load_memory_tool.py +++ b/src/google/adk/tools/load_memory_tool.py @@ -14,6 +14,7 @@ from __future__ import annotations +from typing import Optional from typing import TYPE_CHECKING from google.genai import types @@ -58,7 +59,9 @@ def __init__(self): super().__init__(load_memory) @override - def _get_declaration(self) -> types.FunctionDeclaration | None: + def _get_declaration( + self, ignore_return_declaration: bool = False + ) -> Optional[types.FunctionDeclaration]: return types.FunctionDeclaration( name=self.name, description=self.description, diff --git a/src/google/adk/tools/long_running_tool.py b/src/google/adk/tools/long_running_tool.py index 628d01324..e70884ee5 100644 --- a/src/google/adk/tools/long_running_tool.py +++ b/src/google/adk/tools/long_running_tool.py @@ -45,7 +45,9 @@ def __init__(self, func: Callable): self.is_long_running = True @override - def _get_declaration(self) -> Optional[types.FunctionDeclaration]: + def _get_declaration( + self, ignore_return_declaration: bool = False + ) -> Optional[types.FunctionDeclaration]: declaration = super()._get_declaration() if declaration: instruction = ( diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index af4616cae..42ae66e4d 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -19,6 +19,7 @@ from typing import Optional from fastapi.openapi.models import APIKeyIn +from google.genai import types from google.genai.types import FunctionDeclaration from typing_extensions import override @@ -97,7 +98,9 @@ def __init__( self._mcp_session_manager = mcp_session_manager @override - def _get_declaration(self) -> FunctionDeclaration: + def _get_declaration( + self, ignore_return_declaration: bool = False + ) -> Optional[types.FunctionDeclaration]: """Gets the function declaration for the tool. Returns: diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py index 0df9461b9..2323f3ffa 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py @@ -23,6 +23,7 @@ from typing import Union from fastapi.openapi.models import Operation +from google.genai import types from google.genai.types import FunctionDeclaration import requests from typing_extensions import override @@ -181,7 +182,9 @@ def from_parsed_operation_str( return RestApiTool.from_parsed_operation(operation) @override - def _get_declaration(self) -> FunctionDeclaration: + def _get_declaration( + self, ignore_return_declaration: bool = False + ) -> Optional[types.FunctionDeclaration]: """Returns the function declaration in the Gemini Schema format.""" schema_dict = self._operation_parser.get_json_schema() parameters = _to_gemini_schema(schema_dict) diff --git a/src/google/adk/tools/retrieval/base_retrieval_tool.py b/src/google/adk/tools/retrieval/base_retrieval_tool.py index 64f3ec91d..40c6c991a 100644 --- a/src/google/adk/tools/retrieval/base_retrieval_tool.py +++ b/src/google/adk/tools/retrieval/base_retrieval_tool.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import Optional from google.genai import types from typing_extensions import override @@ -21,7 +24,9 @@ class BaseRetrievalTool(BaseTool): @override - def _get_declaration(self) -> types.FunctionDeclaration: + def _get_declaration( + self, ignore_return_declaration: bool = False + ) -> Optional[types.FunctionDeclaration]: return types.FunctionDeclaration( name=self.name, description=self.description, diff --git a/src/google/adk/tools/set_model_response_tool.py b/src/google/adk/tools/set_model_response_tool.py index 6b27d55c2..da8fd1536 100644 --- a/src/google/adk/tools/set_model_response_tool.py +++ b/src/google/adk/tools/set_model_response_tool.py @@ -81,7 +81,9 @@ def set_model_response() -> str: ) @override - def _get_declaration(self) -> Optional[types.FunctionDeclaration]: + def _get_declaration( + self, ignore_return_declaration: bool = False + ) -> Optional[types.FunctionDeclaration]: """Gets the OpenAPI specification of this tool.""" function_decl = types.FunctionDeclaration.model_validate( build_function_declaration( diff --git a/tests/unittests/models/test_llm_request.py b/tests/unittests/models/test_llm_request.py index 789422968..8fcd877ad 100644 --- a/tests/unittests/models/test_llm_request.py +++ b/tests/unittests/models/test_llm_request.py @@ -26,6 +26,7 @@ from google.adk.tools.tool_context import ToolContext from google.genai import types import pytest +from typing_extensions import override def dummy_tool(query: str) -> str: @@ -178,7 +179,10 @@ class _MockTool(BaseTool): def __init__(self, name: str): super().__init__(name=name, description=f'Mock tool {name}') - def _get_declaration(self) -> Optional[types.FunctionDeclaration]: + @override + def _get_declaration( + self, ignore_return_declaration: bool = False + ) -> Optional[types.FunctionDeclaration]: return types.FunctionDeclaration( name=self.name, description=self.description, diff --git a/tests/unittests/tools/test_base_tool.py b/tests/unittests/tools/test_base_tool.py index da1dda64d..4264b1d04 100644 --- a/tests/unittests/tools/test_base_tool.py +++ b/tests/unittests/tools/test_base_tool.py @@ -22,6 +22,7 @@ from google.adk.tools.tool_context import ToolContext from google.genai import types import pytest +from typing_extensions import override class _TestingTool(BaseTool): @@ -33,7 +34,10 @@ def __init__( super().__init__(name='test_tool', description='test_description') self.declaration = declaration - def _get_declaration(self) -> Optional[types.FunctionDeclaration]: + @override + def _get_declaration( + self, ignore_return_declaration: bool = False + ) -> Optional[types.FunctionDeclaration]: return self.declaration