Skip to content

fix(toolbox-llamaindex): Align ToolboxTool properties for consistency and debuggability #293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: anubhav-state-lc
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion packages/toolbox-llamaindex/src/toolbox_llamaindex/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from asyncio import to_thread
from typing import Any, Callable, Union
from typing import Any, Awaitable, Callable, Mapping, Sequence, Union

from deprecated import deprecated
from llama_index.core.tools import ToolMetadata
Expand Down Expand Up @@ -57,6 +57,32 @@ def metadata(self) -> ToolMetadata:
),
)

@property
def _bound_params(
self,
) -> Mapping[str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]]:
return self.__core_tool._bound_params

@property
def _required_authn_params(self) -> Mapping[str, list[str]]:
return self.__core_tool._required_authn_params

@property
def _required_authz_tokens(self) -> Sequence[str]:
return self.__core_tool._required_authz_tokens

@property
def _auth_service_token_getters(
self,
) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]]]]:
return self.__core_tool._auth_service_token_getters

@property
def _client_headers(
self,
) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]:
return self.__core_tool._client_headers

def call(self, **kwargs: Any) -> ToolOutput: # type: ignore
output_content = self.__core_tool(**kwargs)
return ToolOutput(
Expand Down
57 changes: 57 additions & 0 deletions packages/toolbox-llamaindex/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def mock_core_tool(self, tool_schema_dict):
)
sync_mock.bind_params = Mock(return_value=new_mock_instance_for_methods)

sync_mock._bound_params = {"mock_bound_param": "mock_bound_value"}
sync_mock._required_authn_params = {"mock_auth_source": ["mock_param"]}
sync_mock._required_authz_tokens = ["mock_authz_token"]
sync_mock._auth_service_token_getters = {"mock_service": lambda: "mock_token"}
sync_mock._client_headers = {"mock_header": "mock_header_value"}

return sync_mock

@pytest.fixture
Expand Down Expand Up @@ -168,6 +174,12 @@ def mock_core_sync_auth_tool(self, auth_tool_schema_dict):
return_value=new_mock_instance_for_methods
)
sync_mock.bind_params = Mock(return_value=new_mock_instance_for_methods)
sync_mock._bound_params = {"mock_bound_param": "mock_bound_value"}
sync_mock._required_authn_params = {"mock_auth_source": ["mock_param"]}
sync_mock._required_authz_tokens = ["mock_authz_token"]
sync_mock._auth_service_token_getters = {"mock_service": lambda: "mock_token"}
sync_mock._client_headers = {"mock_header": "mock_header_value"}

return sync_mock

@pytest.fixture
Expand Down Expand Up @@ -317,3 +329,48 @@ async def to_thread_side_effect(func, *args, **kwargs_for_func):

assert mock_core_tool.call_count == 1
assert mock_core_tool.call_args == call(**kwargs_to_run)

def test_toolbox_tool_properties(self, toolbox_tool, mock_core_tool):
"""Tests that the properties correctly proxy to the core tool."""
assert toolbox_tool._bound_params == mock_core_tool._bound_params
assert (
toolbox_tool._required_authn_params == mock_core_tool._required_authn_params
)
assert (
toolbox_tool._required_authz_tokens == mock_core_tool._required_authz_tokens
)
assert (
toolbox_tool._auth_service_token_getters
== mock_core_tool._auth_service_token_getters
)
assert toolbox_tool._client_headers == mock_core_tool._client_headers

def test_toolbox_tool_add_auth_tokens_deprecated(
self, auth_toolbox_tool, mock_core_sync_auth_tool
):
"""Tests the deprecated add_auth_tokens method."""
auth_tokens = {"test-auth-source": lambda: "test-token"}
with pytest.warns(DeprecationWarning):
new_tool = auth_toolbox_tool.add_auth_tokens(auth_tokens)

# Check that the call was correctly forwarded to the new method on the core tool
mock_core_sync_auth_tool.add_auth_token_getters.assert_called_once_with(
auth_tokens
)
assert isinstance(new_tool, ToolboxTool)

def test_toolbox_tool_add_auth_token_deprecated(
self, auth_toolbox_tool, mock_core_sync_auth_tool
):
"""Tests the deprecated add_auth_token method."""
get_id_token = lambda: "test-token"
with pytest.warns(DeprecationWarning):
new_tool = auth_toolbox_tool.add_auth_token(
"test-auth-source", get_id_token
)

# Check that the call was correctly forwarded to the new method on the core tool
mock_core_sync_auth_tool.add_auth_token_getters.assert_called_once_with(
{"test-auth-source": get_id_token}
)
assert isinstance(new_tool, ToolboxTool)