Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/google/adk/tools/function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,21 @@ async def run_async(
if 'tool_context' in valid_params:
args_to_call['tool_context'] = tool_context

# Filter args_to_call to only include valid parameters for the function
args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params}
# Check if function accepts **kwargs
has_kwargs = any(
param.kind == inspect.Parameter.VAR_KEYWORD
for param in signature.parameters.values()
)

if has_kwargs:
# For functions with **kwargs, pass all arguments except 'self' and 'tool_context'
args_to_call = {
k: v for k, v in args_to_call.items()
if k not in ('self', 'tool_context')
}
else:
# For functions without **kwargs, use the original filtering
args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params}

# Before invoking the function, we check for if the list of args passed in
# has all the mandatory arguments or not.
Expand Down
95 changes: 95 additions & 0 deletions tests/unittests/tools/test_function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,98 @@ def sample_func(arg1: str):
tool_context=tool_context_mock,
)
assert result == {"received_arg": "hello"}


@pytest.mark.asyncio
async def test_run_async_with_kwargs_crewai_style():
"""Test that run_async works with CrewAI-style functions that use **kwargs."""

def crewai_style_tool(*args, **kwargs):
"""CrewAI-style tool that accepts any keyword arguments."""
return {
"received_args": args,
"received_kwargs": kwargs,
"search_query": kwargs.get("search_query"),
"other_param": kwargs.get("other_param")
}

tool = FunctionTool(crewai_style_tool)
mock_invocation_context = MagicMock(spec=InvocationContext)
mock_invocation_context.session = MagicMock(spec=Session)
mock_invocation_context.session.state = MagicMock()
tool_context_mock = ToolContext(invocation_context=mock_invocation_context)

# Test with CrewAI-style parameters that should be passed through
result = await tool.run_async(
args={
"search_query": "test_query",
"other_param": "test_value"
},
tool_context=tool_context_mock,
)

assert result["search_query"] == "test_query"
assert result["other_param"] == "test_value"
assert result["received_kwargs"]["search_query"] == "test_query"
assert result["received_kwargs"]["other_param"] == "test_value"


@pytest.mark.asyncio
async def test_run_async_with_kwargs_crewai_style_async():
"""Test that run_async works with async CrewAI-style functions that use **kwargs."""

async def async_crewai_style_tool(*args, **kwargs):
"""Async CrewAI-style tool that accepts any keyword arguments."""
return {
"received_args": args,
"received_kwargs": kwargs,
"search_query": kwargs.get("search_query"),
"other_param": kwargs.get("other_param")
}

tool = FunctionTool(async_crewai_style_tool)
mock_invocation_context = MagicMock(spec=InvocationContext)
mock_invocation_context.session = MagicMock(spec=Session)
mock_invocation_context.session.state = MagicMock()
tool_context_mock = ToolContext(invocation_context=mock_invocation_context)

# Test with CrewAI-style parameters that should be passed through
result = await tool.run_async(
args={
"search_query": "async test query",
"other_param": "async test value"
},
tool_context=tool_context_mock,
)

assert result["search_query"] == "async test query"
assert result["other_param"] == "async test value"
assert result["received_kwargs"]["search_query"] == "async test query"
assert result["received_kwargs"]["other_param"] == "async test value"


@pytest.mark.asyncio
async def test_run_async_with_kwargs_backward_compatibility():
"""Test that the **kwargs fix maintains backward compatibility with explicit parameters."""

def explicit_params_func(arg1: str, arg2: int):
"""Function with explicit parameters (no **kwargs)."""
return {"arg1": arg1, "arg2": arg2}

tool = FunctionTool(explicit_params_func)
mock_invocation_context = MagicMock(spec=InvocationContext)
mock_invocation_context.session = MagicMock(spec=Session)
mock_invocation_context.session.state = MagicMock()
tool_context_mock = ToolContext(invocation_context=mock_invocation_context)

# Test that unexpected parameters are still filtered out for non-kwargs functions
result = await tool.run_async(
args={
"arg1": "test",
"arg2": 42,
"unexpected_param": "should_be_filtered"
},
tool_context=tool_context_mock,
)

assert result == {"arg1": "test", "arg2": 42}