diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index 1ab32d42b..a45316026 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -165,8 +165,22 @@ async def run_async( if 'tool_context' in valid_params: args_to_call['tool_context'] = tool_context - # Filter args_to_call to only include valid parameters for the function - args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params} + # Check if function accepts **kwargs + has_kwargs = any( + param.kind == inspect.Parameter.VAR_KEYWORD + for param in signature.parameters.values() + ) + + if has_kwargs: + # For functions with **kwargs, we pass all arguments. We defensively + # remove the `self` argument, which may be injected by some tool + # frameworks but is not intended for the wrapped function. + args_to_call.pop('self', None) + if 'tool_context' not in valid_params: + args_to_call.pop('tool_context', None) + else: + # For functions without **kwargs, use the original filtering. + args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params} # Before invoking the function, we check for if the list of args passed in # has all the mandatory arguments or not. diff --git a/tests/unittests/tools/test_function_tool.py b/tests/unittests/tools/test_function_tool.py index e7854a2c8..cfdbb2a21 100644 --- a/tests/unittests/tools/test_function_tool.py +++ b/tests/unittests/tools/test_function_tool.py @@ -22,6 +22,57 @@ import pytest +@pytest.fixture +def mock_tool_context() -> ToolContext: + """Fixture that provides a mock ToolContext for testing.""" + mock_invocation_context = MagicMock(spec=InvocationContext) + mock_invocation_context.session = MagicMock(spec=Session) + mock_invocation_context.session.state = MagicMock() + return ToolContext(invocation_context=mock_invocation_context) + + +def _crewai_style_tool_sync(*args, **kwargs): + """CrewAI-style tool that accepts any keyword arguments.""" + return { + "received_args": args, + "received_kwargs": kwargs, + "search_query": kwargs.get("search_query"), + "other_param": kwargs.get("other_param"), + } + + +async def _crewai_style_tool_async(*args, **kwargs): + """Async CrewAI-style tool that accepts any keyword arguments.""" + return { + "received_args": args, + "received_kwargs": kwargs, + "search_query": kwargs.get("search_query"), + "other_param": kwargs.get("other_param"), + } + + +def _func_with_context_and_kwargs_sync(arg1: str, tool_context: ToolContext, **kwargs): + """Function with explicit tool_context parameter and **kwargs.""" + return { + "arg1": arg1, + "tool_context_present": bool(tool_context), + "search_query": kwargs.get("search_query"), + "received_kwargs": kwargs, + } + + +async def _func_with_context_and_kwargs_async( + arg1: str, tool_context: ToolContext, **kwargs +): + """Async function with explicit tool_context parameter and **kwargs.""" + return { + "arg1": arg1, + "tool_context_present": bool(tool_context), + "search_query": kwargs.get("search_query"), + "received_kwargs": kwargs, + } + + def function_for_testing_with_no_args(): """Function for testing with no args.""" pass @@ -394,3 +445,97 @@ def sample_func(arg1: str): tool_context=tool_context_mock, ) assert result == {"received_arg": "hello"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "tool_function, search_query, other_param", + [ + (_crewai_style_tool_sync, "test_query", "test_value"), + (_crewai_style_tool_async, "async test query", "async test value"), + ], + ids=["sync", "async"], +) +async def test_run_async_with_kwargs_crewai_style( + mock_tool_context, tool_function, search_query, other_param +): + """Test that run_async works with CrewAI-style functions that use **kwargs.""" + tool = FunctionTool(tool_function) + + # Test with CrewAI-style parameters that should be passed through + result = await tool.run_async( + args={"search_query": search_query, "other_param": other_param}, + tool_context=mock_tool_context, + ) + + assert result["search_query"] == search_query + assert result["other_param"] == other_param + assert result["received_kwargs"]["search_query"] == search_query + assert result["received_kwargs"]["other_param"] == other_param + + +@pytest.mark.asyncio +async def test_run_async_with_kwargs_backward_compatibility(mock_tool_context): + """Test that the **kwargs fix maintains backward compatibility with explicit parameters.""" + + def explicit_params_func(arg1: str, arg2: int): + """Function with explicit parameters (no **kwargs).""" + return {"arg1": arg1, "arg2": arg2} + + tool = FunctionTool(explicit_params_func) + + # Test that unexpected parameters are still filtered out for non-kwargs functions + result = await tool.run_async( + args={ + "arg1": "test", + "arg2": 42, + "unexpected_param": "should_be_filtered" + }, + tool_context=mock_tool_context, + ) + + assert result == {"arg1": "test", "arg2": 42} + # Explicitly verify that unexpected_param was filtered out and not passed to the function + assert "unexpected_param" not in result + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "tool_function, args", + [ + ( + _func_with_context_and_kwargs_sync, + { + "arg1": "test_value", + "search_query": "omar elcircevi speaker", + "other_param": "test_value", + }, + ), + ( + _func_with_context_and_kwargs_async, + { + "arg1": "async_test_value", + "search_query": "async test query", + "other_param": "async test value", + }, + ), + ], + ids=["sync", "async"], +) +async def test_run_async_with_kwargs_and_tool_context( + mock_tool_context, tool_function, args +): + """Test that run_async works with functions that have both tool_context and **kwargs.""" + tool = FunctionTool(tool_function) + + # Test that both tool_context and **kwargs parameters work together + result = await tool.run_async( + args=args, + tool_context=mock_tool_context, + ) + + assert result["arg1"] == args["arg1"] + assert result["tool_context_present"] is True + assert result["search_query"] == args["search_query"] + assert result["received_kwargs"]["search_query"] == args["search_query"] + assert result["received_kwargs"]["other_param"] == args["other_param"]