Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/google/adk/tools/function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,22 @@ async def run_async(
if 'tool_context' in valid_params:
args_to_call['tool_context'] = tool_context

# Filter args_to_call to only include valid parameters for the function
args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params}
# Check if function accepts **kwargs
has_kwargs = any(
param.kind == inspect.Parameter.VAR_KEYWORD
for param in signature.parameters.values()
)

if has_kwargs:
# For functions with **kwargs, we pass all arguments. We defensively
# remove the `self` argument, which may be injected by some tool
# frameworks but is not intended for the wrapped function.
args_to_call.pop('self', None)
if 'tool_context' not in valid_params:
args_to_call.pop('tool_context', None)
else:
# For functions without **kwargs, use the original filtering.
args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params}

# Before invoking the function, we check for if the list of args passed in
# has all the mandatory arguments or not.
Expand Down
143 changes: 143 additions & 0 deletions tests/unittests/tools/test_function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,57 @@
import pytest


@pytest.fixture
def mock_tool_context() -> ToolContext:
"""Fixture that provides a mock ToolContext for testing."""
mock_invocation_context = MagicMock(spec=InvocationContext)
mock_invocation_context.session = MagicMock(spec=Session)
mock_invocation_context.session.state = MagicMock()
return ToolContext(invocation_context=mock_invocation_context)


def _crewai_style_tool_sync(*args, **kwargs):
"""CrewAI-style tool that accepts any keyword arguments."""
return {
"received_args": args,
"received_kwargs": kwargs,
"search_query": kwargs.get("search_query"),
"other_param": kwargs.get("other_param"),
}


async def _crewai_style_tool_async(*args, **kwargs):
"""Async CrewAI-style tool that accepts any keyword arguments."""
return {
"received_args": args,
"received_kwargs": kwargs,
"search_query": kwargs.get("search_query"),
"other_param": kwargs.get("other_param"),
}


def _func_with_context_and_kwargs_sync(arg1: str, tool_context: ToolContext, **kwargs):
"""Function with explicit tool_context parameter and **kwargs."""
return {
"arg1": arg1,
"tool_context_present": bool(tool_context),
"search_query": kwargs.get("search_query"),
"received_kwargs": kwargs,
}


async def _func_with_context_and_kwargs_async(
arg1: str, tool_context: ToolContext, **kwargs
):
"""Async function with explicit tool_context parameter and **kwargs."""
return {
"arg1": arg1,
"tool_context_present": bool(tool_context),
"search_query": kwargs.get("search_query"),
"received_kwargs": kwargs,
}


def function_for_testing_with_no_args():
"""Function for testing with no args."""
pass
Expand Down Expand Up @@ -394,3 +445,95 @@ def sample_func(arg1: str):
tool_context=tool_context_mock,
)
assert result == {"received_arg": "hello"}


@pytest.mark.asyncio
@pytest.mark.parametrize(
"tool_function, search_query, other_param",
[
(_crewai_style_tool_sync, "test_query", "test_value"),
(_crewai_style_tool_async, "async test query", "async test value"),
],
ids=["sync", "async"],
)
async def test_run_async_with_kwargs_crewai_style(
mock_tool_context, tool_function, search_query, other_param
):
"""Test that run_async works with CrewAI-style functions that use **kwargs."""
tool = FunctionTool(tool_function)

# Test with CrewAI-style parameters that should be passed through
result = await tool.run_async(
args={"search_query": search_query, "other_param": other_param},
tool_context=mock_tool_context,
)

assert result["search_query"] == search_query
assert result["other_param"] == other_param
assert result["received_kwargs"]["search_query"] == search_query
assert result["received_kwargs"]["other_param"] == other_param


@pytest.mark.asyncio
async def test_run_async_with_kwargs_backward_compatibility(mock_tool_context):
"""Test that the **kwargs fix maintains backward compatibility with explicit parameters."""

def explicit_params_func(arg1: str, arg2: int):
"""Function with explicit parameters (no **kwargs)."""
return {"arg1": arg1, "arg2": arg2}

tool = FunctionTool(explicit_params_func)

# Test that unexpected parameters are still filtered out for non-kwargs functions
result = await tool.run_async(
args={
"arg1": "test",
"arg2": 42,
"unexpected_param": "should_be_filtered"
},
tool_context=mock_tool_context,
)

assert result == {"arg1": "test", "arg2": 42}


@pytest.mark.asyncio
@pytest.mark.parametrize(
"tool_function, args",
[
(
_func_with_context_and_kwargs_sync,
{
"arg1": "test_value",
"search_query": "omar elcircevi speaker",
"other_param": "test_value",
},
),
(
_func_with_context_and_kwargs_async,
{
"arg1": "async_test_value",
"search_query": "async test query",
"other_param": "async test value",
},
),
],
ids=["sync", "async"],
)
async def test_run_async_with_kwargs_and_tool_context(
mock_tool_context, tool_function, args
):
"""Test that run_async works with functions that have both tool_context and **kwargs."""
tool = FunctionTool(tool_function)

# Test that both tool_context and **kwargs parameters work together
result = await tool.run_async(
args=args,
tool_context=mock_tool_context,
)

assert result["arg1"] == args["arg1"]
assert result["tool_context_present"] is True
assert result["search_query"] == args["search_query"]
assert result["received_kwargs"]["search_query"] == args["search_query"]
assert result["received_kwargs"]["other_param"] == args["other_param"]