diff --git a/packages/toolbox-llamaindex/src/toolbox_llamaindex/tools.py b/packages/toolbox-llamaindex/src/toolbox_llamaindex/tools.py index 841e9427..9595366d 100644 --- a/packages/toolbox-llamaindex/src/toolbox_llamaindex/tools.py +++ b/packages/toolbox-llamaindex/src/toolbox_llamaindex/tools.py @@ -13,7 +13,7 @@ # limitations under the License. from asyncio import to_thread -from typing import Any, Callable, Union +from typing import Any, Awaitable, Callable, Mapping, Sequence, Union from deprecated import deprecated from llama_index.core.tools import ToolMetadata @@ -57,6 +57,32 @@ def metadata(self) -> ToolMetadata: ), ) + @property + def _bound_params( + self, + ) -> Mapping[str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]]: + return self.__core_tool._bound_params + + @property + def _required_authn_params(self) -> Mapping[str, list[str]]: + return self.__core_tool._required_authn_params + + @property + def _required_authz_tokens(self) -> Sequence[str]: + return self.__core_tool._required_authz_tokens + + @property + def _auth_service_token_getters( + self, + ) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]]]]: + return self.__core_tool._auth_service_token_getters + + @property + def _client_headers( + self, + ) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]: + return self.__core_tool._client_headers + def call(self, **kwargs: Any) -> ToolOutput: # type: ignore output_content = self.__core_tool(**kwargs) return ToolOutput( diff --git a/packages/toolbox-llamaindex/tests/test_tools.py b/packages/toolbox-llamaindex/tests/test_tools.py index 3f8cbabe..4ee3d84d 100644 --- a/packages/toolbox-llamaindex/tests/test_tools.py +++ b/packages/toolbox-llamaindex/tests/test_tools.py @@ -133,6 +133,12 @@ def mock_core_tool(self, tool_schema_dict): ) sync_mock.bind_params = Mock(return_value=new_mock_instance_for_methods) + sync_mock._bound_params = {"mock_bound_param": "mock_bound_value"} + sync_mock._required_authn_params = {"mock_auth_source": ["mock_param"]} + sync_mock._required_authz_tokens = ["mock_authz_token"] + sync_mock._auth_service_token_getters = {"mock_service": lambda: "mock_token"} + sync_mock._client_headers = {"mock_header": "mock_header_value"} + return sync_mock @pytest.fixture @@ -168,6 +174,12 @@ def mock_core_sync_auth_tool(self, auth_tool_schema_dict): return_value=new_mock_instance_for_methods ) sync_mock.bind_params = Mock(return_value=new_mock_instance_for_methods) + sync_mock._bound_params = {"mock_bound_param": "mock_bound_value"} + sync_mock._required_authn_params = {"mock_auth_source": ["mock_param"]} + sync_mock._required_authz_tokens = ["mock_authz_token"] + sync_mock._auth_service_token_getters = {"mock_service": lambda: "mock_token"} + sync_mock._client_headers = {"mock_header": "mock_header_value"} + return sync_mock @pytest.fixture @@ -317,3 +329,48 @@ async def to_thread_side_effect(func, *args, **kwargs_for_func): assert mock_core_tool.call_count == 1 assert mock_core_tool.call_args == call(**kwargs_to_run) + + def test_toolbox_tool_properties(self, toolbox_tool, mock_core_tool): + """Tests that the properties correctly proxy to the core tool.""" + assert toolbox_tool._bound_params == mock_core_tool._bound_params + assert ( + toolbox_tool._required_authn_params == mock_core_tool._required_authn_params + ) + assert ( + toolbox_tool._required_authz_tokens == mock_core_tool._required_authz_tokens + ) + assert ( + toolbox_tool._auth_service_token_getters + == mock_core_tool._auth_service_token_getters + ) + assert toolbox_tool._client_headers == mock_core_tool._client_headers + + def test_toolbox_tool_add_auth_tokens_deprecated( + self, auth_toolbox_tool, mock_core_sync_auth_tool + ): + """Tests the deprecated add_auth_tokens method.""" + auth_tokens = {"test-auth-source": lambda: "test-token"} + with pytest.warns(DeprecationWarning): + new_tool = auth_toolbox_tool.add_auth_tokens(auth_tokens) + + # Check that the call was correctly forwarded to the new method on the core tool + mock_core_sync_auth_tool.add_auth_token_getters.assert_called_once_with( + auth_tokens + ) + assert isinstance(new_tool, ToolboxTool) + + def test_toolbox_tool_add_auth_token_deprecated( + self, auth_toolbox_tool, mock_core_sync_auth_tool + ): + """Tests the deprecated add_auth_token method.""" + get_id_token = lambda: "test-token" + with pytest.warns(DeprecationWarning): + new_tool = auth_toolbox_tool.add_auth_token( + "test-auth-source", get_id_token + ) + + # Check that the call was correctly forwarded to the new method on the core tool + mock_core_sync_auth_tool.add_auth_token_getters.assert_called_once_with( + {"test-auth-source": get_id_token} + ) + assert isinstance(new_tool, ToolboxTool)