diff --git a/packages/toolbox-llamaindex/README.md b/packages/toolbox-llamaindex/README.md index 06a09192..1d6a2621 100644 --- a/packages/toolbox-llamaindex/README.md +++ b/packages/toolbox-llamaindex/README.md @@ -9,29 +9,27 @@ applications, enabling advanced orchestration and interaction with GenAI models. ## Table of Contents -- [MCP Toolbox LlamaIndex SDK](#mcp-toolbox-llamaindex-sdk) - - [Installation](#installation) - - [Quickstart](#quickstart) -- [TODO: add link](#todo-add-link) - - [Usage](#usage) - - [Loading Tools](#loading-tools) - - [Load a toolset](#load-a-toolset) - - [Load a single tool](#load-a-single-tool) - - [Use with LlamaIndex](#use-with-llamaindex) - - [Maintain state](#maintain-state) - - [Manual usage](#manual-usage) - - [Authenticating Tools](#authenticating-tools) - - [Supported Authentication Mechanisms](#supported-authentication-mechanisms) - - [Configure Tools](#configure-tools) - - [Configure SDK](#configure-sdk) - - [Add Authentication to a Tool](#add-authentication-to-a-tool) - - [Add Authentication While Loading](#add-authentication-while-loading) - - [Complete Example](#complete-example) - - [Binding Parameter Values](#binding-parameter-values) - - [Binding Parameters to a Tool](#binding-parameters-to-a-tool) - - [Binding Parameters While Loading](#binding-parameters-while-loading) - - [Binding Dynamic Values](#binding-dynamic-values) - - [Asynchronous Usage](#asynchronous-usage) +- [Installation](#installation) +- [Quickstart](#quickstart) +- [Usage](#usage) +- [Loading Tools](#loading-tools) + - [Load a toolset](#load-a-toolset) + - [Load a single tool](#load-a-single-tool) +- [Use with LlamaIndex](#use-with-llamaindex) + - [Maintain state](#maintain-state) +- [Manual usage](#manual-usage) +- [Authenticating Tools](#authenticating-tools) + - [Supported Authentication Mechanisms](#supported-authentication-mechanisms) + - [Configure Tools](#configure-tools) + - [Configure SDK](#configure-sdk) + - [Add Authentication to a Tool](#add-authentication-to-a-tool) + - [Add Authentication While Loading](#add-authentication-while-loading) + - [Complete Example](#complete-example) +- [Binding Parameter Values](#binding-parameter-values) + - [Binding Parameters to a Tool](#binding-parameters-to-a-tool) + - [Binding Parameters While Loading](#binding-parameters-while-loading) + - [Binding Dynamic Values](#binding-dynamic-values) +- [Asynchronous Usage](#asynchronous-usage) @@ -44,8 +42,7 @@ pip install toolbox-llamaindex ## Quickstart Here's a minimal example to get you started using -# TODO: add link -[LlamaIndex](): +[LlamaIndex](https://docs.llamaindex.ai/en/stable/#getting-started): ```py import asyncio @@ -111,7 +108,7 @@ available to your LLM agent. ## Use with LlamaIndex -LangChain's agents can dynamically choose and execute tools based on the user +LlamaIndex's agents can dynamically choose and execute tools based on the user input. Include tools loaded from the Toolbox SDK in the agent's toolkit: ```py @@ -165,7 +162,7 @@ print(response) Execute a tool manually using the `call` method: ```py -result = tools[0].call({"name": "Alice", "age": 30}) +result = tools[0].call(name="Alice", age=30) ``` This is useful for testing tools or when you need precise control over tool @@ -210,21 +207,21 @@ async def get_auth_token(): toolbox = ToolboxClient("http://127.0.0.1:5000") tools = toolbox.load_toolset() -auth_tool = tools[0].add_auth_token("my_auth", get_auth_token) # Single token +auth_tool = tools[0].add_auth_token_getter("my_auth", get_auth_token) # Single token -multi_auth_tool = tools[0].add_auth_tokens({"my_auth", get_auth_token}) # Multiple tokens +multi_auth_tool = tools[0].add_auth_token_getters({"auth_1": get_auth_1}, {"auth_2": get_auth_2}) # Multiple tokens # OR -auth_tools = [tool.add_auth_token("my_auth", get_auth_token) for tool in tools] +auth_tools = [tool.add_auth_token_getter("my_auth", get_auth_token) for tool in tools] ``` #### Add Authentication While Loading ```py -auth_tool = toolbox.load_tool(auth_tokens={"my_auth": get_auth_token}) +auth_tool = toolbox.load_tool(auth_token_getters={"my_auth": get_auth_token}) -auth_tools = toolbox.load_toolset(auth_tokens={"my_auth": get_auth_token}) +auth_tools = toolbox.load_toolset(auth_token_getters={"my_auth": get_auth_token}) ``` > [!NOTE] @@ -245,8 +242,8 @@ async def get_auth_token(): toolbox = ToolboxClient("http://127.0.0.1:5000") tool = toolbox.load_tool("my-tool") -auth_tool = tool.add_auth_token("my_auth", get_auth_token) -result = auth_tool.call({"input": "some input"}) +auth_tool = tool.add_auth_token_getter("my_auth", get_auth_token) +result = auth_tool.call(input="some input") print(result) ``` diff --git a/packages/toolbox-llamaindex/integration.cloudbuild.yaml b/packages/toolbox-llamaindex/integration.cloudbuild.yaml index 5ed8277d..1eee540a 100644 --- a/packages/toolbox-llamaindex/integration.cloudbuild.yaml +++ b/packages/toolbox-llamaindex/integration.cloudbuild.yaml @@ -15,10 +15,11 @@ steps: - id: Install library requirements name: 'python:${_VERSION}' + dir: 'packages/toolbox-llamaindex' args: - install - '-r' - - 'packages/toolbox-llamaindex/requirements.txt' + - 'requirements.txt' - '--user' entrypoint: pip - id: Install test requirements diff --git a/packages/toolbox-llamaindex/pyproject.toml b/packages/toolbox-llamaindex/pyproject.toml index 74db5fa9..a706eab1 100644 --- a/packages/toolbox-llamaindex/pyproject.toml +++ b/packages/toolbox-llamaindex/pyproject.toml @@ -1,5 +1,5 @@ [project] -name = "toolbox-llamindex" +name = "toolbox-llamaindex" dynamic = ["version"] readme = "README.md" description = "Python SDK for interacting with the Toolbox service with LlamaIndex" @@ -9,6 +9,8 @@ authors = [ {name = "Google LLC", email = "googleapis-packages@google.com"} ] dependencies = [ + # TODO: Bump toolbox-core version to 0.2.0 + "toolbox-core==0.1.0", "llama-index>=0.12.0,<1.0.0", "PyYAML>=6.0.1,<7.0.0", "pydantic>=2.8.0,<3.0.0", diff --git a/packages/toolbox-llamaindex/requirements.txt b/packages/toolbox-llamaindex/requirements.txt index be1bc672..2825b8ad 100644 --- a/packages/toolbox-llamaindex/requirements.txt +++ b/packages/toolbox-llamaindex/requirements.txt @@ -1,3 +1,4 @@ +-e ../toolbox-core llama-index==0.12.33 PyYAML==6.0.2 pydantic==2.11.4 diff --git a/packages/toolbox-llamaindex/src/toolbox_llamaindex/async_client.py b/packages/toolbox-llamaindex/src/toolbox_llamaindex/async_client.py index b65c8ccf..95e384c8 100644 --- a/packages/toolbox-llamaindex/src/toolbox_llamaindex/async_client.py +++ b/packages/toolbox-llamaindex/src/toolbox_llamaindex/async_client.py @@ -16,9 +16,9 @@ from warnings import warn from aiohttp import ClientSession +from toolbox_core.client import ToolboxClient as ToolboxCoreClient -from .tools import AsyncToolboxTool -from .utils import ManifestSchema, _load_manifest +from .async_tools import AsyncToolboxTool # This class is an internal implementation detail and is not exposed to the @@ -38,67 +38,72 @@ def __init__( url: The base URL of the Toolbox service. session: An HTTP client session. """ - self.__url = url - self.__session = session + self.__core_client = ToolboxCoreClient(url=url, session=session) async def aload_tool( self, tool_name: str, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, ) -> AsyncToolboxTool: """ Loads the tool with the given tool name from the Toolbox service. Args: tool_name: The name of the tool to load. - auth_tokens: An optional mapping of authentication source names to - functions that retrieve ID tokens. - auth_headers: Deprecated. Use `auth_tokens` instead. + auth_token_getters: An optional mapping of authentication source + names to functions that retrieve ID tokens. + auth_tokens: Deprecated. Use `auth_token_getters` instead. + auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. Returns: A tool loaded from the Toolbox. """ + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens + if auth_headers: - if auth_tokens: + if auth_token_getters: warn( - "Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.", + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", DeprecationWarning, ) else: warn( - "Argument `auth_headers` is deprecated. Use `auth_tokens` instead.", + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", DeprecationWarning, ) - auth_tokens = auth_headers - - url = f"{self.__url}/api/tool/{tool_name}" - manifest: ManifestSchema = await _load_manifest(url, self.__session) - - return AsyncToolboxTool( - tool_name, - manifest.tools[tool_name], - self.__url, - self.__session, - auth_tokens, - bound_params, - strict, + auth_token_getters = auth_headers + + core_tool = await self.__core_client.load_tool( + name=tool_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, ) + return AsyncToolboxTool(core_tool=core_tool) async def aload_toolset( self, toolset_name: Optional[str] = None, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + strict: bool = False, ) -> list[AsyncToolboxTool]: """ Loads tools from the Toolbox service, optionally filtered by toolset @@ -107,65 +112,76 @@ async def aload_toolset( Args: toolset_name: The name of the toolset to load. If not provided, all tools are loaded. - auth_tokens: An optional mapping of authentication source names to - functions that retrieve ID tokens. - auth_headers: Deprecated. Use `auth_tokens` instead. + auth_token_getters: An optional mapping of authentication source + names to functions that retrieve ID tokens. + auth_tokens: Deprecated. Use `auth_token_getters` instead. + auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. + strict: If True, raises an error if *any* loaded tool instance fails + to utilize at least one provided parameter or auth token (if any + provided). If False (default), raises an error only if a + user-provided parameter or auth token cannot be applied to *any* + loaded tool across the set. Returns: A list of all tools loaded from the Toolbox. """ - if auth_headers: - if auth_tokens: + if auth_tokens: + if auth_token_getters: warn( - "Both `auth_tokens` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_tokens` will be used.", + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", DeprecationWarning, ) else: warn( - "Argument `auth_headers` is deprecated. Use `auth_tokens` instead.", + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", DeprecationWarning, ) - auth_tokens = auth_headers - - url = f"{self.__url}/api/toolset/{toolset_name or ''}" - manifest: ManifestSchema = await _load_manifest(url, self.__session) - tools: list[AsyncToolboxTool] = [] - - for tool_name, tool_schema in manifest.tools.items(): - tools.append( - AsyncToolboxTool( - tool_name, - tool_schema, - self.__url, - self.__session, - auth_tokens, - bound_params, - strict, + auth_token_getters = auth_tokens + + if auth_headers: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, ) - ) + auth_token_getters = auth_headers + + core_tools = await self.__core_client.load_toolset( + name=toolset_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=strict, + ) + + tools = [] + for core_tool in core_tools: + tools.append(AsyncToolboxTool(core_tool=core_tool)) return tools def load_tool( self, tool_name: str, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, ) -> AsyncToolboxTool: raise NotImplementedError("Synchronous methods not supported by async client.") def load_toolset( self, toolset_name: Optional[str] = None, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + strict: bool = False, ) -> list[AsyncToolboxTool]: raise NotImplementedError("Synchronous methods not supported by async client.") diff --git a/packages/toolbox-llamaindex/src/toolbox_llamaindex/async_tools.py b/packages/toolbox-llamaindex/src/toolbox_llamaindex/async_tools.py index 879df74e..6db8ae21 100644 --- a/packages/toolbox-llamaindex/src/toolbox_llamaindex/async_tools.py +++ b/packages/toolbox-llamaindex/src/toolbox_llamaindex/async_tools.py @@ -12,23 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy -from typing import Any, Callable, TypeVar, Union -from warnings import warn +from typing import Any, Callable, Union -from aiohttp import ClientResponseError, ClientSession +from deprecated import deprecated from llama_index.core.tools import ToolMetadata from llama_index.core.tools.types import AsyncBaseTool, ToolOutput - -from .utils import ( - ToolSchema, - _find_auth_params, - _find_bound_params, - _invoke_tool, - _schema_to_model, -) - -T = TypeVar("T") +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool +from toolbox_core.utils import params_to_pydantic_model # This class is an internal implementation detail and is not exposed to the @@ -42,107 +32,30 @@ class AsyncToolboxTool(AsyncBaseTool): def __init__( self, - name: str, - schema: ToolSchema, - url: str, - session: ClientSession, - auth_tokens: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + core_tool: ToolboxCoreTool, ) -> None: """ Initializes an AsyncToolboxTool instance. Args: - name: The name of the tool. - schema: The tool schema. - url: The base URL of the Toolbox service. - session: The HTTP client session. - auth_tokens: A mapping of authentication source names to functions - that retrieve ID tokens. - bound_params: A mapping of parameter names to their bound - values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. + core_tool: The underlying core async ToolboxTool instance. """ - # If the schema is not already a ToolSchema instance, we create one from - # its attributes. This allows flexibility in how the schema is provided, - # accepting both a ToolSchema object and a dictionary of schema - # attributes. - if not isinstance(schema, ToolSchema): - schema = ToolSchema(**schema) - - auth_params, non_auth_params = _find_auth_params(schema.parameters) - non_auth_bound_params, non_auth_non_bound_params = _find_bound_params( - non_auth_params, list(bound_params) - ) - - # Check if the user is trying to bind a param that is authenticated or - # is missing from the given schema. - auth_bound_params: list[str] = [] - missing_bound_params: list[str] = [] - for bound_param in bound_params: - if bound_param in [param.name for param in auth_params]: - auth_bound_params.append(bound_param) - elif bound_param not in [param.name for param in non_auth_params]: - missing_bound_params.append(bound_param) - - # Create error messages for any params that are found to be - # authenticated or missing. - messages: list[str] = [] - if auth_bound_params: - messages.append( - f"Parameter(s) {', '.join(auth_bound_params)} already authenticated and cannot be bound." - ) - if missing_bound_params: - messages.append( - f"Parameter(s) {', '.join(missing_bound_params)} missing and cannot be bound." - ) - - # Join any error messages and raise them as an error or warning, - # depending on the value of the strict flag. - if messages: - message = "\n\n".join(messages) - if strict: - raise ValueError(message) - warn(message) - - # Bind values for parameters present in the schema that don't require - # authentication. - bound_params = { - param_name: param_value - for param_name, param_value in bound_params.items() - if param_name in [param.name for param in non_auth_bound_params] - } - - # Update the tools schema to validate only the presence of parameters - # that neither require authentication nor are bound. - schema.parameters = non_auth_non_bound_params - # Due to how pydantic works, we must initialize the underlying # AsyncBaseTool class before assigning values to member variables. super().__init__() - self.__name = name - self.__schema = schema - self.__url = url - self.__session = session - self.__auth_tokens = auth_tokens - self.__auth_params = auth_params - self.__bound_params = bound_params - - # Warn users about any missing authentication so they can add it before - # tool invocation. - self.__validate_auth(strict=False) + self.__core_tool = core_tool @property def metadata(self) -> ToolMetadata: + if self.__core_tool.__doc__ is None: + raise ValueError("No description found for the tool.") + return ToolMetadata( - name=self.__name, - description=self.__schema.description, - fn_schema=_schema_to_model( - model_name=self.__name, schema=self.__schema.parameters + name=self.__core_tool.__name__, + description=self.__core_tool.__doc__, + fn_schema=params_to_pydantic_model( + self.__core_tool._name, self.__core_tool._params ), ) @@ -154,182 +67,45 @@ async def acall(self, **kwargs: Any) -> ToolOutput: # type: ignore The coroutine that invokes the tool with the given arguments. Args: - kwargs: The arguments to the tool. + **kwargs: The arguments to the tool. Returns: A dictionary containing the parsed JSON response from the tool invocation. """ - # Validate arguments with the schema - if self.metadata.fn_schema: - self.metadata.fn_schema.model_validate(kwargs) - - # If the tool had parameters that require authentication, then right - # before invoking that tool, we check whether all these required - # authentication sources have been registered or not. - self.__validate_auth() - - # Evaluate dynamic parameter values if any - evaluated_params = {} - for param_name, param_value in self.__bound_params.items(): - if callable(param_value): - evaluated_params[param_name] = param_value() - else: - evaluated_params[param_name] = param_value - - # Merge bound parameters with the provided arguments - kwargs.update(evaluated_params) - try: - response = await _invoke_tool( - self.__url, self.__session, self.__name, kwargs, self.__auth_tokens - ) - return ToolOutput( - content=str(response), - tool_name=self.__name, - raw_input=kwargs, - raw_output=response, - is_error=False, - ) - except ClientResponseError as e: - return ToolOutput( - content="Encountered error: " + str(e), - tool_name=self.__name, - raw_input=kwargs, - raw_output=str(e), - is_error=True, - ) - - def __validate_auth(self, strict: bool = True) -> None: - """ - Checks if a tool meets the authentication requirements. - - A tool is considered authenticated if all of its parameters meet at - least one of the following conditions: - - * The parameter has at least one registered authentication source. - * The parameter requires no authentication. - - Args: - strict: If True, raises a PermissionError if any required - authentication sources are not registered. If False, only issues - a warning. - - Raises: - PermissionError: If strict is True and any required authentication - sources are not registered. - """ - params_missing_auth: list[str] = [] - - # Check each parameter for at least 1 required auth source - for param in self.__auth_params: - if not param.authSources: - raise ValueError("Auth sources cannot be None.") - has_auth = False - for src in param.authSources: - - # Find first auth source that is specified - if src in self.__auth_tokens: - has_auth = True - break - if not has_auth: - params_missing_auth.append(param.name) - - if params_missing_auth: - message = f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name} require authentication, but no valid authentication sources are registered. Please register the required sources before use." - - if strict: - raise PermissionError(message) - warn(message) - - def __create_copy( - self, - *, - auth_tokens: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool, - ) -> "AsyncToolboxTool": - """ - Creates a copy of the current AsyncToolboxTool instance, allowing for - modification of auth tokens and bound params. - - This method enables the creation of new tool instances with inherited - properties from the current instance, while optionally updating the auth - tokens and bound params. This is useful for creating variations of the - tool with additional auth tokens or bound params without modifying the - original instance, ensuring immutability. - - Args: - auth_tokens: A dictionary of auth source names to functions that - retrieve ID tokens. These tokens will be merged with the - existing auth tokens. - bound_params: A dictionary of parameter names to their - bound values or functions to retrieve the values. These params - will be merged with the existing bound params. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. - - Returns: - A new AsyncToolboxTool instance that is a deep copy of the current - instance, with added auth tokens or bound params. - """ - new_schema = deepcopy(self.__schema) - - # Reconstruct the complete parameter schema by merging the auth - # parameters back with the non-auth parameters. This is necessary to - # accurately validate the new combination of auth tokens and bound - # params in the constructor of the new AsyncToolboxTool instance, ensuring - # that any overlaps or conflicts are correctly identified and reported - # as errors or warnings, depending on the given `strict` flag. - new_schema.parameters += self.__auth_params - return AsyncToolboxTool( - name=self.__name, - schema=new_schema, - url=self.__url, - session=self.__session, - auth_tokens={**self.__auth_tokens, **auth_tokens}, - bound_params={**self.__bound_params, **bound_params}, - strict=strict, + output_content = await self.__core_tool(**kwargs) + return ToolOutput( + content=output_content, + tool_name=self.__core_tool.__name__, + raw_input=kwargs, + raw_output=output_content, ) - def add_auth_tokens( - self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True + def add_auth_token_getters( + self, auth_token_getters: dict[str, Callable[[], str]] ) -> "AsyncToolboxTool": """ Registers functions to retrieve ID tokens for the corresponding authentication sources. Args: - auth_tokens: A dictionary of authentication source names to the - functions that return corresponding ID token. - strict: If True, a ValueError is raised if any of the provided auth - tokens are already bound. If False, only a warning is issued. + auth_token_getters: A dictionary of authentication source names to + the functions that return corresponding ID token getters. Returns: A new AsyncToolboxTool instance that is a deep copy of the current - instance, with added auth tokens. + instance, with added auth token getters. Raises: - ValueError: If the provided auth tokens are already registered. - ValueError: If the provided auth tokens are already bound and strict - is True. - """ - - # Check if the authentication source is already registered. - dupe_tokens: list[str] = [] - for auth_token, _ in auth_tokens.items(): - if auth_token in self.__auth_tokens: - dupe_tokens.append(auth_token) + ValueError: If any of the provided auth parameters is already + registered. - if dupe_tokens: - raise ValueError( - f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name}`." - ) - - return self.__create_copy(auth_tokens=auth_tokens, strict=strict) + """ + new_core_tool = self.__core_tool.add_auth_token_getters(auth_token_getters) + return AsyncToolboxTool(core_tool=new_core_tool) - def add_auth_token( - self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + def add_auth_token_getter( + self, auth_source: str, get_id_token: Callable[[], str] ) -> "AsyncToolboxTool": """ Registers a function to retrieve an ID token for a given authentication @@ -338,24 +114,32 @@ def add_auth_token( Args: auth_source: The name of the authentication source. get_id_token: A function that returns the ID token. - strict: If True, a ValueError is raised if any of the provided auth - token is already bound. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current - instance, with added auth token. + instance, with added auth token getter. Raises: - ValueError: If the provided auth token is already registered. - ValueError: If the provided auth token is already bound and strict - is True. + ValueError: If the provided auth parameter is already registered. + """ - return self.add_auth_tokens({auth_source: get_id_token}, strict=strict) + return self.add_auth_token_getters({auth_source: get_id_token}) + + @deprecated("Please use `add_auth_token_getters` instead.") + def add_auth_tokens( + self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True + ) -> "AsyncToolboxTool": + return self.add_auth_token_getters(auth_tokens) + + @deprecated("Please use `add_auth_token_getter` instead.") + def add_auth_token( + self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + ) -> "AsyncToolboxTool": + return self.add_auth_token_getter(auth_source, get_id_token) def bind_params( self, bound_params: dict[str, Union[Any, Callable[[], Any]]], - strict: bool = True, ) -> "AsyncToolboxTool": """ Registers values or functions to retrieve the value for the @@ -364,38 +148,21 @@ def bind_params( Args: bound_params: A dictionary of the bound parameter name to the value or function of the bound value. - strict: If True, a ValueError is raised if any of the provided bound - params are not defined in the tool's schema, or require - authentication. If False, only a warning is issued. Returns: A new AsyncToolboxTool instance that is a deep copy of the current instance, with added bound params. Raises: - ValueError: If the provided bound params are already bound. - ValueError: if the provided bound params are not defined in the tool's schema, or require - authentication, and strict is True. + ValueError: If any of the provided bound params is already bound. """ - - # Check if the parameter is already bound. - dupe_params: list[str] = [] - for param_name, _ in bound_params.items(): - if param_name in self.__bound_params: - dupe_params.append(param_name) - - if dupe_params: - raise ValueError( - f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self.__name}`." - ) - - return self.__create_copy(bound_params=bound_params, strict=strict) + new_core_tool = self.__core_tool.bind_params(bound_params) + return AsyncToolboxTool(core_tool=new_core_tool) def bind_param( self, param_name: str, param_value: Union[Any, Callable[[], Any]], - strict: bool = True, ) -> "AsyncToolboxTool": """ Registers a value or a function to retrieve the value for a given bound @@ -405,9 +172,6 @@ def bind_param( param_name: The name of the bound parameter. param_value: The value of the bound parameter, or a callable that returns the value. - strict: If True, a ValueError is raised if any of the provided bound - params is not defined in the tool's schema, or requires - authentication. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current @@ -415,7 +179,5 @@ def bind_param( Raises: ValueError: If the provided bound param is already bound. - ValueError: if the provided bound param is not defined in the tool's - schema, or requires authentication, and strict is True. """ - return self.bind_params({param_name: param_value}, strict) + return self.bind_params({param_name: param_value}) diff --git a/packages/toolbox-llamaindex/src/toolbox_llamaindex/client.py b/packages/toolbox-llamaindex/src/toolbox_llamaindex/client.py index 5079beab..cede90e8 100644 --- a/packages/toolbox-llamaindex/src/toolbox_llamaindex/client.py +++ b/packages/toolbox-llamaindex/src/toolbox_llamaindex/client.py @@ -12,22 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -from threading import Thread -from typing import Any, Awaitable, Callable, Optional, TypeVar, Union +from asyncio import to_thread +from typing import Any, Callable, Optional, Union +from warnings import warn -from aiohttp import ClientSession +from toolbox_core.sync_client import ToolboxSyncClient as ToolboxCoreSyncClient +from toolbox_core.sync_tool import ToolboxSyncTool -from .async_client import AsyncToolboxClient from .tools import ToolboxTool -T = TypeVar("T") - class ToolboxClient: - __session: Optional[ClientSession] = None - __loop: Optional[asyncio.AbstractEventLoop] = None - __thread: Optional[Thread] = None def __init__( self, @@ -39,94 +34,73 @@ def __init__( Args: url: The base URL of the Toolbox service. """ - - # Running a loop in a background thread allows us to support async - # methods from non-async environments. - if ToolboxClient.__loop is None: - loop = asyncio.new_event_loop() - thread = Thread(target=loop.run_forever, daemon=True) - thread.start() - ToolboxClient.__thread = thread - ToolboxClient.__loop = loop - - async def __start_session() -> None: - - # Use a default session if none is provided. This leverages connection - # pooling for better performance by reusing a single session throughout - # the application's lifetime. - if ToolboxClient.__session is None: - ToolboxClient.__session = ClientSession() - - coro = __start_session() - - asyncio.run_coroutine_threadsafe(coro, ToolboxClient.__loop).result() - - if not ToolboxClient.__session: - raise ValueError("Session cannot be None.") - self.__async_client = AsyncToolboxClient(url, ToolboxClient.__session) - - def __run_as_sync(self, coro: Awaitable[T]) -> T: - """Run an async coroutine synchronously""" - if not self.__loop: - raise Exception( - "Cannot call synchronous methods before the background loop is initialized." - ) - return asyncio.run_coroutine_threadsafe(coro, self.__loop).result() - - async def __run_as_async(self, coro: Awaitable[T]) -> T: - """Run an async coroutine asynchronously""" - - # If a loop has not been provided, attempt to run in current thread. - if not self.__loop: - return await coro - - # Otherwise, run in the background thread. - return await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__loop) - ) + self.__core_client = ToolboxCoreSyncClient(url=url) async def aload_tool( self, tool_name: str, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, ) -> ToolboxTool: """ Loads the tool with the given tool name from the Toolbox service. Args: tool_name: The name of the tool to load. - auth_tokens: An optional mapping of authentication source names to - functions that retrieve ID tokens. - auth_headers: Deprecated. Use `auth_tokens` instead. + auth_token_getters: An optional mapping of authentication source + names to functions that retrieve ID tokens. + auth_tokens: Deprecated. Use `auth_token_getters` instead. + auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. Returns: A tool loaded from the Toolbox. """ - async_tool = await self.__run_as_async( - self.__async_client.aload_tool( - tool_name, auth_tokens, auth_headers, bound_params, strict - ) + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens + + if auth_headers: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_headers + + core_tool = await to_thread( + self.__core_client.load_tool, + name=tool_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, ) - - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - return ToolboxTool(async_tool, self.__loop, self.__thread) + return ToolboxTool(core_tool=core_tool) async def aload_toolset( self, toolset_name: Optional[str] = None, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + strict: bool = False, ) -> list[ToolboxTool]: """ Loads tools from the Toolbox service, optionally filtered by toolset @@ -135,74 +109,124 @@ async def aload_toolset( Args: toolset_name: The name of the toolset to load. If not provided, all tools are loaded. - auth_tokens: An optional mapping of authentication source names to - functions that retrieve ID tokens. - auth_headers: Deprecated. Use `auth_tokens` instead. + auth_token_getters: An optional mapping of authentication source + names to functions that retrieve ID tokens. + auth_tokens: Deprecated. Use `auth_token_getters` instead. + auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. + strict: If True, raises an error if *any* loaded tool instance fails + to utilize at least one provided parameter or auth token (if any + provided). If False (default), raises an error only if a + user-provided parameter or auth token cannot be applied to *any* + loaded tool across the set. Returns: A list of all tools loaded from the Toolbox. """ - async_tools = await self.__run_as_async( - self.__async_client.aload_toolset( - toolset_name, auth_tokens, auth_headers, bound_params, strict - ) + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens + + if auth_headers: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_headers + + core_tools = await to_thread( + self.__core_client.load_toolset, + name=toolset_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=strict, ) - tools: list[ToolboxTool] = [] - - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - for async_tool in async_tools: - tools.append(ToolboxTool(async_tool, self.__loop, self.__thread)) + tools = [] + for core_tool in core_tools: + tools.append(ToolboxTool(core_tool=core_tool)) return tools def load_tool( self, tool_name: str, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, ) -> ToolboxTool: """ Loads the tool with the given tool name from the Toolbox service. Args: tool_name: The name of the tool to load. - auth_tokens: An optional mapping of authentication source names to - functions that retrieve ID tokens. - auth_headers: Deprecated. Use `auth_tokens` instead. + auth_token_getters: An optional mapping of authentication source + names to functions that retrieve ID tokens. + auth_tokens: Deprecated. Use `auth_token_getters` instead. + auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. Returns: A tool loaded from the Toolbox. """ - async_tool = self.__run_as_sync( - self.__async_client.aload_tool( - tool_name, auth_tokens, auth_headers, bound_params, strict - ) + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens + + if auth_headers: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_headers + + core_sync_tool = self.__core_client.load_tool( + name=tool_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, ) - - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - return ToolboxTool(async_tool, self.__loop, self.__thread) + return ToolboxTool(core_tool=core_sync_tool) def load_toolset( self, toolset_name: Optional[str] = None, - auth_tokens: dict[str, Callable[[], str]] = {}, + auth_token_getters: dict[str, Callable[[], str]] = {}, + auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + strict: bool = False, ) -> list[ToolboxTool]: """ Loads tools from the Toolbox service, optionally filtered by toolset @@ -211,27 +235,55 @@ def load_toolset( Args: toolset_name: The name of the toolset to load. If not provided, all tools are loaded. - auth_tokens: An optional mapping of authentication source names to - functions that retrieve ID tokens. - auth_headers: Deprecated. Use `auth_tokens` instead. + auth_token_getters: An optional mapping of authentication source + names to functions that retrieve ID tokens. + auth_tokens: Deprecated. Use `auth_token_getters` instead. + auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. + strict: If True, raises an error if *any* loaded tool instance fails + to utilize at least one provided parameter or auth token (if any + provided). If False (default), raises an error only if a + user-provided parameter or auth token cannot be applied to *any* + loaded tool across the set. Returns: A list of all tools loaded from the Toolbox. """ - async_tools = self.__run_as_sync( - self.__async_client.aload_toolset( - toolset_name, auth_tokens, auth_headers, bound_params, strict - ) + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens + + if auth_headers: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_headers + + core_sync_tools = self.__core_client.load_toolset( + name=toolset_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=strict, ) - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - tools: list[ToolboxTool] = [] - for async_tool in async_tools: - tools.append(ToolboxTool(async_tool, self.__loop, self.__thread)) + tools = [] + for core_sync_tool in core_sync_tools: + tools.append(ToolboxTool(core_tool=core_sync_tool)) return tools diff --git a/packages/toolbox-llamaindex/src/toolbox_llamaindex/py.typed b/packages/toolbox-llamaindex/src/toolbox_llamaindex/py.typed index 8b137891..e69de29b 100644 --- a/packages/toolbox-llamaindex/src/toolbox_llamaindex/py.typed +++ b/packages/toolbox-llamaindex/src/toolbox_llamaindex/py.typed @@ -1 +0,0 @@ - diff --git a/packages/toolbox-llamaindex/src/toolbox_llamaindex/tools.py b/packages/toolbox-llamaindex/src/toolbox_llamaindex/tools.py index 00690dca..841e9427 100644 --- a/packages/toolbox-llamaindex/src/toolbox_llamaindex/tools.py +++ b/packages/toolbox-llamaindex/src/toolbox_llamaindex/tools.py @@ -12,17 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -from asyncio import AbstractEventLoop -from threading import Thread -from typing import Any, Awaitable, Callable, TypeVar, Union +from asyncio import to_thread +from typing import Any, Callable, Union +from deprecated import deprecated from llama_index.core.tools import ToolMetadata from llama_index.core.tools.types import AsyncBaseTool, ToolOutput - -from .async_tools import AsyncToolboxTool - -T = TypeVar("T") +from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool +from toolbox_core.utils import params_to_pydantic_model class ToolboxTool(AsyncBaseTool): @@ -33,92 +30,75 @@ class ToolboxTool(AsyncBaseTool): def __init__( self, - async_tool: AsyncToolboxTool, - loop: AbstractEventLoop, - thread: Thread, + core_tool: ToolboxCoreSyncTool, ) -> None: """ Initializes a ToolboxTool instance. Args: - async_tool: The underlying AsyncToolboxTool instance. - loop: The event loop used to run asynchronous tasks. - thread: The thread to run blocking operations in. + core_tool: The underlying core sync ToolboxTool instance. """ - # Due to how pydantic works, we must initialize the underlying # AsyncBaseTool class before assigning values to member variables. super().__init__() - self.__async_tool = async_tool - self.__loop = loop - self.__thread = thread - - def __run_as_sync(self, coro: Awaitable[T]) -> T: - """Run an async coroutine synchronously""" - if not self.__loop: - raise Exception( - "Cannot call synchronous methods before the background loop is initialized." - ) - return asyncio.run_coroutine_threadsafe(coro, self.__loop).result() - - async def __run_as_async(self, coro: Awaitable[T]) -> T: - """Run an async coroutine asynchronously""" - - # If a loop has not been provided, attempt to run in current thread. - if not self.__loop: - return await coro - - # Otherwise, run in the background thread. - return await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__loop) - ) + self.__core_tool = core_tool @property def metadata(self) -> ToolMetadata: - async_tool = self.__async_tool + if self.__core_tool.__doc__ is None: + raise ValueError("No description found for the tool.") + return ToolMetadata( - name=async_tool.metadata.name, - description=async_tool.metadata.description, - fn_schema=async_tool.metadata.fn_schema, + name=self.__core_tool.__name__, + description=self.__core_tool.__doc__, + fn_schema=params_to_pydantic_model( + self.__core_tool._name, self.__core_tool._params + ), ) def call(self, **kwargs: Any) -> ToolOutput: # type: ignore - return self.__run_as_sync(self.__async_tool.acall(**kwargs)) + output_content = self.__core_tool(**kwargs) + return ToolOutput( + content=output_content, + tool_name=self.__core_tool.__name__, + raw_input=kwargs, + raw_output=output_content, + ) async def acall(self, **kwargs: Any) -> ToolOutput: # type: ignore - return await self.__run_as_async(self.__async_tool.acall(**kwargs)) + output_content = await to_thread(self.__core_tool, **kwargs) + return ToolOutput( + content=output_content, + tool_name=self.__core_tool.__name__, + raw_input=kwargs, + raw_output=output_content, + ) - def add_auth_tokens( - self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True + def add_auth_token_getters( + self, auth_token_getters: dict[str, Callable[[], str]] ) -> "ToolboxTool": """ Registers functions to retrieve ID tokens for the corresponding authentication sources. Args: - auth_tokens: A dictionary of authentication source names to the - functions that return corresponding ID token. - strict: If True, a ValueError is raised if any of the provided auth - tokens are already bound. If False, only a warning is issued. + auth_token_getters: A dictionary of authentication source names to + the functions that return corresponding ID token. Returns: A new ToolboxTool instance that is a deep copy of the current - instance, with added auth tokens. + instance, with added auth token getters. Raises: - ValueError: If the provided auth tokens are already registered. - ValueError: If the provided auth tokens are already bound and strict - is True. + ValueError: If any of the provided auth parameters is already + registered. """ - return ToolboxTool( - self.__async_tool.add_auth_tokens(auth_tokens, strict), - self.__loop, - self.__thread, - ) + new_core_tool = self.__core_tool.add_auth_token_getters(auth_token_getters) + return ToolboxTool(core_tool=new_core_tool) - def add_auth_token( - self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + def add_auth_token_getter( + self, auth_source: str, get_id_token: Callable[[], str] ) -> "ToolboxTool": """ Registers a function to retrieve an ID token for a given authentication @@ -127,28 +107,31 @@ def add_auth_token( Args: auth_source: The name of the authentication source. get_id_token: A function that returns the ID token. - strict: If True, a ValueError is raised if any of the provided auth - token is already bound. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current - instance, with added auth token. + instance, with added auth token getter. Raises: - ValueError: If the provided auth token is already registered. - ValueError: If the provided auth token is already bound and strict - is True. + ValueError: If the provided auth parameter is already registered. """ - return ToolboxTool( - self.__async_tool.add_auth_token(auth_source, get_id_token, strict), - self.__loop, - self.__thread, - ) + return self.add_auth_token_getters({auth_source: get_id_token}) + + @deprecated("Please use `add_auth_token_getters` instead.") + def add_auth_tokens( + self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True + ) -> "ToolboxTool": + return self.add_auth_token_getters(auth_tokens) + + @deprecated("Please use `add_auth_token_getter` instead.") + def add_auth_token( + self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + ) -> "ToolboxTool": + return self.add_auth_token_getter(auth_source, get_id_token) def bind_params( self, bound_params: dict[str, Union[Any, Callable[[], Any]]], - strict: bool = True, ) -> "ToolboxTool": """ Registers values or functions to retrieve the value for the @@ -157,30 +140,21 @@ def bind_params( Args: bound_params: A dictionary of the bound parameter name to the value or function of the bound value. - strict: If True, a ValueError is raised if any of the provided bound - params are not defined in the tool's schema, or require - authentication. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current instance, with added bound params. Raises: - ValueError: If the provided bound params are already bound. - ValueError: if the provided bound params are not defined in the tool's schema, or require - authentication, and strict is True. + ValueError: If any of the provided bound params is already bound. """ - return ToolboxTool( - self.__async_tool.bind_params(bound_params, strict), - self.__loop, - self.__thread, - ) + new_core_tool = self.__core_tool.bind_params(bound_params) + return ToolboxTool(core_tool=new_core_tool) def bind_param( self, param_name: str, param_value: Union[Any, Callable[[], Any]], - strict: bool = True, ) -> "ToolboxTool": """ Registers a value or a function to retrieve the value for a given bound @@ -190,9 +164,6 @@ def bind_param( param_name: The name of the bound parameter. param_value: The value of the bound parameter, or a callable that returns the value. - strict: If True, a ValueError is raised if any of the provided bound - params is not defined in the tool's schema, or requires - authentication. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current @@ -200,11 +171,5 @@ def bind_param( Raises: ValueError: If the provided bound param is already bound. - ValueError: if the provided bound param is not defined in the tool's - schema, or requires authentication, and strict is True. """ - return ToolboxTool( - self.__async_tool.bind_param(param_name, param_value, strict), - self.__loop, - self.__thread, - ) + return self.bind_params({param_name: param_value}) diff --git a/packages/toolbox-llamaindex/src/toolbox_llamaindex/utils.py b/packages/toolbox-llamaindex/src/toolbox_llamaindex/utils.py deleted file mode 100644 index 54c55e30..00000000 --- a/packages/toolbox-llamaindex/src/toolbox_llamaindex/utils.py +++ /dev/null @@ -1,262 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from typing import Any, Callable, Optional, Type, cast -from warnings import warn - -from aiohttp import ClientSession -from deprecated import deprecated -from pydantic import BaseModel, Field, create_model - - -class ParameterSchema(BaseModel): - """ - Schema for a tool parameter. - """ - - name: str - type: str - description: str - authSources: Optional[list[str]] = None - items: Optional["ParameterSchema"] = None - - -class ToolSchema(BaseModel): - """ - Schema for a tool. - """ - - description: str - parameters: list[ParameterSchema] - - -class ManifestSchema(BaseModel): - """ - Schema for the Toolbox manifest. - """ - - serverVersion: str - tools: dict[str, ToolSchema] - - -async def _load_manifest(url: str, session: ClientSession) -> ManifestSchema: - """ - Asynchronously fetches and parses the JSON manifest schema from the given - URL. - - Args: - url: The URL to fetch the JSON from. - session: The HTTP client session. - - Returns: - The parsed Toolbox manifest. - - Raises: - json.JSONDecodeError: If the response is not valid JSON. - ValueError: If the response is not a valid manifest. - """ - async with session.get(url) as response: - # TODO: Remove as it masks error messages. - response.raise_for_status() - try: - # TODO: Simply use response.json() - parsed_json = json.loads(await response.text()) - except json.JSONDecodeError as e: - raise json.JSONDecodeError( - f"Failed to parse JSON from {url}: {e}", e.doc, e.pos - ) from e - try: - return ManifestSchema(**parsed_json) - except ValueError as e: - raise ValueError(f"Invalid JSON data from {url}: {e}") from e - - -def _schema_to_model(model_name: str, schema: list[ParameterSchema]) -> Type[BaseModel]: - """ - Converts the given manifest schema to a Pydantic BaseModel class. - - Args: - model_name: The name of the model to create. - schema: The schema to convert. - - Returns: - A Pydantic BaseModel class. - """ - field_definitions = {} - for field in schema: - field_definitions[field.name] = cast( - Any, - ( - _parse_type(field), - Field(description=field.description), - ), - ) - - return create_model(model_name, **field_definitions) - - -def _parse_type(schema_: ParameterSchema) -> Any: - """ - Converts a schema type to a JSON type. - - Args: - schema_: The ParameterSchema to convert. - - Returns: - A valid JSON type. - - Raises: - ValueError: If the given type is not supported. - """ - type_ = schema_.type - - if type_ == "string": - return str - elif type_ == "integer": - return int - elif type_ == "float": - return float - elif type_ == "boolean": - return bool - elif type_ == "array": - if isinstance(schema_, ParameterSchema) and schema_.items: - return list[_parse_type(schema_.items)] # type: ignore - else: - raise ValueError(f"Schema missing field items") - else: - raise ValueError(f"Unsupported schema type: {type_}") - - -@deprecated("Please use `_get_auth_tokens` instead.") -def _get_auth_headers(id_token_getters: dict[str, Callable[[], str]]) -> dict[str, str]: - """ - Deprecated. Use `_get_auth_tokens` instead. - """ - return _get_auth_tokens(id_token_getters) - - -def _get_auth_tokens(id_token_getters: dict[str, Callable[[], str]]) -> dict[str, str]: - """ - Gets ID tokens for the given auth sources in the getters map and returns - tokens to be included in tool invocation. - - Args: - id_token_getters: A dict that maps auth source names to the functions - that return its ID token. - - Returns: - A dictionary of tokens to be included in the tool invocation. - """ - auth_tokens = {} - for auth_source, get_id_token in id_token_getters.items(): - auth_tokens[f"{auth_source}_token"] = get_id_token() - return auth_tokens - - -async def _invoke_tool( - url: str, - session: ClientSession, - tool_name: str, - data: dict, - id_token_getters: dict[str, Callable[[], str]], -) -> dict: - """ - Asynchronously makes an API call to the Toolbox service to invoke a tool. - - Args: - url: The base URL of the Toolbox service. - session: The HTTP client session. - tool_name: The name of the tool to invoke. - data: The input data for the tool. - id_token_getters: A dict that maps auth source names to the functions - that return its ID token. - - Returns: - A dictionary containing the parsed JSON response from the tool - invocation. - """ - url = f"{url}/api/tool/{tool_name}/invoke" - auth_tokens = _get_auth_tokens(id_token_getters) - - # ID tokens contain sensitive user information (claims). Transmitting these - # over HTTP exposes the data to interception and unauthorized access. Always - # use HTTPS to ensure secure communication and protect user privacy. - if auth_tokens and not url.startswith("https://"): - warn( - "Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication." - ) - - async with session.post( - url, - json=data, - headers=auth_tokens, - ) as response: - # TODO: Remove as it masks error messages. - response.raise_for_status() - return await response.json() - - -def _find_auth_params( - params: list[ParameterSchema], -) -> tuple[list[ParameterSchema], list[ParameterSchema]]: - """ - Separates parameters into those that are authenticated and those that are not. - - Args: - params: A list of ParameterSchema objects. - - Returns: - A tuple containing two lists: - - auth_params: A list of ParameterSchema objects that require authentication. - - non_auth_params: A list of ParameterSchema objects that do not require authentication. - """ - _auth_params: list[ParameterSchema] = [] - _non_auth_params: list[ParameterSchema] = [] - - for param in params: - if param.authSources: - _auth_params.append(param) - else: - _non_auth_params.append(param) - - return (_auth_params, _non_auth_params) - - -def _find_bound_params( - params: list[ParameterSchema], bound_params: list[str] -) -> tuple[list[ParameterSchema], list[ParameterSchema]]: - """ - Separates parameters into those that are bound and those that are not. - - Args: - params: A list of ParameterSchema objects. - bound_params: A list of parameter names that are bound. - - Returns: - A tuple containing two lists: - - bound_params: A list of ParameterSchema objects whose names are in the bound_params list. - - non_bound_params: A list of ParameterSchema objects whose names are not in the bound_params list. - """ - - _bound_params: list[ParameterSchema] = [] - _non_bound_params: list[ParameterSchema] = [] - - for param in params: - if param.name in bound_params: - _bound_params.append(param) - else: - _non_bound_params.append(param) - - return (_bound_params, _non_bound_params) diff --git a/packages/toolbox-llamaindex/tests/test_async_client.py b/packages/toolbox-llamaindex/tests/test_async_client.py index cdfd2cbc..9139908d 100644 --- a/packages/toolbox-llamaindex/tests/test_async_client.py +++ b/packages/toolbox-llamaindex/tests/test_async_client.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio from unittest.mock import AsyncMock, patch from warnings import catch_warnings, simplefilter import pytest from aiohttp import ClientSession +from toolbox_core.client import ToolboxClient as ToolboxCoreClient +from toolbox_core.protocol import ParameterSchema as CoreParameterSchema +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool from toolbox_llamaindex.async_client import AsyncToolboxClient from toolbox_llamaindex.async_tools import AsyncToolboxTool -from toolbox_llamaindex.utils import ManifestSchema URL = "http://test_url" MANIFEST_JSON = { @@ -53,131 +54,277 @@ @pytest.mark.asyncio class TestAsyncToolboxClient: - @pytest.fixture() - def manifest_schema(self): - return ManifestSchema(**MANIFEST_JSON) - @pytest.fixture() def mock_session(self): return AsyncMock(spec=ClientSession) + @pytest.fixture + def mock_core_client_instance(self, mock_session): + mock = AsyncMock(spec=ToolboxCoreClient) + + async def mock_load_tool_impl(name, auth_token_getters, bound_params): + tool_schema_dict = MANIFEST_JSON["tools"].get(name) + if not tool_schema_dict: + raise ValueError(f"Tool '{name}' not in mock manifest_dict") + + core_params = [ + CoreParameterSchema(**p) for p in tool_schema_dict["parameters"] + ] + # Return a mock that looks like toolbox_core.tool.ToolboxTool + core_tool_mock = AsyncMock(spec=ToolboxCoreTool) + core_tool_mock.__name__ = name + core_tool_mock.__doc__ = tool_schema_dict["description"] + core_tool_mock._name = name + core_tool_mock._params = core_params + # Add other necessary attributes or method mocks if AsyncToolboxTool uses them + return core_tool_mock + + mock.load_tool = AsyncMock(side_effect=mock_load_tool_impl) + + async def mock_load_toolset_impl( + name, auth_token_getters, bound_params, strict + ): + core_tools_list = [] + for tool_name_iter, tool_schema_dict in MANIFEST_JSON["tools"].items(): + core_params = [ + CoreParameterSchema(**p) for p in tool_schema_dict["parameters"] + ] + core_tool_mock = AsyncMock(spec=ToolboxCoreTool) + core_tool_mock.__name__ = tool_name_iter + core_tool_mock.__doc__ = tool_schema_dict["description"] + core_tool_mock._name = tool_name_iter + core_tool_mock._params = core_params + core_tools_list.append(core_tool_mock) + return core_tools_list + + mock.load_toolset = AsyncMock(side_effect=mock_load_toolset_impl) + # Mock the session attribute if it's directly accessed by AsyncToolboxClient tests + mock._ToolboxClient__session = mock_session + return mock + @pytest.fixture() - def mock_client(self, mock_session): - return AsyncToolboxClient(URL, session=mock_session) + def mock_client(self, mock_session, mock_core_client_instance): + # Patch the ToolboxCoreClient constructor used by AsyncToolboxClient + with patch( + "toolbox_llamaindex.async_client.ToolboxCoreClient", + return_value=mock_core_client_instance, + ): + client = AsyncToolboxClient(URL, session=mock_session) + # Ensure the mocked core client is used + client._AsyncToolboxClient__core_client = mock_core_client_instance + return client async def test_create_with_existing_session(self, mock_client, mock_session): - assert mock_client._AsyncToolboxClient__session == mock_session + # AsyncToolboxClient stores the core_client, which stores the session + assert ( + mock_client._AsyncToolboxClient__core_client._ToolboxClient__session + == mock_session + ) - @patch("toolbox_llamaindex.async_client._load_manifest") async def test_aload_tool( - self, mock_load_manifest, mock_client, mock_session, manifest_schema + self, + mock_client, ): tool_name = "test_tool_1" - mock_load_manifest.return_value = manifest_schema + test_bound_params = {"bp1": "value1"} - tool = await mock_client.aload_tool(tool_name) + tool = await mock_client.aload_tool(tool_name, bound_params=test_bound_params) - mock_load_manifest.assert_called_once_with( - f"{URL}/api/tool/{tool_name}", mock_session + # Assert that the core client's load_tool was called correctly + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, auth_token_getters={}, bound_params=test_bound_params ) assert isinstance(tool, AsyncToolboxTool) - assert tool._AsyncToolboxTool__name == tool_name + assert ( + tool.metadata.name == tool_name + ) # AsyncToolboxTool gets its name from the core_tool - @patch("toolbox_llamaindex.async_client._load_manifest") - async def test_aload_tool_auth_headers_deprecated( - self, mock_load_manifest, mock_client, manifest_schema - ): + async def test_aload_tool_auth_headers_deprecated(self, mock_client): tool_name = "test_tool_1" - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + auth_lambda = lambda: "Bearer token" with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_tool( - tool_name, auth_headers={"Authorization": lambda: "Bearer token"} + tool_name, + auth_headers={"Authorization": auth_lambda}, ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) assert "auth_headers" in str(w[-1].message) + assert "Use `auth_token_getters` instead" in str(w[-1].message) - @patch("toolbox_llamaindex.async_client._load_manifest") - async def test_aload_tool_auth_headers_and_tokens( - self, mock_load_manifest, mock_client, manifest_schema - ): + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, + auth_token_getters={"Authorization": auth_lambda}, + bound_params={}, + ) + + async def test_aload_tool_auth_headers_and_getters_precedence(self, mock_client): tool_name = "test_tool_1" - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + auth_getters = {"test_source": lambda: "id_token_from_getters"} + auth_headers_lambda = lambda: "Bearer token_from_headers" + with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_tool( tool_name, - auth_headers={"Authorization": lambda: "Bearer token"}, - auth_tokens={"test": lambda: "token"}, + auth_headers={"Authorization": auth_headers_lambda}, + auth_token_getters=auth_getters, ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) assert "auth_headers" in str(w[-1].message) + assert "`auth_token_getters` will be used" in str(w[-1].message) - @patch("toolbox_llamaindex.async_client._load_manifest") - async def test_aload_toolset( - self, mock_load_manifest, mock_client, mock_session, manifest_schema - ): - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest - tools = await mock_client.aload_toolset() + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, auth_token_getters=auth_getters, bound_params={} + ) - mock_load_manifest.assert_called_once_with(f"{URL}/api/toolset/", mock_session) + async def test_aload_tool_auth_tokens_deprecated(self, mock_client): + tool_name = "test_tool_1" + token_lambda = lambda: "id_token" + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_tool( + tool_name, + auth_tokens={"some_token_key": token_lambda}, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_tokens" in str(w[-1].message) + assert "Use `auth_token_getters` instead" in str(w[-1].message) + + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, + auth_token_getters={"some_token_key": token_lambda}, + bound_params={}, + ) + + async def test_aload_tool_auth_tokens_and_getters_precedence(self, mock_client): + tool_name = "test_tool_1" + auth_getters = {"real_source": lambda: "token_from_getters"} + token_lambda = lambda: "token_from_auth_tokens" + + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_tool( + tool_name, + auth_tokens={"deprecated_source": token_lambda}, + auth_token_getters=auth_getters, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_tokens" in str(w[-1].message) + assert "`auth_token_getters` will be used" in str(w[-1].message) + + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, auth_token_getters=auth_getters, bound_params={} + ) + + async def test_aload_toolset(self, mock_client): + test_bound_params = {"bp_set": "value_set"} + tools = await mock_client.aload_toolset( + bound_params=test_bound_params, strict=True + ) + + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, + auth_token_getters={}, + bound_params=test_bound_params, + strict=True, + ) assert len(tools) == 2 for tool in tools: assert isinstance(tool, AsyncToolboxTool) - assert tool._AsyncToolboxTool__name in ["test_tool_1", "test_tool_2"] + assert tool.metadata.name in ["test_tool_1", "test_tool_2"] - @patch("toolbox_llamaindex.async_client._load_manifest") - async def test_aload_toolset_with_toolset_name( - self, mock_load_manifest, mock_client, mock_session, manifest_schema - ): + async def test_aload_toolset_with_toolset_name(self, mock_client): toolset_name = "test_toolset_1" - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest tools = await mock_client.aload_toolset(toolset_name=toolset_name) - mock_load_manifest.assert_called_once_with( - f"{URL}/api/toolset/{toolset_name}", mock_session + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=toolset_name, auth_token_getters={}, bound_params={}, strict=False ) assert len(tools) == 2 for tool in tools: assert isinstance(tool, AsyncToolboxTool) - assert tool._AsyncToolboxTool__name in ["test_tool_1", "test_tool_2"] - @patch("toolbox_llamaindex.async_client._load_manifest") - async def test_aload_toolset_auth_headers_deprecated( - self, mock_load_manifest, mock_client, manifest_schema + async def test_aload_toolset_auth_headers_deprecated(self, mock_client): + auth_lambda = lambda: "Bearer token" + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_toolset(auth_headers={"Authorization": auth_lambda}) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_headers" in str(w[-1].message) + assert "Use `auth_token_getters` instead" in str(w[-1].message) + + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, + auth_token_getters={"Authorization": auth_lambda}, + bound_params={}, + strict=False, + ) + + async def test_aload_toolset_auth_headers_and_getters_precedence( # Renamed for clarity + self, mock_client ): - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + auth_getters = {"test_source": lambda: "id_token_from_getters"} + auth_headers_lambda = lambda: "Bearer token_from_headers" with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_toolset( - auth_headers={"Authorization": lambda: "Bearer token"} + auth_headers={"Authorization": auth_headers_lambda}, + auth_token_getters=auth_getters, ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) assert "auth_headers" in str(w[-1].message) + assert "`auth_token_getters` will be used" in str(w[-1].message) - @patch("toolbox_llamaindex.async_client._load_manifest") - async def test_aload_toolset_auth_headers_and_tokens( - self, mock_load_manifest, mock_client, manifest_schema - ): - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, + auth_token_getters=auth_getters, + bound_params={}, + strict=False, # auth_getters takes precedence + ) + + async def test_aload_toolset_auth_tokens_deprecated(self, mock_client): + token_lambda = lambda: "id_token" with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_toolset( - auth_headers={"Authorization": lambda: "Bearer token"}, - auth_tokens={"test": lambda: "token"}, + auth_tokens={"some_token_key": token_lambda} ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) - assert "auth_headers" in str(w[-1].message) + assert "auth_tokens" in str(w[-1].message) + assert "Use `auth_token_getters` instead" in str(w[-1].message) + + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, + auth_token_getters={"some_token_key": token_lambda}, + bound_params={}, + strict=False, + ) + + async def test_aload_toolset_auth_tokens_and_getters_precedence(self, mock_client): + auth_getters = {"real_source": lambda: "token_from_getters"} + token_lambda = lambda: "token_from_auth_tokens" + with catch_warnings(record=True) as w: + simplefilter("always") + await mock_client.aload_toolset( + auth_tokens={"deprecated_source": token_lambda}, + auth_token_getters=auth_getters, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "auth_tokens" in str(w[-1].message) + assert "`auth_token_getters` will be used" in str(w[-1].message) + + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, auth_token_getters=auth_getters, bound_params={}, strict=False + ) async def test_load_tool_not_implemented(self, mock_client): with pytest.raises(NotImplementedError) as excinfo: diff --git a/packages/toolbox-llamaindex/tests/test_async_tools.py b/packages/toolbox-llamaindex/tests/test_async_tools.py index 16b891e5..251c88cd 100644 --- a/packages/toolbox-llamaindex/tests/test_async_tools.py +++ b/packages/toolbox-llamaindex/tests/test_async_tools.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import types # For MappingProxyType from unittest.mock import AsyncMock, Mock, patch import pytest import pytest_asyncio +from llama_index.core.tools.types import ToolOutput from pydantic import ValidationError +from toolbox_core.protocol import ParameterSchema as CoreParameterSchema +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool from toolbox_llamaindex.async_tools import AsyncToolboxTool @@ -24,7 +28,7 @@ @pytest.mark.asyncio class TestAsyncToolboxTool: @pytest.fixture - def tool_schema(self): + def tool_schema_dict(self): return { "description": "Test Tool Description", "parameters": [ @@ -34,9 +38,10 @@ def tool_schema(self): } @pytest.fixture - def auth_tool_schema(self): + def auth_tool_schema_dict(self): return { "description": "Test Tool Description", + "authRequired": ["test-auth-source"], "parameters": [ { "name": "param1", @@ -48,209 +53,327 @@ def auth_tool_schema(self): ], } + def _create_core_tool_from_dict( + self, session, name, schema_dict, url, initial_auth_getters=None + ): + core_params_schemas = [ + CoreParameterSchema(**p) for p in schema_dict["parameters"] + ] + + tool_constructor_params = [] + required_authn_for_core = {} + for p_schema in core_params_schemas: + if p_schema.authSources: + required_authn_for_core[p_schema.name] = p_schema.authSources + else: + tool_constructor_params.append(p_schema) + + return ToolboxCoreTool( + session=session, + base_url=url, + name=name, + description=schema_dict["description"], + params=tool_constructor_params, + required_authn_params=types.MappingProxyType(required_authn_for_core), + required_authz_tokens=schema_dict.get("authRequired", []), + auth_service_token_getters=types.MappingProxyType( + initial_auth_getters or {} + ), + bound_params=types.MappingProxyType({}), + client_headers=types.MappingProxyType({}), + ) + @pytest_asyncio.fixture @patch("aiohttp.ClientSession") - async def toolbox_tool(self, mock_client_session, tool_schema): - mock_session = mock_client_session.return_value - mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() - mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( - return_value={"result": "test-result"} - ) - tool = AsyncToolboxTool( + async def toolbox_tool(self, MockClientSession, tool_schema_dict): + mock_session = MockClientSession.return_value + mock_response = mock_session.post.return_value.__aenter__.return_value + mock_response.raise_for_status = Mock() + mock_response.json = AsyncMock(return_value={"result": "test-result"}) + mock_response.status = 200 # *** Fix: Set status for the mock response *** + + core_tool_instance = self._create_core_tool_from_dict( + session=mock_session, name="test_tool", - schema=tool_schema, + schema_dict=tool_schema_dict, url="http://test_url", - session=mock_session, ) + tool = AsyncToolboxTool(core_tool=core_tool_instance) return tool @pytest_asyncio.fixture @patch("aiohttp.ClientSession") - async def auth_toolbox_tool(self, mock_client_session, auth_tool_schema): - mock_session = mock_client_session.return_value - mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() - mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( - return_value={"result": "test-result"} + async def auth_toolbox_tool(self, MockClientSession, auth_tool_schema_dict): + mock_session = MockClientSession.return_value + mock_response = mock_session.post.return_value.__aenter__.return_value + mock_response.raise_for_status = Mock() + mock_response.json = AsyncMock(return_value={"result": "test-result"}) + mock_response.status = 200 # *** Fix: Set status for the mock response *** + + core_tool_instance = self._create_core_tool_from_dict( + session=mock_session, + name="test_tool", + schema_dict=auth_tool_schema_dict, + url="https://test-url", ) - with pytest.warns( - UserWarning, - match=r"Parameter\(s\) `param1` of tool test_tool require authentication", - ): - tool = AsyncToolboxTool( - name="test_tool", - schema=auth_tool_schema, - url="https://test-url", - session=mock_session, - ) + tool = AsyncToolboxTool(core_tool=core_tool_instance) return tool @patch("aiohttp.ClientSession") - async def test_toolbox_tool_init(self, mock_client_session, tool_schema): - mock_session = mock_client_session.return_value - tool = AsyncToolboxTool( + async def test_toolbox_tool_init(self, MockClientSession, tool_schema_dict): + mock_session = MockClientSession.return_value + mock_response = mock_session.post.return_value.__aenter__.return_value + mock_response.status = 200 + core_tool_instance = self._create_core_tool_from_dict( + session=mock_session, name="test_tool", - schema=tool_schema, + schema_dict=tool_schema_dict, url="https://test-url", - session=mock_session, ) + tool = AsyncToolboxTool(core_tool=core_tool_instance) assert tool.metadata.name == "test_tool" - assert tool.metadata.description == "Test Tool Description" + assert tool.metadata.description == core_tool_instance.__doc__ @pytest.mark.parametrize( - "params, expected_bound_params", + "params_to_bind", [ - ({"param1": "bound-value"}, {"param1": "bound-value"}), - ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), - ( - {"param1": "bound-value", "param2": 123}, - {"param1": "bound-value", "param2": 123}, - ), + ({"param1": "bound-value"}), + ({"param1": lambda: "bound-value"}), + ({"param1": "bound-value", "param2": 123}), ], ) - async def test_toolbox_tool_bind_params( - self, toolbox_tool, params, expected_bound_params - ): - tool = toolbox_tool.bind_params(params) - for key, value in expected_bound_params.items(): - if callable(value): - assert value() == tool._AsyncToolboxTool__bound_params[key]() - else: - assert value == tool._AsyncToolboxTool__bound_params[key] - - @pytest.mark.parametrize("strict", [True, False]) - async def test_toolbox_tool_bind_params_invalid(self, toolbox_tool, strict): - if strict: - with pytest.raises(ValueError) as e: - tool = toolbox_tool.bind_params( - {"param3": "bound-value"}, strict=strict - ) - assert "Parameter(s) param3 missing and cannot be bound." in str(e.value) - else: - with pytest.warns(UserWarning) as record: - tool = toolbox_tool.bind_params( - {"param3": "bound-value"}, strict=strict - ) - assert len(record) == 1 - assert "Parameter(s) param3 missing and cannot be bound." in str( - record[0].message + async def test_toolbox_tool_bind_params(self, toolbox_tool, params_to_bind): + original_core_tool = toolbox_tool._AsyncToolboxTool__core_tool + with patch.object( + original_core_tool, "bind_params", wraps=original_core_tool.bind_params + ) as mock_core_bind_params: + new_llamaindex_tool = toolbox_tool.bind_params(params_to_bind) + mock_core_bind_params.assert_called_once_with(params_to_bind) + assert isinstance( + new_llamaindex_tool._AsyncToolboxTool__core_tool, ToolboxCoreTool + ) + new_core_tool_signature_params = ( + new_llamaindex_tool._AsyncToolboxTool__core_tool.__signature__.parameters ) + for bound_param_name in params_to_bind.keys(): + assert bound_param_name not in new_core_tool_signature_params + + async def test_toolbox_tool_bind_params_invalid(self, toolbox_tool): + with pytest.raises( + ValueError, match="unable to bind parameters: no parameter named param3" + ): + toolbox_tool.bind_params({"param3": "bound-value"}) async def test_toolbox_tool_bind_params_duplicate(self, toolbox_tool): tool = toolbox_tool.bind_params({"param1": "bound-value"}) - with pytest.raises(ValueError) as e: - tool = tool.bind_params({"param1": "bound-value"}) - assert "Parameter(s) `param1` already bound in tool `test_tool`." in str( - e.value - ) + with pytest.raises( + ValueError, + match="cannot re-bind parameter: parameter 'param1' is already bound", + ): + tool.bind_params({"param1": "bound-value"}) async def test_toolbox_tool_bind_params_invalid_params(self, auth_toolbox_tool): - with pytest.raises(ValueError) as e: + auth_core_tool = auth_toolbox_tool._AsyncToolboxTool__core_tool + # Verify that 'param1' is not in the list of bindable parameters for the core tool + # because it requires authentication. + assert "param1" not in [p.name for p in auth_core_tool._ToolboxTool__params] + with pytest.raises( + ValueError, match="unable to bind parameters: no parameter named param1" + ): auth_toolbox_tool.bind_params({"param1": "bound-value"}) - assert "Parameter(s) param1 already authenticated and cannot be bound." in str( - e.value - ) - @pytest.mark.parametrize( - "auth_tokens, expected_auth_tokens", - [ - ( - {"test-auth-source": lambda: "test-token"}, - {"test-auth-source": lambda: "test-token"}, - ), - ( - { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", - }, - { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", - }, - ), - ], - ) - async def test_toolbox_tool_add_auth_tokens( - self, auth_toolbox_tool, auth_tokens, expected_auth_tokens - ): - tool = auth_toolbox_tool.add_auth_tokens(auth_tokens) - for source, getter in expected_auth_tokens.items(): - assert tool._AsyncToolboxTool__auth_tokens[source]() == getter() + async def test_toolbox_tool_add_valid_auth_token_getter(self, auth_toolbox_tool): + get_token_lambda = lambda: "test-token-value" + original_core_tool = auth_toolbox_tool._AsyncToolboxTool__core_tool + with patch.object( + original_core_tool, + "add_auth_token_getters", + wraps=original_core_tool.add_auth_token_getters, + ) as mock_core_add_getters: + tool = auth_toolbox_tool.add_auth_token_getters( + {"test-auth-source": get_token_lambda} + ) + mock_core_add_getters.assert_called_once_with( + {"test-auth-source": get_token_lambda} + ) + core_tool_after_add = tool._AsyncToolboxTool__core_tool + assert ( + "test-auth-source" + in core_tool_after_add._ToolboxTool__auth_service_token_getters + ) + assert ( + core_tool_after_add._ToolboxTool__auth_service_token_getters[ + "test-auth-source" + ] + is get_token_lambda + ) + assert not core_tool_after_add._ToolboxTool__required_authn_params.get( + "param1" + ) + assert ( + "test-auth-source" + not in core_tool_after_add._ToolboxTool__required_authz_tokens + ) - async def test_toolbox_tool_add_auth_tokens_duplicate(self, auth_toolbox_tool): - tool = auth_toolbox_tool.add_auth_tokens( - {"test-auth-source": lambda: "test-token"} + async def test_toolbox_tool_add_unused_auth_token_getter_raises_error( + self, auth_toolbox_tool + ): + unused_lambda = lambda: "another-token" + with pytest.raises(ValueError) as excinfo: + auth_toolbox_tool.add_auth_token_getters( + {"another-auth-source": unused_lambda} + ) + assert ( + "Authentication source(s) `another-auth-source` unused by tool `test_tool`" + in str(excinfo.value) ) - with pytest.raises(ValueError) as e: - tool = tool.add_auth_tokens({"test-auth-source": lambda: "test-token"}) + + valid_lambda = lambda: "test-token" + with pytest.raises(ValueError) as excinfo_mixed: + auth_toolbox_tool.add_auth_token_getters( + { + "test-auth-source": valid_lambda, + "another-auth-source": unused_lambda, + } + ) assert ( - "Authentication source(s) `test-auth-source` already registered in tool `test_tool`." - in str(e.value) + "Authentication source(s) `another-auth-source` unused by tool `test_tool`" + in str(excinfo_mixed.value) ) - async def test_toolbox_tool_validate_auth_strict(self, auth_toolbox_tool): - with pytest.raises(PermissionError) as e: - auth_toolbox_tool._AsyncToolboxTool__validate_auth(strict=True) - assert "Parameter(s) `param1` of tool test_tool require authentication" in str( - e.value + async def test_toolbox_tool_add_auth_token_getters_duplicate( + self, auth_toolbox_tool + ): + tool = auth_toolbox_tool.add_auth_token_getters( + {"test-auth-source": lambda: "test-token"} ) + with pytest.raises( + ValueError, + match="Authentication source\\(s\\) `test-auth-source` already registered in tool `test_tool`\\.", + ): + tool.add_auth_token_getters({"test-auth-source": lambda: "test-token"}) + + async def test_toolbox_tool_call_requires_auth_strict(self, auth_toolbox_tool): + with pytest.raises( + PermissionError, + match="One or more of the following authn services are required to invoke this tool: test-auth-source", + ): + await auth_toolbox_tool.acall(param2=123) async def test_toolbox_tool_call(self, toolbox_tool): result = await toolbox_tool.acall(param1="test-value", param2=123) - assert result.content == str({"result": "test-result"}) - toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + assert result == ToolOutput( + content="test-result", + tool_name="test_tool", + raw_input={"param1": "test-value", "param2": 123}, + raw_output="test-result", + is_error=False, + ) + + core_tool = toolbox_tool._AsyncToolboxTool__core_tool + core_tool._ToolboxTool__session.post.assert_called_once_with( "http://test_url/api/tool/test_tool/invoke", json={"param1": "test-value", "param2": 123}, headers={}, ) @pytest.mark.parametrize( - "bound_param, expected_value", + "bound_param_map, expected_value", [ ({"param1": "bound-value"}, "bound-value"), ({"param1": lambda: "dynamic-value"}, "dynamic-value"), ], ) async def test_toolbox_tool_call_with_bound_params( - self, toolbox_tool, bound_param, expected_value + self, toolbox_tool, bound_param_map, expected_value ): - tool = toolbox_tool.bind_params(bound_param) + tool = toolbox_tool.bind_params(bound_param_map) result = await tool.acall(param2=123) - assert result.content == str({"result": "test-result"}) - toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + assert result == ToolOutput( + content="test-result", + tool_name="test_tool", + raw_input={"param2": 123}, + raw_output="test-result", + is_error=False, + ) + core_tool = tool._AsyncToolboxTool__core_tool + core_tool._ToolboxTool__session.post.assert_called_once_with( "http://test_url/api/tool/test_tool/invoke", json={"param1": expected_value, "param2": 123}, headers={}, ) async def test_toolbox_tool_call_with_auth_tokens(self, auth_toolbox_tool): - tool = auth_toolbox_tool.add_auth_tokens( + tool = auth_toolbox_tool.add_auth_token_getters( {"test-auth-source": lambda: "test-token"} ) result = await tool.acall(param2=123) - assert result.content == str({"result": "test-result"}) - auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + assert result == ToolOutput( + content="test-result", + tool_name="test_tool", + raw_input={"param2": 123}, + raw_output="test-result", + is_error=False, + ) + + core_tool = tool._AsyncToolboxTool__core_tool + core_tool._ToolboxTool__session.post.assert_called_once_with( "https://test-url/api/tool/test_tool/invoke", json={"param2": 123}, headers={"test-auth-source_token": "test-token"}, ) - async def test_toolbox_tool_call_with_auth_tokens_insecure(self, auth_toolbox_tool): + async def test_toolbox_tool_call_with_auth_tokens_insecure( + self, auth_toolbox_tool, auth_tool_schema_dict + ): + core_tool_of_auth_tool = auth_toolbox_tool._AsyncToolboxTool__core_tool + mock_session = core_tool_of_auth_tool._ToolboxTool__session + with pytest.warns( UserWarning, match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", ): - auth_toolbox_tool._AsyncToolboxTool__url = "http://test-url" - tool = auth_toolbox_tool.add_auth_tokens( - {"test-auth-source": lambda: "test-token"} - ) - result = await tool.acall(param2=123) - assert result.content == str({"result": "test-result"}) - auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( - "http://test-url/api/tool/test_tool/invoke", - json={"param2": 123}, - headers={"test-auth-source_token": "test-token"}, + insecure_core_tool = self._create_core_tool_from_dict( + session=mock_session, + name="test_tool", + schema_dict=auth_tool_schema_dict, + url="http://test-url", ) + insecure_auth_langchain_tool = AsyncToolboxTool(core_tool=insecure_core_tool) + + tool_with_getter = insecure_auth_langchain_tool.add_auth_token_getters( + {"test-auth-source": lambda: "test-token"} + ) + result = await tool_with_getter.acall(param2=123) + assert result == ToolOutput( + content="test-result", + tool_name="test_tool", + raw_input={"param2": 123}, + raw_output="test-result", + is_error=False, + ) + + modified_core_tool_in_new_tool = tool_with_getter._AsyncToolboxTool__core_tool + assert ( + modified_core_tool_in_new_tool._ToolboxTool__base_url == "http://test-url" + ) + assert ( + modified_core_tool_in_new_tool._ToolboxTool__url + == "http://test-url/api/tool/test_tool/invoke" + ) + + modified_core_tool_in_new_tool._ToolboxTool__session.post.assert_called_once_with( + "http://test-url/api/tool/test_tool/invoke", + json={"param2": 123}, + headers={"test-auth-source_token": "test-token"}, + ) + + async def test_toolbox_tool_call_with_empty_input(self, toolbox_tool): + with pytest.raises(TypeError) as e: + await toolbox_tool.acall() + assert "missing a required argument: 'param1'" in str(e.value) + async def test_toolbox_tool_call_with_invalid_input(self, toolbox_tool): with pytest.raises(ValidationError) as e: await toolbox_tool.acall(param1=123, param2="invalid") @@ -258,13 +381,6 @@ async def test_toolbox_tool_call_with_invalid_input(self, toolbox_tool): assert "param1\n Input should be a valid string" in str(e.value) assert "param2\n Input should be a valid integer" in str(e.value) - async def test_toolbox_tool_call_with_empty_input(self, toolbox_tool): - with pytest.raises(ValidationError) as e: - await toolbox_tool.acall() - assert "2 validation errors for test_tool" in str(e.value) - assert "param1\n Field required" in str(e.value) - assert "param2\n Field required" in str(e.value) - async def test_toolbox_tool_run_not_implemented(self, toolbox_tool): with pytest.raises(NotImplementedError): toolbox_tool.call() diff --git a/packages/toolbox-llamaindex/tests/test_client.py b/packages/toolbox-llamaindex/tests/test_client.py index 842dae22..aec1509a 100644 --- a/packages/toolbox-llamaindex/tests/test_client.py +++ b/packages/toolbox-llamaindex/tests/test_client.py @@ -1,4 +1,4 @@ -# Copyright 2025 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,290 +16,409 @@ import pytest from pydantic import BaseModel +from toolbox_core.protocol import ParameterSchema as CoreParameterSchema +from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool +from toolbox_core.utils import params_to_pydantic_model -from toolbox_llamaindex.async_tools import AsyncToolboxTool from toolbox_llamaindex.client import ToolboxClient from toolbox_llamaindex.tools import ToolboxTool -from toolbox_llamaindex.utils import _schema_to_model URL = "http://test_url" -class TestToolboxClient: - @pytest.fixture - def tool_schema(self): - return { - "description": "Test Tool Description", - "parameters": [ - {"name": "param1", "type": "string", "description": "Param 1"}, - {"name": "param2", "type": "integer", "description": "Param 2"}, - ], - } +def create_mock_core_sync_tool( + name="mock-sync-tool", + doc="Mock sync description.", + model_name="MockSyncModel", + params=None, +): + mock_tool = Mock(spec=ToolboxCoreSyncTool) + mock_tool.__name__ = name + mock_tool.__doc__ = doc + mock_tool._name = model_name + if params is None: + mock_tool._params = [ + CoreParameterSchema(name="param1", type="string", description="Param 1") + ] + else: + mock_tool._params = params + return mock_tool + + +def assert_pydantic_models_equivalent( + model_cls1: type[BaseModel], model_cls2: type[BaseModel], expected_model_name: str +): + assert issubclass(model_cls1, BaseModel), "model_cls1 is not a Pydantic BaseModel" + assert issubclass(model_cls2, BaseModel), "model_cls2 is not a Pydantic BaseModel" + + assert ( + model_cls1.__name__ == expected_model_name + ), f"model_cls1 name mismatch: expected {expected_model_name}, got {model_cls1.__name__}" + assert ( + model_cls2.__name__ == expected_model_name + ), f"model_cls2 name mismatch: expected {expected_model_name}, got {model_cls2.__name__}" + + fields1 = model_cls1.model_fields + fields2 = model_cls2.model_fields + + assert ( + fields1.keys() == fields2.keys() + ), f"Field names mismatch: {fields1.keys()} != {fields2.keys()}" + + for field_name in fields1.keys(): + field_info1 = fields1[field_name] + field_info2 = fields2[field_name] + + assert ( + field_info1.annotation == field_info2.annotation + ), f"Field '{field_name}': Annotation mismatch ({field_info1.annotation} != {field_info2.annotation})" + assert ( + field_info1.description == field_info2.description + ), f"Field '{field_name}': Description mismatch ('{field_info1.description}' != '{field_info2.description}')" + is_required1 = ( + field_info1.is_required() + if hasattr(field_info1, "is_required") + else not field_info1.is_nullable() + ) + is_required2 = ( + field_info2.is_required() + if hasattr(field_info2, "is_required") + else not field_info2.is_nullable() + ) + assert ( + is_required1 == is_required2 + ), f"Field '{field_name}': Required status mismatch ({is_required1} != {is_required2})" + +class TestToolboxClient: @pytest.fixture() def toolbox_client(self): client = ToolboxClient(URL) assert isinstance(client, ToolboxClient) - assert client._ToolboxClient__async_client is not None + assert client._ToolboxClient__core_client is not None + return client - # Check that the background loop was created and started - assert client._ToolboxClient__loop is not None - assert client._ToolboxClient__loop.is_running() + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") + def test_load_tool(self, mock_core_load_tool, toolbox_client): + mock_core_tool_instance = create_mock_core_sync_tool( + name="test_tool_sync", + doc="Sync tool description.", + model_name="TestToolSyncModel", + params=[ + CoreParameterSchema( + name="sp1", type="integer", description="Sync Param 1" + ) + ], + ) + mock_core_load_tool.return_value = mock_core_tool_instance - return client + llamaindex_tool = toolbox_client.load_tool("test_tool") - @patch("toolbox_llamaindex.client.ToolboxTool.__init__", return_value=None) - @patch("toolbox_llamaindex.client.AsyncToolboxClient.aload_tool") - def test_load_tool( - self, mock_aload_tool, mock_toolbox_tool_init, toolbox_client, tool_schema - ): - mock_async_tool = Mock(spec=AsyncToolboxTool) - mock_async_tool._AsyncToolboxTool__name = "mock-tool" # Access the mangled name - mock_async_tool._AsyncToolboxTool__schema = ( - tool_schema # Access the mangled name + assert isinstance(llamaindex_tool, ToolboxTool) + assert llamaindex_tool.metadata.name == mock_core_tool_instance.__name__ + assert llamaindex_tool.metadata.description == mock_core_tool_instance.__doc__ + + # Generate the expected schema once for comparison + expected_args_schema = params_to_pydantic_model( + mock_core_tool_instance._name, mock_core_tool_instance._params ) - mock_aload_tool.return_value = mock_async_tool - tool = toolbox_client.load_tool("test_tool") - mock_toolbox_tool_init.assert_called_once_with( - mock_async_tool, - toolbox_client._ToolboxClient__loop, - toolbox_client._ToolboxClient__thread, + assert_pydantic_models_equivalent( + llamaindex_tool.metadata.fn_schema, + expected_args_schema, + mock_core_tool_instance._name, ) - assert ( - tool_schema["description"] - == mock_async_tool._AsyncToolboxTool__schema["description"] + mock_core_load_tool.assert_called_once_with( + name="test_tool", auth_token_getters={}, bound_params={} ) - mock_aload_tool.assert_called_once_with("test_tool", {}, None, {}, True) - @patch("toolbox_llamaindex.client.ToolboxTool.__init__", return_value=None) - @patch("toolbox_llamaindex.client.AsyncToolboxClient.aload_toolset") - def test_load_toolset( - self, mock_aload_toolset, mock_toolbox_tool_init, toolbox_client, tool_schema - ): - mock_async_tool1 = Mock(spec=AsyncToolboxTool) - mock_async_tool1._AsyncToolboxTool__name = "mock-tool-0" - mock_async_tool1._AsyncToolboxTool__schema = tool_schema - - mock_async_tool2 = Mock(spec=AsyncToolboxTool) - mock_async_tool2._AsyncToolboxTool__name = "mock-tool-1" - mock_async_tool2._AsyncToolboxTool__schema = tool_schema - mock_aload_toolset.return_value = [mock_async_tool1, mock_async_tool2] - - tools = toolbox_client.load_toolset() - assert len(tools) == 2 - mock_toolbox_tool_init.assert_any_call( - mock_async_tool1, - toolbox_client._ToolboxClient__loop, - toolbox_client._ToolboxClient__thread, + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") + def test_load_toolset(self, mock_core_load_toolset, toolbox_client): + mock_core_tool_instance1 = create_mock_core_sync_tool( + name="tool-0", doc="desc 0", model_name="T0Model" ) - mock_toolbox_tool_init.assert_any_call( - mock_async_tool2, - toolbox_client._ToolboxClient__loop, - toolbox_client._ToolboxClient__thread, + mock_core_tool_instance2 = create_mock_core_sync_tool( + name="tool-1", doc="desc 1", model_name="T1Model", params=[] ) - mock_aload_toolset.assert_called_once_with(None, {}, None, {}, True) + mock_core_load_toolset.return_value = [ + mock_core_tool_instance1, + mock_core_tool_instance2, + ] + + llamaindex_tools = toolbox_client.load_toolset() + assert len(llamaindex_tools) == 2 + + tool_instances_mocks = [mock_core_tool_instance1, mock_core_tool_instance2] + for i, tool_instance_mock in enumerate(tool_instances_mocks): + llamaindex_tool = llamaindex_tools[i] + assert isinstance(llamaindex_tool, ToolboxTool) + assert llamaindex_tool.metadata.name == tool_instance_mock.__name__ + assert llamaindex_tool.metadata.description == tool_instance_mock.__doc__ + + expected_args_schema = params_to_pydantic_model( + tool_instance_mock._name, tool_instance_mock._params + ) + assert_pydantic_models_equivalent( + llamaindex_tool.metadata.fn_schema, + expected_args_schema, + tool_instance_mock._name, + ) + + mock_core_load_toolset.assert_called_once_with( + name=None, auth_token_getters={}, bound_params={}, strict=False + ) @pytest.mark.asyncio - @patch("toolbox_llamaindex.client.ToolboxTool.__init__", return_value=None) - @patch("toolbox_llamaindex.client.AsyncToolboxClient.aload_tool") - async def test_aload_tool( - self, mock_aload_tool, mock_toolbox_tool_init, toolbox_client, tool_schema - ): - mock_async_tool = Mock(spec=AsyncToolboxTool) - mock_async_tool._AsyncToolboxTool__name = "mock-tool" # Access mangled name - mock_async_tool._AsyncToolboxTool__schema = tool_schema - mock_aload_tool.return_value = mock_async_tool - - tool = await toolbox_client.aload_tool("test_tool") - mock_toolbox_tool_init.assert_called_once_with( - mock_async_tool, - toolbox_client._ToolboxClient__loop, - toolbox_client._ToolboxClient__thread, + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") + async def test_aload_tool(self, mock_sync_core_load_tool, toolbox_client): + mock_core_sync_tool_instance = create_mock_core_sync_tool( + name="test_async_loaded_tool", + doc="Async loaded sync tool description.", + model_name="AsyncTestToolModel", ) + mock_sync_core_load_tool.return_value = mock_core_sync_tool_instance + + llamaindex_tool = await toolbox_client.aload_tool("test_tool") + assert isinstance(llamaindex_tool, ToolboxTool) + assert llamaindex_tool.metadata.name == mock_core_sync_tool_instance.__name__ assert ( - tool_schema["description"] - == mock_async_tool._AsyncToolboxTool__schema["description"] + llamaindex_tool.metadata.description == mock_core_sync_tool_instance.__doc__ ) - mock_aload_tool.assert_called_once_with("test_tool", {}, None, {}, True) - @pytest.mark.asyncio - @patch("toolbox_llamaindex.client.ToolboxTool.__init__", return_value=None) - @patch("toolbox_llamaindex.client.AsyncToolboxClient.aload_toolset") - async def test_aload_toolset( - self, mock_aload_toolset, mock_toolbox_tool_init, toolbox_client, tool_schema - ): - mock_async_tool1 = Mock(spec=AsyncToolboxTool) - mock_async_tool1._AsyncToolboxTool__name = "mock-tool-0" - mock_async_tool1._AsyncToolboxTool__schema = tool_schema - - mock_async_tool2 = Mock(spec=AsyncToolboxTool) - mock_async_tool2._AsyncToolboxTool__name = "mock-tool-1" - mock_async_tool2._AsyncToolboxTool__schema = tool_schema - - mock_aload_toolset.return_value = [mock_async_tool1, mock_async_tool2] - - tools = await toolbox_client.aload_toolset() - assert len(tools) == 2 - mock_toolbox_tool_init.assert_any_call( - mock_async_tool1, - toolbox_client._ToolboxClient__loop, - toolbox_client._ToolboxClient__thread, + expected_args_schema = params_to_pydantic_model( + mock_core_sync_tool_instance._name, mock_core_sync_tool_instance._params ) - mock_toolbox_tool_init.assert_any_call( - mock_async_tool2, - toolbox_client._ToolboxClient__loop, - toolbox_client._ToolboxClient__thread, + assert_pydantic_models_equivalent( + llamaindex_tool.metadata.fn_schema, + expected_args_schema, + mock_core_sync_tool_instance._name, ) - mock_aload_toolset.assert_called_once_with(None, {}, None, {}, True) - @patch("toolbox_llamaindex.client.ToolboxTool.__init__", return_value=None) - @patch("toolbox_llamaindex.client.AsyncToolboxClient.aload_tool") - def test_load_tool_with_args( - self, mock_aload_tool, mock_toolbox_tool_init, toolbox_client, tool_schema - ): - mock_async_tool = Mock(spec=AsyncToolboxTool) - mock_async_tool._AsyncToolboxTool__name = "mock-tool" - mock_async_tool._AsyncToolboxTool__schema = tool_schema - mock_aload_tool.return_value = mock_async_tool - - auth_tokens = {"token1": lambda: "value1"} - auth_headers = {"header1": lambda: "value2"} - bound_params = {"param1": "value3"} - - tool = toolbox_client.load_tool( - "test_tool_name", - auth_tokens=auth_tokens, - auth_headers=auth_headers, - bound_params=bound_params, - strict=False, - ) - mock_toolbox_tool_init.assert_called_once_with( - mock_async_tool, - toolbox_client._ToolboxClient__loop, - toolbox_client._ToolboxClient__thread, + mock_sync_core_load_tool.assert_called_once_with( + name="test_tool", auth_token_getters={}, bound_params={} ) - assert ( - tool_schema["description"] - == mock_async_tool._AsyncToolboxTool__schema["description"] + @pytest.mark.asyncio + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") + async def test_aload_toolset(self, mock_sync_core_load_toolset, toolbox_client): + mock_core_sync_tool1 = create_mock_core_sync_tool( + name="async-tool-0", doc="async desc 0", model_name="AT0Model" ) - mock_aload_tool.assert_called_once_with( - "test_tool_name", auth_tokens, auth_headers, bound_params, False + mock_core_sync_tool2 = create_mock_core_sync_tool( + name="async-tool-1", + doc="async desc 1", + model_name="AT1Model", + params=[CoreParameterSchema(name="p1", type="string", description="P1")], ) - @patch("toolbox_llamaindex.client.ToolboxTool.__init__", return_value=None) - @patch("toolbox_llamaindex.client.AsyncToolboxClient.aload_toolset") - def test_load_toolset_with_args( - self, mock_aload_toolset, mock_toolbox_tool_init, toolbox_client, tool_schema - ): - mock_async_tool1 = Mock(spec=AsyncToolboxTool) - mock_async_tool1._AsyncToolboxTool__name = "mock-tool-0" - mock_async_tool1._AsyncToolboxTool__schema = tool_schema - - mock_async_tool2 = Mock(spec=AsyncToolboxTool) - mock_async_tool2._AsyncToolboxTool__name = "mock-tool-1" - mock_async_tool2._AsyncToolboxTool__schema = tool_schema - - mock_aload_toolset.return_value = [mock_async_tool1, mock_async_tool2] - - auth_tokens = {"token1": lambda: "value1"} - auth_headers = {"header1": lambda: "value2"} - bound_params = {"param1": "value3"} - - tools = toolbox_client.load_toolset( - toolset_name="my_toolset", - auth_tokens=auth_tokens, - auth_headers=auth_headers, - bound_params=bound_params, - strict=False, + mock_sync_core_load_toolset.return_value = [ + mock_core_sync_tool1, + mock_core_sync_tool2, + ] + + llamaindex_tools = await toolbox_client.aload_toolset() + assert len(llamaindex_tools) == 2 + + tool_instances_mocks = [mock_core_sync_tool1, mock_core_sync_tool2] + for i, tool_instance_mock in enumerate(tool_instances_mocks): + llamaindex_tool = llamaindex_tools[i] + assert isinstance(llamaindex_tool, ToolboxTool) + assert llamaindex_tool.metadata.name == tool_instance_mock.__name__ + + expected_args_schema = params_to_pydantic_model( + tool_instance_mock._name, tool_instance_mock._params + ) + assert_pydantic_models_equivalent( + llamaindex_tool.metadata.fn_schema, + expected_args_schema, + tool_instance_mock._name, + ) + + mock_sync_core_load_toolset.assert_called_once_with( + name=None, auth_token_getters={}, bound_params={}, strict=False ) - assert len(tools) == 2 - mock_toolbox_tool_init.assert_any_call( - mock_async_tool1, - toolbox_client._ToolboxClient__loop, - toolbox_client._ToolboxClient__thread, + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") + def test_load_tool_with_args(self, mock_core_load_tool, toolbox_client): + mock_core_tool_instance = create_mock_core_sync_tool() + mock_core_load_tool.return_value = mock_core_tool_instance + + auth_token_getters = {"token_getter1": lambda: "value1"} + auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} + auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} + bound_params = {"param1": "value4"} + # Scenario 1: auth_token_getters takes precedence + with pytest.warns(DeprecationWarning) as record: + tool = toolbox_client.load_tool( + "test_tool_name", + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + ) + assert len(record) == 2 + messages = sorted([str(r.message) for r in record]) + # Warning for auth_headers when auth_token_getters is also present + assert ( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used." + in messages ) - mock_toolbox_tool_init.assert_any_call( - mock_async_tool2, - toolbox_client._ToolboxClient__loop, - toolbox_client._ToolboxClient__thread, + # Warning for auth_tokens when auth_token_getters is also present + assert ( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used." + in messages ) - mock_aload_toolset.assert_called_once_with( - "my_toolset", auth_tokens, auth_headers, bound_params, False + assert isinstance(tool, ToolboxTool) + mock_core_load_tool.assert_called_with( + name="test_tool_name", + auth_token_getters=auth_token_getters, + bound_params=bound_params, ) + mock_core_load_tool.reset_mock() + + # Scenario 2: auth_tokens and auth_headers provided, auth_token_getters is default (empty initially) + with pytest.warns(DeprecationWarning) as record: + toolbox_client.load_tool( + "test_tool_name_2", + auth_tokens=auth_tokens_deprecated, # This will be used for auth_token_getters + auth_headers=auth_headers_deprecated, # This will warn as auth_token_getters is now populated + bound_params=bound_params, + ) + assert len(record) == 2 + messages = sorted([str(r.message) for r in record]) - @pytest.mark.asyncio - @patch("toolbox_llamaindex.client.ToolboxTool.__init__", return_value=None) - @patch("toolbox_llamaindex.client.AsyncToolboxClient.aload_tool") - async def test_aload_tool_with_args( - self, mock_aload_tool, mock_toolbox_tool_init, toolbox_client, tool_schema - ): - mock_async_tool = Mock(spec=AsyncToolboxTool) - mock_async_tool._AsyncToolboxTool__name = "mock-tool" - mock_async_tool._AsyncToolboxTool__schema = tool_schema - mock_aload_tool.return_value = mock_async_tool + assert ( + messages[0] + == "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead." + ) + assert ( + messages[1] + == "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used." + ) - auth_tokens = {"token1": lambda: "value1"} - auth_headers = {"header1": lambda: "value2"} - bound_params = {"param1": "value3"} + expected_getters_for_call = auth_tokens_deprecated - tool = await toolbox_client.aload_tool( - "test_tool", auth_tokens, auth_headers, bound_params, False + mock_core_load_tool.assert_called_with( + name="test_tool_name_2", + auth_token_getters=expected_getters_for_call, + bound_params=bound_params, ) - mock_toolbox_tool_init.assert_called_once_with( - mock_async_tool, - toolbox_client._ToolboxClient__loop, - toolbox_client._ToolboxClient__thread, + mock_core_load_tool.reset_mock() + + with pytest.warns( + DeprecationWarning, + match="Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + ) as record: + toolbox_client.load_tool( + "test_tool_name_3", + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + ) + assert len(record) == 1 + + mock_core_load_tool.assert_called_with( + name="test_tool_name_3", + auth_token_getters=auth_headers_deprecated, + bound_params=bound_params, ) - assert ( - tool_schema["description"] - == mock_async_tool._AsyncToolboxTool__schema["description"] + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") + def test_load_toolset_with_args(self, mock_core_load_toolset, toolbox_client): + mock_core_tool_instance = create_mock_core_sync_tool(model_name="MySetModel") + mock_core_load_toolset.return_value = [mock_core_tool_instance] + + auth_token_getters = {"token_getter1": lambda: "value1"} + auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} + auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} + bound_params = {"param1": "value4"} + toolset_name = "my_toolset" + + with pytest.warns(DeprecationWarning) as record: + tools = toolbox_client.load_toolset( + toolset_name=toolset_name, + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + strict=True, + ) + assert len(record) == 2 + + assert len(tools) == 1 + assert isinstance(tools[0], ToolboxTool) + mock_core_load_toolset.assert_called_with( + name=toolset_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=True, ) - mock_aload_tool.assert_called_once_with( - "test_tool", auth_tokens, auth_headers, bound_params, False + + @pytest.mark.asyncio + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") + async def test_aload_tool_with_args(self, mock_sync_core_load_tool, toolbox_client): + mock_core_tool_instance = create_mock_core_sync_tool( + model_name="MyAsyncToolModel" + ) + mock_sync_core_load_tool.return_value = mock_core_tool_instance + + auth_token_getters = {"token_getter1": lambda: "value1"} + auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} + auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} + bound_params = {"param1": "value4"} + + with pytest.warns(DeprecationWarning) as record: + tool = await toolbox_client.aload_tool( + "test_tool", + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + ) + assert len(record) == 2 + + assert isinstance(tool, ToolboxTool) + mock_sync_core_load_tool.assert_called_with( + name="test_tool", + auth_token_getters=auth_token_getters, + bound_params=bound_params, ) @pytest.mark.asyncio - @patch("toolbox_llamaindex.client.ToolboxTool.__init__", return_value=None) - @patch("toolbox_llamaindex.client.AsyncToolboxClient.aload_toolset") + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") async def test_aload_toolset_with_args( - self, mock_aload_toolset, mock_toolbox_tool_init, toolbox_client, tool_schema + self, mock_sync_core_load_toolset, toolbox_client ): - mock_async_tool1 = Mock(spec=AsyncToolboxTool) - mock_async_tool1._AsyncToolboxTool__name = "mock-tool-0" - mock_async_tool1._AsyncToolboxTool__schema = tool_schema - - mock_async_tool2 = Mock(spec=AsyncToolboxTool) - mock_async_tool2._AsyncToolboxTool__name = "mock-tool-1" - mock_async_tool2._AsyncToolboxTool__schema = tool_schema - mock_aload_toolset.return_value = [mock_async_tool1, mock_async_tool2] - - auth_tokens = {"token1": lambda: "value1"} - auth_headers = {"header1": lambda: "value2"} - bound_params = {"param1": "value3"} - - tools = await toolbox_client.aload_toolset( - "my_toolset", auth_tokens, auth_headers, bound_params, False - ) - assert len(tools) == 2 - mock_toolbox_tool_init.assert_any_call( - mock_async_tool1, - toolbox_client._ToolboxClient__loop, - toolbox_client._ToolboxClient__thread, + mock_core_tool_instance = create_mock_core_sync_tool( + model_name="MyAsyncSetModel" ) - mock_toolbox_tool_init.assert_any_call( - mock_async_tool2, - toolbox_client._ToolboxClient__loop, - toolbox_client._ToolboxClient__thread, - ) - - mock_aload_toolset.assert_called_once_with( - "my_toolset", auth_tokens, auth_headers, bound_params, False + mock_sync_core_load_toolset.return_value = [mock_core_tool_instance] + + auth_token_getters = {"token_getter1": lambda: "value1"} + auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} + auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} + bound_params = {"param1": "value4"} + toolset_name = "my_async_toolset" + + with pytest.warns(DeprecationWarning) as record: + tools = await toolbox_client.aload_toolset( + toolset_name, + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + strict=True, + ) + assert len(record) == 2 + + assert len(tools) == 1 + assert isinstance(tools[0], ToolboxTool) + mock_sync_core_load_toolset.assert_called_with( + name=toolset_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=True, ) diff --git a/packages/toolbox-llamaindex/tests/test_e2e.py b/packages/toolbox-llamaindex/tests/test_e2e.py index 55b2e522..5f389b86 100644 --- a/packages/toolbox-llamaindex/tests/test_e2e.py +++ b/packages/toolbox-llamaindex/tests/test_e2e.py @@ -36,7 +36,6 @@ import pytest import pytest_asyncio -from aiohttp import ClientResponseError from pydantic import ValidationError from toolbox_llamaindex.client import ToolboxClient @@ -54,7 +53,7 @@ def toolbox(self): @pytest_asyncio.fixture(scope="function") async def get_n_rows_tool(self, toolbox): tool = await toolbox.aload_tool("get-n-rows") - assert tool._ToolboxTool__async_tool._AsyncToolboxTool__name == "get-n-rows" + assert tool._ToolboxTool__core_tool.__name__ == "get-n-rows" return tool #### Basic e2e tests @@ -71,7 +70,7 @@ async def test_aload_toolset_specific( toolset = await toolbox.aload_toolset(toolset_name) assert len(toolset) == expected_length for tool in toolset: - name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + name = tool._ToolboxTool__core_tool.__name__ assert name in expected_tools async def test_aload_toolset_all(self, toolbox): @@ -85,27 +84,25 @@ async def test_aload_toolset_all(self, toolbox): "get-row-by-content-auth", ] for tool in toolset: - name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + name = tool._ToolboxTool__core_tool.__name__ assert name in tool_names async def test_run_tool_async(self, get_n_rows_tool): response = await get_n_rows_tool.acall(num_rows="2") - result = response.content - assert "row1" in result - assert "row2" in result - assert "row3" not in result + assert "row1" in response.content + assert "row2" in response.content + assert "row3" not in response.content async def test_run_tool_sync(self, get_n_rows_tool): response = get_n_rows_tool.call(num_rows="2") - result = response.content - assert "row1" in result - assert "row2" in result - assert "row3" not in result + assert "row1" in response.content + assert "row2" in response.content + assert "row3" not in response.content async def test_run_tool_missing_params(self, get_n_rows_tool): - with pytest.raises(ValidationError, match="Field required"): + with pytest.raises(TypeError, match="missing a required argument: 'num_rows'"): await get_n_rows_tool.acall() async def test_run_tool_wrong_param_type(self, get_n_rows_tool): @@ -116,71 +113,78 @@ async def test_run_tool_wrong_param_type(self, get_n_rows_tool): @pytest.mark.asyncio async def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): """Tests running a tool that doesn't require auth, with auth provided.""" - tool = await toolbox.aload_tool( - "get-row-by-id", auth_tokens={"my-test-auth": lambda: auth_token2} - ) - response = await tool.acall(id="2") - assert "row2" in response.content + with pytest.raises( + ValueError, + match="Validation failed for tool 'get-row-by-id': unused auth tokens: my-test-auth.", + ): + await toolbox.aload_tool( + "get-row-by-id", + auth_token_getters={"my-test-auth": lambda: auth_token2}, + ) async def test_run_tool_no_auth(self, toolbox): """Tests running a tool requiring auth without providing auth.""" tool = await toolbox.aload_tool( "get-row-by-id-auth", ) - response = await tool.acall(id="2") - assert response.is_error == True - assert "401, message='Unauthorized'" in response.content - assert isinstance(response.raw_output, str) + with pytest.raises( + PermissionError, + match="One or more of the following authn services are required to invoke this tool: my-test-auth", + ): + await tool.acall(id="2") async def test_run_tool_wrong_auth(self, toolbox, auth_token2): """Tests running a tool with incorrect auth.""" tool = await toolbox.aload_tool( "get-row-by-id-auth", ) - auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token2) - response = await auth_tool.acall(id="2") - assert response.is_error == True - assert "401, message='Unauthorized'" in response.content - assert isinstance(response.raw_output, str) + auth_tool = tool.add_auth_token_getter("my-test-auth", lambda: auth_token2) + with pytest.raises( + Exception, + match="tool invocation not authorized. Please make sure your specify correct auth headers", + ): + await auth_tool.acall(id="2") async def test_run_tool_auth(self, toolbox, auth_token1): """Tests running a tool with correct auth.""" tool = await toolbox.aload_tool( "get-row-by-id-auth", ) - auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token1) + auth_tool = tool.add_auth_token_getter("my-test-auth", lambda: auth_token1) response = await auth_tool.acall(id="2") assert "row2" in response.content async def test_run_tool_param_auth_no_auth(self, toolbox): - """Tests runningP a tool with a param requiring auth, without auth.""" + """Tests running a tool with a param requiring auth, without auth.""" tool = await toolbox.aload_tool("get-row-by-email-auth") with pytest.raises( PermissionError, - match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.", + match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): await tool.acall(email="") async def test_run_tool_param_auth(self, toolbox, auth_token1): """Tests running a tool with a param requiring auth, with correct auth.""" tool = await toolbox.aload_tool( - "get-row-by-email-auth", auth_tokens={"my-test-auth": lambda: auth_token1} + "get-row-by-email-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, ) response = await tool.acall() - result = response.content - assert "row4" in result - assert "row5" in result - assert "row6" in result + assert "row4" in response.content + assert "row5" in response.content + assert "row6" in response.content async def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): """Tests running a tool with a param requiring auth, with insufficient auth.""" tool = await toolbox.aload_tool( - "get-row-by-content-auth", auth_tokens={"my-test-auth": lambda: auth_token1} + "get-row-by-content-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, ) - response = await tool.acall() - assert response.is_error == True - assert "400, message='Bad Request'" in response.content - assert isinstance(response.raw_output, str) + with pytest.raises( + Exception, + match='provided parameters were invalid: error parsing authenticated parameter "data": no field named row_data in claims', + ): + await tool.acall() @pytest.mark.usefixtures("toolbox_server") @@ -194,7 +198,7 @@ def toolbox(self): @pytest.fixture(scope="function") def get_n_rows_tool(self, toolbox): tool = toolbox.load_tool("get-n-rows") - assert tool._ToolboxTool__async_tool._AsyncToolboxTool__name == "get-n-rows" + assert tool._ToolboxTool__core_tool.__name__ == "get-n-rows" return tool #### Basic e2e tests @@ -211,7 +215,7 @@ def test_load_toolset_specific( toolset = toolbox.load_toolset(toolset_name) assert len(toolset) == expected_length for tool in toolset: - name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + name = tool._ToolboxTool__core_tool.__name__ assert name in expected_tools def test_aload_toolset_all(self, toolbox): @@ -225,28 +229,26 @@ def test_aload_toolset_all(self, toolbox): "get-row-by-content-auth", ] for tool in toolset: - name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + name = tool._ToolboxTool__core_tool.__name__ assert name in tool_names @pytest.mark.asyncio async def test_run_tool_async(self, get_n_rows_tool): response = await get_n_rows_tool.acall(num_rows="2") - result = response.content - assert "row1" in result - assert "row2" in result - assert "row3" not in result + assert "row1" in response.content + assert "row2" in response.content + assert "row3" not in response.content def test_run_tool_sync(self, get_n_rows_tool): response = get_n_rows_tool.call(num_rows="2") - result = response.content - assert "row1" in result - assert "row2" in result - assert "row3" not in result + assert "row1" in response.content + assert "row2" in response.content + assert "row3" not in response.content def test_run_tool_missing_params(self, get_n_rows_tool): - with pytest.raises(ValidationError, match="Field required"): + with pytest.raises(TypeError, match="missing a required argument: 'num_rows'"): get_n_rows_tool.call() def test_run_tool_wrong_param_type(self, get_n_rows_tool): @@ -256,39 +258,44 @@ def test_run_tool_wrong_param_type(self, get_n_rows_tool): #### Auth tests def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): """Tests running a tool that doesn't require auth, with auth provided.""" - tool = toolbox.load_tool( - "get-row-by-id", auth_tokens={"my-test-auth": lambda: auth_token2} - ) - response = tool.call(id="2") - assert "row2" in response.content + with pytest.raises( + ValueError, + match="Validation failed for tool 'get-row-by-id': unused auth tokens: my-test-auth.", + ): + toolbox.load_tool( + "get-row-by-id", + auth_token_getters={"my-test-auth": lambda: auth_token2}, + ) def test_run_tool_no_auth(self, toolbox): """Tests running a tool requiring auth without providing auth.""" tool = toolbox.load_tool( "get-row-by-id-auth", ) - response = tool.call(id="2") - assert response.is_error == True - assert "401, message='Unauthorized'" in response.content - assert isinstance(response.raw_output, str) + with pytest.raises( + PermissionError, + match="One or more of the following authn services are required to invoke this tool: my-test-auth", + ): + tool.call(id="2") def test_run_tool_wrong_auth(self, toolbox, auth_token2): """Tests running a tool with incorrect auth.""" tool = toolbox.load_tool( "get-row-by-id-auth", ) - auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token2) - response = auth_tool.call(id="2") - assert response.is_error == True - assert "401, message='Unauthorized'" in response.content - assert isinstance(response.raw_output, str) + auth_tool = tool.add_auth_token_getter("my-test-auth", lambda: auth_token2) + with pytest.raises( + Exception, + match="tool invocation not authorized. Please make sure your specify correct auth headers", + ): + auth_tool.call(id="2") def test_run_tool_auth(self, toolbox, auth_token1): """Tests running a tool with correct auth.""" tool = toolbox.load_tool( "get-row-by-id-auth", ) - auth_tool = tool.add_auth_token("my-test-auth", lambda: auth_token1) + auth_tool = tool.add_auth_token_getter("my-test-auth", lambda: auth_token1) response = auth_tool.call(id="2") assert "row2" in response.content @@ -297,27 +304,29 @@ def test_run_tool_param_auth_no_auth(self, toolbox): tool = toolbox.load_tool("get-row-by-email-auth") with pytest.raises( PermissionError, - match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.", + match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): tool.call(email="") def test_run_tool_param_auth(self, toolbox, auth_token1): """Tests running a tool with a param requiring auth, with correct auth.""" tool = toolbox.load_tool( - "get-row-by-email-auth", auth_tokens={"my-test-auth": lambda: auth_token1} + "get-row-by-email-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, ) response = tool.call() - result = response.content - assert "row4" in result - assert "row5" in result - assert "row6" in result + assert "row4" in response.content + assert "row5" in response.content + assert "row6" in response.content def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): """Tests running a tool with a param requiring auth, with insufficient auth.""" tool = toolbox.load_tool( - "get-row-by-content-auth", auth_tokens={"my-test-auth": lambda: auth_token1} + "get-row-by-content-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, ) - response = tool.call() - assert response.is_error == True - assert "400, message='Bad Request'" in response.content - assert isinstance(response.raw_output, str) + with pytest.raises( + Exception, + match='provided parameters were invalid: error parsing authenticated parameter "data": no field named row_data in claims', + ): + tool.call() diff --git a/packages/toolbox-llamaindex/tests/test_tools.py b/packages/toolbox-llamaindex/tests/test_tools.py index faeefd20..3f8cbabe 100644 --- a/packages/toolbox-llamaindex/tests/test_tools.py +++ b/packages/toolbox-llamaindex/tests/test_tools.py @@ -11,19 +11,69 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import concurrent.futures -from unittest.mock import Mock, patch + +import asyncio +from unittest.mock import AsyncMock, Mock, call, patch import pytest +from llama_index.core.tools.types import ToolOutput from pydantic import BaseModel +from toolbox_core.protocol import ParameterSchema as CoreParameterSchema +from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool +from toolbox_core.tool import ToolboxTool as CoreAsyncTool +from toolbox_core.utils import params_to_pydantic_model from toolbox_llamaindex.async_tools import AsyncToolboxTool from toolbox_llamaindex.tools import ToolboxTool +def assert_pydantic_models_equivalent( + model_cls1: type[BaseModel], model_cls2: type[BaseModel], expected_model_name: str +): + assert issubclass(model_cls1, BaseModel), "model_cls1 is not a Pydantic BaseModel" + assert issubclass(model_cls2, BaseModel), "model_cls2 is not a Pydantic BaseModel" + assert ( + model_cls1.__name__ == expected_model_name + ), f"model_cls1 name mismatch: expected {expected_model_name}, got {model_cls1.__name__}" + assert ( + model_cls2.__name__ == expected_model_name + ), f"model_cls2 name mismatch: expected {expected_model_name}, got {model_cls2.__name__}" + + fields1 = model_cls1.model_fields + fields2 = model_cls2.model_fields + + assert ( + fields1.keys() == fields2.keys() + ), f"Field names mismatch: {fields1.keys()} != {fields2.keys()}" + + for field_name in fields1.keys(): + field_info1 = fields1[field_name] + field_info2 = fields2[field_name] + + assert ( + field_info1.annotation == field_info2.annotation + ), f"Field '{field_name}': Annotation mismatch ({field_info1.annotation} != {field_info2.annotation})" + assert ( + field_info1.description == field_info2.description + ), f"Field '{field_name}': Description mismatch ('{field_info1.description}' != '{field_info2.description}')" + is_required1 = ( + field_info1.is_required() + if hasattr(field_info1, "is_required") + else not field_info1.is_nullable() + ) + is_required2 = ( + field_info2.is_required() + if hasattr(field_info2, "is_required") + else not field_info2.is_nullable() + ) + assert ( + is_required1 == is_required2 + ), f"Field '{field_name}': Required status mismatch ({is_required1} != {is_required2})" + + class TestToolboxTool: @pytest.fixture - def tool_schema(self): + def tool_schema_dict(self): return { "description": "Test Tool Description", "parameters": [ @@ -33,9 +83,10 @@ def tool_schema(self): } @pytest.fixture - def auth_tool_schema(self): + def auth_tool_schema_dict(self): return { - "description": "Test Tool Description", + "description": "Test Auth Tool Description", + "authRequired": ["test-auth-source"], "parameters": [ { "name": "param1", @@ -47,197 +98,222 @@ def auth_tool_schema(self): ], } - @pytest.fixture(scope="function") - def mock_async_tool(self, tool_schema): - mock_async_tool = Mock(spec=AsyncToolboxTool) - mock_async_tool.name = "test_tool" - mock_async_tool.description = "test description" - mock_async_tool.args_schema = BaseModel - mock_async_tool._AsyncToolboxTool__name = "test_tool" - mock_async_tool._AsyncToolboxTool__schema = tool_schema - mock_async_tool._AsyncToolboxTool__url = "http://test_url" - mock_async_tool._AsyncToolboxTool__session = Mock() - mock_async_tool._AsyncToolboxTool__auth_tokens = {} - mock_async_tool._AsyncToolboxTool__bound_params = {} - return mock_async_tool - - @pytest.fixture(scope="function") - def mock_async_auth_tool(self, auth_tool_schema): - mock_async_tool = Mock(spec=AsyncToolboxTool) - mock_async_tool.name = "test_tool" - mock_async_tool.description = "test description" - mock_async_tool.args_schema = BaseModel - mock_async_tool._AsyncToolboxTool__name = "test_tool" - mock_async_tool._AsyncToolboxTool__schema = auth_tool_schema - mock_async_tool._AsyncToolboxTool__url = "http://test_url" - mock_async_tool._AsyncToolboxTool__session = Mock() - mock_async_tool._AsyncToolboxTool__auth_tokens = {} - mock_async_tool._AsyncToolboxTool__bound_params = {} - return mock_async_tool - @pytest.fixture - def toolbox_tool(self, mock_async_tool): - return ToolboxTool( - async_tool=mock_async_tool, - loop=Mock(), - thread=Mock(), + def mock_core_tool(self, tool_schema_dict): + sync_mock = Mock(spec=ToolboxCoreSyncTool) + + sync_mock.__name__ = "test_tool_name_for_llamaindex" + sync_mock.__doc__ = tool_schema_dict["description"] + sync_mock._name = "TestToolPydanticModel" + sync_mock._params = [ + CoreParameterSchema(**p) for p in tool_schema_dict["parameters"] + ] + + mock_async_tool_attr = AsyncMock(spec=CoreAsyncTool) + mock_async_tool_attr.return_value = "dummy_internal_async_tool_result" + sync_mock._ToolboxSyncTool__async_tool = mock_async_tool_attr + sync_mock._ToolboxSyncTool__loop = Mock(spec=asyncio.AbstractEventLoop) + sync_mock._ToolboxSyncTool__thread = Mock() + + new_mock_instance_for_methods = Mock(spec=ToolboxCoreSyncTool) + new_mock_instance_for_methods.__name__ = sync_mock.__name__ + new_mock_instance_for_methods.__doc__ = sync_mock.__doc__ + new_mock_instance_for_methods._name = sync_mock._name + new_mock_instance_for_methods._params = sync_mock._params + new_mock_instance_for_methods._ToolboxSyncTool__async_tool = AsyncMock( + spec=CoreAsyncTool + ) + new_mock_instance_for_methods._ToolboxSyncTool__loop = Mock( + spec=asyncio.AbstractEventLoop + ) + new_mock_instance_for_methods._ToolboxSyncTool__thread = Mock() + + sync_mock.add_auth_token_getters = Mock( + return_value=new_mock_instance_for_methods ) + sync_mock.bind_params = Mock(return_value=new_mock_instance_for_methods) + + return sync_mock @pytest.fixture - def auth_toolbox_tool(self, mock_async_auth_tool): - return ToolboxTool( - async_tool=mock_async_auth_tool, - loop=Mock(), - thread=Mock(), + def mock_core_sync_auth_tool(self, auth_tool_schema_dict): + sync_mock = Mock(spec=ToolboxCoreSyncTool) + sync_mock.__name__ = "test_auth_tool_lc_name" + sync_mock.__doc__ = auth_tool_schema_dict["description"] + sync_mock._name = "TestAuthToolPydanticModel" + sync_mock._params = [ + CoreParameterSchema(**p) for p in auth_tool_schema_dict["parameters"] + ] + + mock_async_tool_attr = AsyncMock(spec=CoreAsyncTool) + mock_async_tool_attr.return_value = "dummy_internal_async_auth_tool_result" + sync_mock._ToolboxSyncTool__async_tool = mock_async_tool_attr + sync_mock._ToolboxSyncTool__loop = Mock(spec=asyncio.AbstractEventLoop) + sync_mock._ToolboxSyncTool__thread = Mock() + + new_mock_instance_for_methods = Mock(spec=ToolboxCoreSyncTool) + new_mock_instance_for_methods.__name__ = sync_mock.__name__ + new_mock_instance_for_methods.__doc__ = sync_mock.__doc__ + new_mock_instance_for_methods._name = sync_mock._name + new_mock_instance_for_methods._params = sync_mock._params + new_mock_instance_for_methods._ToolboxSyncTool__async_tool = AsyncMock( + spec=CoreAsyncTool + ) + new_mock_instance_for_methods._ToolboxSyncTool__loop = Mock( + spec=asyncio.AbstractEventLoop ) + new_mock_instance_for_methods._ToolboxSyncTool__thread = Mock() + + sync_mock.add_auth_token_getters = Mock( + return_value=new_mock_instance_for_methods + ) + sync_mock.bind_params = Mock(return_value=new_mock_instance_for_methods) + return sync_mock + + @pytest.fixture + def toolbox_tool(self, mock_core_tool): + return ToolboxTool(core_tool=mock_core_tool) + + @pytest.fixture + def auth_toolbox_tool(self, mock_core_sync_auth_tool): + return ToolboxTool(core_tool=mock_core_sync_auth_tool) + + def test_toolbox_tool_init(self, mock_core_tool): + tool = ToolboxTool(core_tool=mock_core_tool) - def test_toolbox_tool_init(self, mock_async_tool, toolbox_tool): - assert toolbox_tool._ToolboxTool__async_tool == mock_async_tool + assert tool.metadata.name == mock_core_tool.__name__ + assert tool.metadata.description == mock_core_tool.__doc__ + assert tool._ToolboxTool__core_tool == mock_core_tool + + expected_args_schema = params_to_pydantic_model( + mock_core_tool._name, mock_core_tool._params + ) + assert_pydantic_models_equivalent( + tool.metadata.fn_schema, expected_args_schema, mock_core_tool._name + ) @pytest.mark.parametrize( - "params, expected_bound_params", + "params", [ - ({"param1": "bound-value"}, {"param1": "bound-value"}), - ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), - ( - {"param1": "bound-value", "param2": 123}, - {"param1": "bound-value", "param2": 123}, - ), + ({"param1": "bound-value"}), + ({"param1": lambda: "bound-value"}), + ({"param1": "bound-value", "param2": 123}), ], ) def test_toolbox_tool_bind_params( self, params, - expected_bound_params, toolbox_tool, - mock_async_tool, + mock_core_tool, ): - mock_async_tool._AsyncToolboxTool__bound_params = expected_bound_params - mock_async_tool.bind_params.return_value = mock_async_tool - - tool = toolbox_tool.bind_params(params) - mock_async_tool.bind_params.assert_called_once_with(params, True) - - assert isinstance(tool, ToolboxTool) - - for key, value in expected_bound_params.items(): - async_tool_bound_param_val = ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__bound_params[key] - ) - if callable(value): - assert value() == async_tool_bound_param_val() - else: - assert value == async_tool_bound_param_val + returned_core_tool_mock = mock_core_tool.bind_params.return_value + new_llamaindex_tool = toolbox_tool.bind_params(params) - def test_toolbox_tool_bind_param(self, mock_async_tool, toolbox_tool): - expected_bound_param = {"param1": "bound-value"} - mock_async_tool._AsyncToolboxTool__bound_params = expected_bound_param - mock_async_tool.bind_param.return_value = mock_async_tool + mock_core_tool.bind_params.assert_called_once_with(params) + assert isinstance(new_llamaindex_tool, ToolboxTool) + assert new_llamaindex_tool._ToolboxTool__core_tool == returned_core_tool_mock - tool = toolbox_tool.bind_param("param1", "bound-value") - mock_async_tool.bind_param.assert_called_once_with( - "param1", "bound-value", True - ) + def test_toolbox_tool_bind_param(self, toolbox_tool, mock_core_tool): + returned_core_tool_mock = mock_core_tool.bind_params.return_value + new_llamaindex_tool = toolbox_tool.bind_param("param1", "bound-value") - assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__bound_params - == expected_bound_param - ) - assert isinstance(tool, ToolboxTool) + mock_core_tool.bind_params.assert_called_once_with({"param1": "bound-value"}) + assert isinstance(new_llamaindex_tool, ToolboxTool) + assert new_llamaindex_tool._ToolboxTool__core_tool == returned_core_tool_mock @pytest.mark.parametrize( - "auth_tokens, expected_auth_tokens", + "auth_token_getters", [ - ( - {"test-auth-source": lambda: "test-token"}, - {"test-auth-source": lambda: "test-token"}, - ), + ({"test-auth-source": lambda: "test-token"}), ( { "test-auth-source": lambda: "test-token", "another-auth-source": lambda: "another-token", - }, - { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", - }, + } ), ], ) - def test_toolbox_tool_add_auth_tokens( + def test_toolbox_tool_add_auth_token_getters( self, - auth_tokens, - expected_auth_tokens, - mock_async_auth_tool, + auth_token_getters, auth_toolbox_tool, + mock_core_sync_auth_tool, ): - auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens = ( - expected_auth_tokens + returned_core_tool_mock = ( + mock_core_sync_auth_tool.add_auth_token_getters.return_value ) - auth_toolbox_tool._ToolboxTool__async_tool.add_auth_tokens.return_value = ( - mock_async_auth_tool + new_llamaindex_tool = auth_toolbox_tool.add_auth_token_getters( + auth_token_getters ) - tool = auth_toolbox_tool.add_auth_tokens(auth_tokens) - mock_async_auth_tool.add_auth_tokens.assert_called_once_with(auth_tokens, True) - for source, getter in expected_auth_tokens.items(): - assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens[source]() - == getter() - ) - assert isinstance(tool, ToolboxTool) + mock_core_sync_auth_tool.add_auth_token_getters.assert_called_once_with( + auth_token_getters + ) + assert isinstance(new_llamaindex_tool, ToolboxTool) + assert new_llamaindex_tool._ToolboxTool__core_tool == returned_core_tool_mock - def test_toolbox_tool_add_auth_token(self, mock_async_auth_tool, auth_toolbox_tool): + def test_toolbox_tool_add_auth_token_getter( + self, auth_toolbox_tool, mock_core_sync_auth_tool + ): get_id_token = lambda: "test-token" - expected_auth_tokens = {"test-auth-source": get_id_token} - auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens = ( - expected_auth_tokens + returned_core_tool_mock = ( + mock_core_sync_auth_tool.add_auth_token_getters.return_value ) - auth_toolbox_tool._ToolboxTool__async_tool.add_auth_token.return_value = ( - mock_async_auth_tool + + new_llamaindex_tool = auth_toolbox_tool.add_auth_token_getter( + "test-auth-source", get_id_token ) - tool = auth_toolbox_tool.add_auth_token("test-auth-source", get_id_token) - mock_async_auth_tool.add_auth_token.assert_called_once_with( - "test-auth-source", get_id_token, True + mock_core_sync_auth_tool.add_auth_token_getters.assert_called_once_with( + {"test-auth-source": get_id_token} ) + assert isinstance(new_llamaindex_tool, ToolboxTool) + assert new_llamaindex_tool._ToolboxTool__core_tool == returned_core_tool_mock - assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_tokens[ - "test-auth-source" - ]() - == "test-token" + def test_toolbox_tool_run(self, toolbox_tool, mock_core_tool): + kwargs_to_run = {"param1": "run_value1", "param2": 100} + expected_result = ToolOutput( + content="sync_run_output", + tool_name="test_tool_name_for_llamaindex", + raw_input=kwargs_to_run, + raw_output="sync_run_output", + is_error=False, ) - assert isinstance(tool, ToolboxTool) + mock_core_tool.return_value = "sync_run_output" + + result = toolbox_tool.call(**kwargs_to_run) + + assert result == expected_result + assert mock_core_tool.call_count == 1 + assert mock_core_tool.call_args == call(**kwargs_to_run) @pytest.mark.asyncio - async def test_toolbox_tool_validate_auth_strict(self, auth_toolbox_tool): - auth_toolbox_tool._ToolboxTool__async_tool.acall = Mock( - side_effect=PermissionError( - "Parameter(s) `param1` of tool test_tool require authentication" - ) + @patch("toolbox_llamaindex.tools.to_thread", new_callable=AsyncMock) + async def test_toolbox_tool_arun( + self, mock_to_thread_in_tools, toolbox_tool, mock_core_tool + ): + kwargs_to_run = {"param1": "arun_value1", "param2": 200} + expected_result = ToolOutput( + content="async_run_output", + tool_name="test_tool_name_for_llamaindex", + raw_input=kwargs_to_run, + raw_output="async_run_output", + is_error=False, ) - with pytest.raises(PermissionError) as e: - await auth_toolbox_tool.acall() - assert "Parameter(s) `param1` of tool test_tool require authentication" in str( - e.value + + mock_core_tool.return_value = "async_run_output" + + async def to_thread_side_effect(func, *args, **kwargs_for_func): + return func(**kwargs_for_func) + + mock_to_thread_in_tools.side_effect = to_thread_side_effect + + result = await toolbox_tool.acall(**kwargs_to_run) + + assert result == expected_result + mock_to_thread_in_tools.assert_awaited_once_with( + mock_core_tool, **kwargs_to_run ) - @pytest.mark.asyncio - @patch("asyncio.run_coroutine_threadsafe") - async def test_toolbox_tool_run(self, mock_run_coroutine_threadsafe, toolbox_tool): - future = concurrent.futures.Future() - future.set_result({"result": "async success"}) - mock_run_coroutine_threadsafe.return_value = future - result = await toolbox_tool.acall(param1="value1", param2=3) - mock_run_coroutine_threadsafe.assert_called_once() - assert result == {"result": "async success"} - - @patch("asyncio.run_coroutine_threadsafe") - def test_toolbox_tool_sync_run(self, mock_run_coroutine_threadsafe, toolbox_tool): - future = concurrent.futures.Future() - future.set_result({"result": "sync success"}) - mock_run_coroutine_threadsafe.return_value = future - result = toolbox_tool.call(param1="value1", param2=3) - mock_run_coroutine_threadsafe.assert_called_once() - assert result == {"result": "sync success"} + assert mock_core_tool.call_count == 1 + assert mock_core_tool.call_args == call(**kwargs_to_run) diff --git a/packages/toolbox-llamaindex/tests/test_utils.py b/packages/toolbox-llamaindex/tests/test_utils.py deleted file mode 100644 index 78d64e26..00000000 --- a/packages/toolbox-llamaindex/tests/test_utils.py +++ /dev/null @@ -1,291 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import json -import re -import warnings -from typing import Union -from unittest.mock import AsyncMock, Mock, patch - -import aiohttp -import pytest -from pydantic import BaseModel - -from toolbox_llamaindex.utils import ( - ParameterSchema, - _get_auth_headers, - _invoke_tool, - _load_manifest, - _parse_type, - _schema_to_model, -) - -URL = "https://my-toolbox.com/test" -MOCK_MANIFEST = """ -{ - "serverVersion": "0.0.1", - "tools": { - "test_tool": { - "summary": "Test Tool", - "description": "This is a test tool.", - "parameters": [ - { - "name": "param1", - "type": "string", - "description": "Parameter 1" - }, - { - "name": "param2", - "type": "integer", - "description": "Parameter 2" - } - ] - } - } -} -""" - - -class TestUtils: - @pytest.fixture(scope="module") - def mock_manifest(self): - return aiohttp.ClientResponse( - method="GET", - url=aiohttp.client.URL(URL), - writer=None, - continue100=None, - timer=None, - request_info=None, - traces=None, - session=None, - loop=asyncio.get_event_loop(), - ) - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.get") - async def test_load_manifest(self, mock_get, mock_manifest): - mock_manifest.raise_for_status = Mock() - mock_manifest.text = AsyncMock(return_value=MOCK_MANIFEST) - - mock_get.return_value = mock_manifest - session = aiohttp.ClientSession() - manifest = await _load_manifest(URL, session) - await session.close() - mock_get.assert_called_once_with(URL) - - assert manifest.serverVersion == "0.0.1" - assert len(manifest.tools) == 1 - - tool = manifest.tools["test_tool"] - assert tool.description == "This is a test tool." - assert tool.parameters == [ - ParameterSchema(name="param1", type="string", description="Parameter 1"), - ParameterSchema(name="param2", type="integer", description="Parameter 2"), - ] - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.get") - async def test_load_manifest_invalid_json(self, mock_get, mock_manifest): - mock_manifest.raise_for_status = Mock() - mock_manifest.text = AsyncMock(return_value="{ invalid manifest") - mock_get.return_value = mock_manifest - - with pytest.raises(Exception) as e: - session = aiohttp.ClientSession() - await _load_manifest(URL, session) - - mock_get.assert_called_once_with(URL) - assert isinstance(e.value, json.JSONDecodeError) - assert ( - str(e.value) - == "Failed to parse JSON from https://my-toolbox.com/test: Expecting property name enclosed in double quotes: line 1 column 3 (char 2): line 1 column 3 (char 2)" - ) - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.get") - async def test_load_manifest_invalid_manifest(self, mock_get, mock_manifest): - mock_manifest.raise_for_status = Mock() - mock_manifest.text = AsyncMock(return_value='{ "something": "invalid" }') - mock_get.return_value = mock_manifest - - with pytest.raises(Exception) as e: - session = aiohttp.ClientSession() - await _load_manifest(URL, session) - - mock_get.assert_called_once_with(URL) - assert isinstance(e.value, ValueError) - assert re.match( - r"Invalid JSON data from https://my-toolbox.com/test: 2 validation errors for ManifestSchema\nserverVersion\n Field required \[type=missing, input_value={'something': 'invalid'}, input_type=dict]\n For further information visit https://errors.pydantic.dev/\d+\.\d+/v/missing\ntools\n Field required \[type=missing, input_value={'something': 'invalid'}, input_type=dict]\n For further information visit https://errors.pydantic.dev/\d+\.\d+/v/missing", - str(e.value), - ) - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.get") - async def test_load_manifest_api_error(self, mock_get, mock_manifest): - error = aiohttp.ClientError("Simulated HTTP Error") - mock_manifest.raise_for_status = Mock() - mock_manifest.text = AsyncMock(side_effect=error) - mock_get.return_value = mock_manifest - - with pytest.raises(aiohttp.ClientError) as exc_info: - session = aiohttp.ClientSession() - await _load_manifest(URL, session) - mock_get.assert_called_once_with(URL) - assert exc_info.value == error - - def test_schema_to_model(self): - schema = [ - ParameterSchema(name="param1", type="string", description="Parameter 1"), - ParameterSchema(name="param2", type="integer", description="Parameter 2"), - ] - model = _schema_to_model("TestModel", schema) - assert issubclass(model, BaseModel) - - assert model.model_fields["param1"].annotation == str - assert model.model_fields["param1"].description == "Parameter 1" - assert model.model_fields["param2"].annotation == int - assert model.model_fields["param2"].description == "Parameter 2" - - def test_schema_to_model_empty(self): - model = _schema_to_model("TestModel", []) - assert issubclass(model, BaseModel) - assert len(model.model_fields) == 0 - - @pytest.mark.parametrize( - "parameter_schema, expected_type", - [ - (ParameterSchema(name="foo", description="bar", type="string"), str), - (ParameterSchema(name="foo", description="bar", type="integer"), int), - (ParameterSchema(name="foo", description="bar", type="float"), float), - (ParameterSchema(name="foo", description="bar", type="boolean"), bool), - ( - ParameterSchema( - name="foo", - description="bar", - type="array", - items=ParameterSchema( - name="foo", description="bar", type="integer" - ), - ), - list[int], - ), - ], - ) - def test_parse_type(self, parameter_schema, expected_type): - assert _parse_type(parameter_schema) == expected_type - - @pytest.mark.parametrize( - "fail_parameter_schema", - [ - (ParameterSchema(name="foo", description="bar", type="invalid")), - ( - ParameterSchema( - name="foo", - description="bar", - type="array", - items=ParameterSchema( - name="foo", description="bar", type="invalid" - ), - ) - ), - ], - ) - def test_parse_type_invalid(self, fail_parameter_schema): - with pytest.raises(ValueError): - _parse_type(fail_parameter_schema) - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.post") - async def test_invoke_tool(self, mock_post): - mock_response = Mock() - mock_response.raise_for_status = Mock() - mock_response.json = AsyncMock(return_value={"key": "value"}) - mock_post.return_value.__aenter__.return_value = mock_response - - result = await _invoke_tool( - "http://localhost:8000", - aiohttp.ClientSession(), - "tool_name", - {"input": "data"}, - {}, - ) - - mock_post.assert_called_once_with( - "http://localhost:8000/api/tool/tool_name/invoke", - json={"input": "data"}, - headers={}, - ) - assert result == {"key": "value"} - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.post") - async def test_invoke_tool_unsecure_with_auth(self, mock_post): - mock_response = Mock() - mock_response.raise_for_status = Mock() - mock_response.json = AsyncMock(return_value={"key": "value"}) - mock_post.return_value.__aenter__.return_value = mock_response - - with pytest.warns( - UserWarning, - match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", - ): - result = await _invoke_tool( - "http://localhost:8000", - aiohttp.ClientSession(), - "tool_name", - {"input": "data"}, - {"my_test_auth": lambda: "fake_id_token"}, - ) - - mock_post.assert_called_once_with( - "http://localhost:8000/api/tool/tool_name/invoke", - json={"input": "data"}, - headers={"my_test_auth_token": "fake_id_token"}, - ) - assert result == {"key": "value"} - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.post") - async def test_invoke_tool_secure_with_auth(self, mock_post): - session = aiohttp.ClientSession() - mock_response = Mock() - mock_response.raise_for_status = Mock() - mock_response.json = AsyncMock(return_value={"key": "value"}) - mock_post.return_value.__aenter__.return_value = mock_response - - with warnings.catch_warnings(): - warnings.simplefilter("error") - result = await _invoke_tool( - "https://localhost:8000", - session, - "tool_name", - {"input": "data"}, - {"my_test_auth": lambda: "fake_id_token"}, - ) - - mock_post.assert_called_once_with( - "https://localhost:8000/api/tool/tool_name/invoke", - json={"input": "data"}, - headers={"my_test_auth_token": "fake_id_token"}, - ) - assert result == {"key": "value"} - - def test_get_auth_headers_deprecation_warning(self): - """Test _get_auth_headers deprecation warning.""" - with pytest.warns( - DeprecationWarning, - match=r"Call to deprecated function \(or staticmethod\) _get_auth_headers\. \(Please use `_get_auth_tokens` instead\.\)$", - ): - _get_auth_headers({"auth_source1": lambda: "test_token"})