Skip to content

Commit 571c802

Browse files
seanzhougooglecopybara-github
authored andcommitted
fix: Convert argument to pydantic model when tool declare to accept pydantic model as argument
PiperOrigin-RevId: 814273005
1 parent c46308b commit 571c802

File tree

2 files changed

+346
-2
lines changed

2 files changed

+346
-2
lines changed

src/google/adk/tools/function_tool.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,18 @@
1818
import logging
1919
from typing import Any
2020
from typing import Callable
21+
from typing import get_args
22+
from typing import get_origin
2123
from typing import Optional
2224
from typing import Union
2325

2426
from google.genai import types
27+
import pydantic
2528
from typing_extensions import override
2629

2730
from ..utils.context_utils import Aclosing
2831
from ._automatic_function_calling_util import build_function_declaration
2932
from .base_tool import BaseTool
30-
from .tool_confirmation import ToolConfirmation
3133
from .tool_context import ToolContext
3234

3335
logger = logging.getLogger('google_adk.' + __name__)
@@ -95,11 +97,69 @@ def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
9597

9698
return function_decl
9799

100+
def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]:
101+
"""Preprocess and convert function arguments before invocation.
102+
103+
Currently handles:
104+
- Converting JSON dictionaries to Pydantic model instances where expected
105+
106+
Future extensions could include:
107+
- Type coercion for other complex types
108+
- Validation and sanitization
109+
- Custom conversion logic
110+
111+
Args:
112+
args: Raw arguments from the LLM tool call
113+
114+
Returns:
115+
Processed arguments ready for function invocation
116+
"""
117+
signature = inspect.signature(self.func)
118+
converted_args = args.copy()
119+
120+
for param_name, param in signature.parameters.items():
121+
if param_name in args and param.annotation != inspect.Parameter.empty:
122+
target_type = param.annotation
123+
124+
# Handle Optional[PydanticModel] types
125+
if get_origin(param.annotation) is Union:
126+
union_args = get_args(param.annotation)
127+
# Find the non-None type in Optional[T] (which is Union[T, None])
128+
non_none_types = [arg for arg in union_args if arg is not type(None)]
129+
if len(non_none_types) == 1:
130+
target_type = non_none_types[0]
131+
132+
# Check if the target type is a Pydantic model
133+
if inspect.isclass(target_type) and issubclass(
134+
target_type, pydantic.BaseModel
135+
):
136+
# Skip conversion if the value is None and the parameter is Optional
137+
if args[param_name] is None:
138+
continue
139+
140+
# Convert to Pydantic model if it's not already the correct type
141+
if not isinstance(args[param_name], target_type):
142+
try:
143+
converted_args[param_name] = target_type.model_validate(
144+
args[param_name]
145+
)
146+
except Exception as e:
147+
logger.warning(
148+
f"Failed to convert argument '{param_name}' to Pydantic model"
149+
f' {target_type.__name__}: {e}'
150+
)
151+
# Keep the original value if conversion fails
152+
pass
153+
154+
return converted_args
155+
98156
@override
99157
async def run_async(
100158
self, *, args: dict[str, Any], tool_context: ToolContext
101159
) -> Any:
102-
args_to_call = args.copy()
160+
# Preprocess arguments (includes Pydantic model conversion)
161+
args_to_call = self._preprocess_args(args)
162+
103163
signature = inspect.signature(self.func)
104164
valid_params = {param for param in signature.parameters}
105165
if 'tool_context' in valid_params:
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Pydantic model conversion tests
16+
17+
from typing import Optional
18+
from unittest.mock import MagicMock
19+
20+
from google.adk.agents.invocation_context import InvocationContext
21+
from google.adk.sessions.session import Session
22+
from google.adk.tools.function_tool import FunctionTool
23+
from google.adk.tools.tool_context import ToolContext
24+
import pydantic
25+
import pytest
26+
27+
28+
class UserModel(pydantic.BaseModel):
29+
"""Test Pydantic model for user data."""
30+
31+
name: str
32+
age: int
33+
email: Optional[str] = None
34+
35+
36+
class PreferencesModel(pydantic.BaseModel):
37+
"""Test Pydantic model for preferences."""
38+
39+
theme: str = "light"
40+
notifications: bool = True
41+
42+
43+
def sync_function_with_pydantic_model(user: UserModel) -> dict:
44+
"""Sync function that takes a Pydantic model."""
45+
return {
46+
"name": user.name,
47+
"age": user.age,
48+
"email": user.email,
49+
"type": str(type(user).__name__),
50+
}
51+
52+
53+
async def async_function_with_pydantic_model(user: UserModel) -> dict:
54+
"""Async function that takes a Pydantic model."""
55+
return {
56+
"name": user.name,
57+
"age": user.age,
58+
"email": user.email,
59+
"type": str(type(user).__name__),
60+
}
61+
62+
63+
def function_with_optional_pydantic_model(
64+
user: UserModel, preferences: Optional[PreferencesModel] = None
65+
) -> dict:
66+
"""Function with required and optional Pydantic models."""
67+
result = {
68+
"user_name": user.name,
69+
"user_type": str(type(user).__name__),
70+
}
71+
if preferences:
72+
result.update({
73+
"theme": preferences.theme,
74+
"notifications": preferences.notifications,
75+
"preferences_type": str(type(preferences).__name__),
76+
})
77+
return result
78+
79+
80+
def function_with_mixed_args(
81+
name: str, user: UserModel, count: int = 5
82+
) -> dict:
83+
"""Function with mixed argument types including Pydantic model."""
84+
return {
85+
"name": name,
86+
"user_name": user.name,
87+
"user_type": str(type(user).__name__),
88+
"count": count,
89+
}
90+
91+
92+
def test_preprocess_args_with_dict_to_pydantic_conversion():
93+
"""Test _preprocess_args converts dict to Pydantic model."""
94+
tool = FunctionTool(sync_function_with_pydantic_model)
95+
96+
input_args = {
97+
"user": {"name": "Alice", "age": 30, "email": "[email protected]"}
98+
}
99+
100+
processed_args = tool._preprocess_args(input_args)
101+
102+
# Check that the dict was converted to a Pydantic model
103+
assert "user" in processed_args
104+
user = processed_args["user"]
105+
assert isinstance(user, UserModel)
106+
assert user.name == "Alice"
107+
assert user.age == 30
108+
assert user.email == "[email protected]"
109+
110+
111+
def test_preprocess_args_with_existing_pydantic_model():
112+
"""Test _preprocess_args leaves existing Pydantic model unchanged."""
113+
tool = FunctionTool(sync_function_with_pydantic_model)
114+
115+
# Create an existing Pydantic model
116+
existing_user = UserModel(name="Bob", age=25)
117+
input_args = {"user": existing_user}
118+
119+
processed_args = tool._preprocess_args(input_args)
120+
121+
# Check that the existing model was not changed (same object)
122+
assert "user" in processed_args
123+
user = processed_args["user"]
124+
assert user is existing_user
125+
assert isinstance(user, UserModel)
126+
assert user.name == "Bob"
127+
128+
129+
def test_preprocess_args_with_optional_pydantic_model_none():
130+
"""Test _preprocess_args handles None for optional Pydantic models."""
131+
tool = FunctionTool(function_with_optional_pydantic_model)
132+
133+
input_args = {"user": {"name": "Charlie", "age": 35}, "preferences": None}
134+
135+
processed_args = tool._preprocess_args(input_args)
136+
137+
# Check user conversion
138+
assert isinstance(processed_args["user"], UserModel)
139+
assert processed_args["user"].name == "Charlie"
140+
141+
# Check preferences remains None
142+
assert processed_args["preferences"] is None
143+
144+
145+
def test_preprocess_args_with_optional_pydantic_model_dict():
146+
"""Test _preprocess_args converts dict for optional Pydantic models."""
147+
tool = FunctionTool(function_with_optional_pydantic_model)
148+
149+
input_args = {
150+
"user": {"name": "Diana", "age": 28},
151+
"preferences": {"theme": "dark", "notifications": False},
152+
}
153+
154+
processed_args = tool._preprocess_args(input_args)
155+
156+
# Check both conversions
157+
assert isinstance(processed_args["user"], UserModel)
158+
assert processed_args["user"].name == "Diana"
159+
160+
assert isinstance(processed_args["preferences"], PreferencesModel)
161+
assert processed_args["preferences"].theme == "dark"
162+
assert processed_args["preferences"].notifications is False
163+
164+
165+
def test_preprocess_args_with_mixed_types():
166+
"""Test _preprocess_args handles mixed argument types correctly."""
167+
tool = FunctionTool(function_with_mixed_args)
168+
169+
input_args = {
170+
"name": "test_name",
171+
"user": {"name": "Eve", "age": 40},
172+
"count": 10,
173+
}
174+
175+
processed_args = tool._preprocess_args(input_args)
176+
177+
# Check that only Pydantic model was converted
178+
assert processed_args["name"] == "test_name" # string unchanged
179+
assert processed_args["count"] == 10 # int unchanged
180+
181+
# Check Pydantic model conversion
182+
assert isinstance(processed_args["user"], UserModel)
183+
assert processed_args["user"].name == "Eve"
184+
assert processed_args["user"].age == 40
185+
186+
187+
def test_preprocess_args_with_invalid_data_graceful_failure():
188+
"""Test _preprocess_args handles invalid data gracefully."""
189+
tool = FunctionTool(sync_function_with_pydantic_model)
190+
191+
# Invalid data that can't be converted to UserModel
192+
input_args = {"user": "invalid_string"} # string instead of dict/model
193+
194+
processed_args = tool._preprocess_args(input_args)
195+
196+
# Should keep original value when conversion fails
197+
assert processed_args["user"] == "invalid_string"
198+
199+
200+
def test_preprocess_args_with_non_pydantic_parameters():
201+
"""Test _preprocess_args ignores non-Pydantic parameters."""
202+
203+
def simple_function(name: str, age: int) -> dict:
204+
return {"name": name, "age": age}
205+
206+
tool = FunctionTool(simple_function)
207+
208+
input_args = {"name": "test", "age": 25}
209+
processed_args = tool._preprocess_args(input_args)
210+
211+
# Should remain unchanged (no Pydantic models to convert)
212+
assert processed_args == input_args
213+
214+
215+
@pytest.mark.asyncio
216+
async def test_run_async_with_pydantic_model_conversion_sync_function():
217+
"""Test run_async with Pydantic model conversion for sync function."""
218+
tool = FunctionTool(sync_function_with_pydantic_model)
219+
220+
tool_context_mock = MagicMock(spec=ToolContext)
221+
invocation_context_mock = MagicMock(spec=InvocationContext)
222+
session_mock = MagicMock(spec=Session)
223+
invocation_context_mock.session = session_mock
224+
tool_context_mock.invocation_context = invocation_context_mock
225+
226+
args = {"user": {"name": "Frank", "age": 45, "email": "[email protected]"}}
227+
228+
result = await tool.run_async(args=args, tool_context=tool_context_mock)
229+
230+
# Verify the function received a proper Pydantic model
231+
assert result["name"] == "Frank"
232+
assert result["age"] == 45
233+
assert result["email"] == "[email protected]"
234+
assert result["type"] == "UserModel"
235+
236+
237+
@pytest.mark.asyncio
238+
async def test_run_async_with_pydantic_model_conversion_async_function():
239+
"""Test run_async with Pydantic model conversion for async function."""
240+
tool = FunctionTool(async_function_with_pydantic_model)
241+
242+
tool_context_mock = MagicMock(spec=ToolContext)
243+
invocation_context_mock = MagicMock(spec=InvocationContext)
244+
session_mock = MagicMock(spec=Session)
245+
invocation_context_mock.session = session_mock
246+
tool_context_mock.invocation_context = invocation_context_mock
247+
248+
args = {"user": {"name": "Grace", "age": 32}}
249+
250+
result = await tool.run_async(args=args, tool_context=tool_context_mock)
251+
252+
# Verify the function received a proper Pydantic model
253+
assert result["name"] == "Grace"
254+
assert result["age"] == 32
255+
assert result["email"] is None # default value
256+
assert result["type"] == "UserModel"
257+
258+
259+
@pytest.mark.asyncio
260+
async def test_run_async_with_optional_pydantic_models():
261+
"""Test run_async with optional Pydantic models."""
262+
tool = FunctionTool(function_with_optional_pydantic_model)
263+
264+
tool_context_mock = MagicMock(spec=ToolContext)
265+
invocation_context_mock = MagicMock(spec=InvocationContext)
266+
session_mock = MagicMock(spec=Session)
267+
invocation_context_mock.session = session_mock
268+
tool_context_mock.invocation_context = invocation_context_mock
269+
270+
# Test with both required and optional models
271+
args = {
272+
"user": {"name": "Henry", "age": 50},
273+
"preferences": {"theme": "dark", "notifications": True},
274+
}
275+
276+
result = await tool.run_async(args=args, tool_context=tool_context_mock)
277+
278+
assert result["user_name"] == "Henry"
279+
assert result["user_type"] == "UserModel"
280+
assert result["theme"] == "dark"
281+
assert result["notifications"] is True
282+
assert result["preferences_type"] == "PreferencesModel"
283+
assert result["preferences_type"] == "PreferencesModel"
284+
assert result["preferences_type"] == "PreferencesModel"

0 commit comments

Comments
 (0)