diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 8e5ca278..cd3879ff 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -308,11 +308,11 @@ async def load_toolset( return tools - async def add_headers( + def add_headers( self, headers: Mapping[str, Union[Callable, Coroutine, str]] ) -> None: """ - Asynchronously Add headers to be included in each request sent through this client. + Add headers to be included in each request sent through this client. Args: headers: Headers to include in each request sent through this client. diff --git a/packages/toolbox-core/src/toolbox_core/sync_client.py b/packages/toolbox-core/src/toolbox_core/sync_client.py index 22e708b7..e1b06d30 100644 --- a/packages/toolbox-core/src/toolbox_core/sync_client.py +++ b/packages/toolbox-core/src/toolbox_core/sync_client.py @@ -14,13 +14,13 @@ import asyncio +from asyncio import AbstractEventLoop from threading import Thread -from typing import Any, Callable, Coroutine, Mapping, Optional, TypeVar, Union +from typing import Any, Callable, Coroutine, Mapping, Optional, Union from .client import ToolboxClient from .sync_tool import ToolboxSyncTool - -T = TypeVar("T") +from concurrent.futures import Future class ToolboxSyncClient: @@ -31,7 +31,7 @@ class ToolboxSyncClient: service endpoint. """ - __loop: Optional[asyncio.AbstractEventLoop] = None + __loop: Optional[AbstractEventLoop] = None __thread: Optional[Thread] = None def __init__( @@ -58,11 +58,22 @@ def __init__( async def create_client(): return ToolboxClient(url, client_headers=client_headers) - # Ignoring type since we're already checking the existence of a loop above. self.__async_client = asyncio.run_coroutine_threadsafe( create_client(), self.__class__.__loop ).result() + @property + def _async_client(self) -> ToolboxClient: + return self.__async_client + + @property + def _loop(self) -> Optional[AbstractEventLoop]: + return self.__class__.__loop + + @property + def _thread(self) -> Optional[Thread]: + return self.__class__.__thread + def close(self): """ Synchronously closes the underlying client session. Doing so will cause @@ -75,6 +86,46 @@ def close(self): coro = self.__async_client.close() asyncio.run_coroutine_threadsafe(coro, self.__loop).result() + def load_tool_future( + self, + name: str, + auth_token_getters: dict[str, Callable[[], str]] = {}, + bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, + ) -> Future[ToolboxSyncTool]: + """ + Returns a future that loads a tool from the server. + """ + if not self.__loop or not self.__thread: + raise ValueError("Background loop or thread cannot be None.") + + async def async_worker() -> ToolboxSyncTool: + async_tool = await self.__async_client.load_tool(name, auth_token_getters, bound_params) + return ToolboxSyncTool(async_tool, self.__loop, self.__thread) + return asyncio.run_coroutine_threadsafe(async_worker(), self.__loop) + + def load_toolset_future( + self, + name: Optional[str] = None, + auth_token_getters: dict[str, Callable[[], str]] = {}, + bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, + strict: bool = False, + ) -> Future[list[ToolboxSyncTool]]: + """ + Returns a future that fetches a toolset and loads all tools defined within it. + """ + if not self.__loop or not self.__thread: + raise ValueError("Background loop or thread cannot be None.") + + async def async_worker() -> list[ToolboxSyncTool]: + async_tools = await self.__async_client.load_toolset( + name, auth_token_getters, bound_params, strict + ) + return [ + ToolboxSyncTool(async_tool, self.__loop, self.__thread) + for async_tool in async_tools + ] + return asyncio.run_coroutine_threadsafe(async_worker(), self.__loop) + def load_tool( self, name: str, @@ -100,50 +151,44 @@ def load_tool( for execution. The specific arguments and behavior of the callable depend on the tool itself. """ - coro = self.__async_client.load_tool(name, auth_token_getters, bound_params) - - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - - async_tool = asyncio.run_coroutine_threadsafe(coro, self.__loop).result() - return ToolboxSyncTool(async_tool, self.__loop, self.__thread) + return self.load_tool_future(name, auth_token_getters, bound_params).result() def load_toolset( self, - name: str, + name: Optional[str] = None, auth_token_getters: dict[str, Callable[[], str]] = {}, bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, + strict: bool = False, ) -> list[ToolboxSyncTool]: """ Synchronously fetches a toolset and loads all tools defined within it. Args: - name: Name of the toolset to load tools. + name: Name of the toolset to load. If None, loads the default toolset. auth_token_getters: A mapping of authentication service names to callables that return the corresponding authentication token. bound_params: A mapping of parameter names to bind to specific values or callables that are called to produce values as needed. + strict: If True, raises an error if *any* loaded tool instance fails + to utilize at least one provided parameter or auth token (if any + provided). If False (default), raises an error only if a + user-provided parameter or auth token cannot be applied to *any* + loaded tool across the set. Returns: list[ToolboxSyncTool]: A list of callables, one for each tool defined in the toolset. - """ - coro = self.__async_client.load_toolset(name, auth_token_getters, bound_params) - - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - async_tools = asyncio.run_coroutine_threadsafe(coro, self.__loop).result() # type: ignore - return [ - ToolboxSyncTool(async_tool, self.__loop, self.__thread) - for async_tool in async_tools - ] + Raises: + ValueError: If validation fails based on the `strict` flag. + """ + return self.load_toolset_future(name, auth_token_getters, bound_params, strict).result() def add_headers( self, headers: Mapping[str, Union[Callable, Coroutine, str]] ) -> None: """ - Synchronously Add headers to be included in each request sent through this client. + Add headers to be included in each request sent through this client. Args: headers: Headers to include in each request sent through this client. @@ -151,10 +196,7 @@ def add_headers( Raises: ValueError: If any of the headers are already registered in the client. """ - coro = self.__async_client.add_headers(headers) - - # We have already created a new loop in the init method in case it does not already exist - asyncio.run_coroutine_threadsafe(coro, self.__loop).result() # type: ignore + self.__async_client.add_headers(headers) def __enter__(self): """Enter the runtime context related to this client instance.""" diff --git a/packages/toolbox-core/src/toolbox_core/sync_tool.py b/packages/toolbox-core/src/toolbox_core/sync_tool.py index 74f6f0bf..791e6b46 100644 --- a/packages/toolbox-core/src/toolbox_core/sync_tool.py +++ b/packages/toolbox-core/src/toolbox_core/sync_tool.py @@ -17,12 +17,11 @@ from asyncio import AbstractEventLoop from inspect import Signature from threading import Thread -from typing import Any, Callable, Coroutine, Mapping, Sequence, TypeVar, Union +from typing import Any, Callable, Coroutine, Mapping, Sequence, Union from .protocol import ParameterSchema from .tool import ToolboxTool - -T = TypeVar("T") +from concurrent.futures import Future class ToolboxSyncTool: @@ -69,6 +68,18 @@ def __init__( f"{self.__class__.__qualname__}.{self.__async_tool.__name__}" ) + @property + def _async_tool(self) -> ToolboxTool: + return self.__async_tool + + @property + def _loop(self) -> AbstractEventLoop: + return self.__loop + + @property + def _thread(self) -> Thread: + return self.__thread + @property def __name__(self) -> str: return self.__async_tool.__name__ @@ -119,6 +130,13 @@ def _auth_service_token_getters(self) -> Mapping[str, Callable[[], str]]: def _client_headers(self) -> Mapping[str, Union[Callable, Coroutine, str]]: return self.__async_tool._client_headers + def call_future(self, *args: Any, **kwargs: Any) -> Future[str]: + """ + Returns future that calls the remote tool with the provided arguments. + """ + coro = self.__async_tool(*args, **kwargs) + return asyncio.run_coroutine_threadsafe(coro, self.__loop) + def __call__(self, *args: Any, **kwargs: Any) -> str: """ Synchronously calls the remote tool with the provided arguments. @@ -133,8 +151,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> str: Returns: The string result returned by the remote tool execution. """ - coro = self.__async_tool(*args, **kwargs) - return asyncio.run_coroutine_threadsafe(coro, self.__loop).result() + return self.call_future(*args, **kwargs).result() def add_auth_token_getters( self, diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 98d04e17..ad997d40 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -19,6 +19,7 @@ from warnings import warn from aiohttp import ClientSession +from pydantic import BaseModel from .protocol import ParameterSchema from .utils import ( @@ -158,6 +159,10 @@ def _auth_service_token_getters(self) -> Mapping[str, Callable[[], str]]: def _client_headers(self) -> Mapping[str, Union[Callable, Coroutine, str]]: return MappingProxyType(self.__client_headers) + @property + def _pydantic_model(self) -> type[BaseModel]: + return self.__pydantic_model + def __copy( self, session: Optional[ClientSession] = None, @@ -239,7 +244,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: for s in self.__required_authn_params.values(): req_auth_services.update(s) req_auth_services.update(self.__required_authz_tokens) - raise ValueError( + raise PermissionError( f"One or more of the following authn services are required to invoke this tool" f": {','.join(req_auth_services)}" ) diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index b57624b7..1b36fe0d 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -1443,7 +1443,7 @@ async def test_add_headers_success( ) async with ToolboxClient(TEST_BASE_URL) as client: - await client.add_headers(static_header) + client.add_headers(static_header) assert client._ToolboxClient__client_headers == static_header tool = await client.load_tool(tool_name) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index c8111b6f..8920bc3b 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -147,7 +147,7 @@ async def test_run_tool_no_auth(self, toolbox: ToolboxClient): """Tests running a tool requiring auth without providing auth.""" tool = await toolbox.load_tool("get-row-by-id-auth") with pytest.raises( - Exception, + PermissionError, match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): await tool(id="2") @@ -188,7 +188,7 @@ async def test_run_tool_param_auth_no_auth(self, toolbox: ToolboxClient): """Tests running a tool with a param requiring auth, without auth.""" tool = await toolbox.load_tool("get-row-by-email-auth") with pytest.raises( - ValueError, + PermissionError, match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): await tool() diff --git a/packages/toolbox-core/tests/test_sync_client.py b/packages/toolbox-core/tests/test_sync_client.py index 51a4a288..5bbedb7d 100644 --- a/packages/toolbox-core/tests/test_sync_client.py +++ b/packages/toolbox-core/tests/test_sync_client.py @@ -14,8 +14,10 @@ import inspect -from typing import Any, Callable, Mapping, Optional -from unittest.mock import AsyncMock, patch +from asyncio import AbstractEventLoop +from threading import Thread +from typing import Any, Callable, Generator, Mapping, Optional +from unittest.mock import AsyncMock, Mock, patch import pytest from aioresponses import CallbackResult, aioresponses @@ -24,6 +26,7 @@ from toolbox_core.protocol import ManifestSchema, ParameterSchema, ToolSchema from toolbox_core.sync_client import ToolboxSyncClient from toolbox_core.sync_tool import ToolboxSyncTool +from toolbox_core.tool import ToolboxTool TEST_BASE_URL = "http://toolbox.example.com" @@ -44,8 +47,12 @@ def sync_client_environment(): # This ensures any client created will start a new loop/thread. # Ensure no loop/thread is running from a previous misbehaving test or setup - assert original_loop is None or not original_loop.is_running() - assert original_thread is None or not original_thread.is_alive() + if original_loop and original_loop.is_running(): + original_loop.call_soon_threadsafe(original_loop.stop) + if original_thread and original_thread.is_alive(): + original_thread.join(timeout=5) + ToolboxSyncClient._ToolboxSyncClient__loop = None + ToolboxSyncClient._ToolboxSyncClient__thread = None ToolboxSyncClient._ToolboxSyncClient__loop = None ToolboxSyncClient._ToolboxSyncClient__thread = None @@ -67,21 +74,18 @@ def sync_client_environment(): @pytest.fixture -def sync_client(sync_client_environment, request): +def sync_client(sync_client_environment): """ Provides a ToolboxSyncClient instance within an isolated environment. The client's underlying async session is automatically closed after the test. The class-level loop/thread are managed by sync_client_environment. """ - # `sync_client_environment` has prepared the class state. client = ToolboxSyncClient(TEST_BASE_URL) - def finalizer(): - client.close() # Closes the async_client's session. - # Loop/thread shutdown is handled by sync_client_environment's teardown. + yield client - request.addfinalizer(finalizer) - return client + client.close() # Closes the async_client's session. + # Loop/thread shutdown is handled by sync_client_environment's teardown. @pytest.fixture() @@ -231,6 +235,41 @@ def test_sync_load_toolset_success( assert result1 == f"{TOOL1_NAME}_ok" +def test_sync_tool_internal_properties(aioresponses, tool_schema_minimal, sync_client): + """ + Tests that the internal properties _async_tool, _loop, and _thread + of a ToolboxSyncTool instance are correctly initialized and accessible. + This directly covers the respective @property methods in ToolboxSyncTool. + """ + TOOL_NAME = "test_tool_for_internal_properties" + mock_tool_load(aioresponses, TOOL_NAME, tool_schema_minimal) + + loaded_sync_tool = sync_client.load_tool(TOOL_NAME) + + assert isinstance(loaded_sync_tool, ToolboxSyncTool) + + # 1. Test the _async_tool property + internal_async_tool = loaded_sync_tool._async_tool + assert isinstance(internal_async_tool, ToolboxTool) + assert internal_async_tool.__name__ == TOOL_NAME + + # 2. Test the _loop property + internal_loop = loaded_sync_tool._loop + assert isinstance(internal_loop, AbstractEventLoop) + assert internal_loop is sync_client._ToolboxSyncClient__loop + assert ( + internal_loop.is_running() + ), "The event loop used by ToolboxSyncTool should be running." + + # 3. Test the _thread property + internal_thread = loaded_sync_tool._thread + assert isinstance(internal_thread, Thread) + assert internal_thread is sync_client._ToolboxSyncClient__thread + assert ( + internal_thread.is_alive() + ), "The thread used by ToolboxSyncTool should be alive." + + def test_sync_invoke_tool_server_error(aioresponses, test_tool_str_schema, sync_client): TOOL_NAME = "sync_server_error_tool" ERROR_MESSAGE = "Simulated Server Error for Sync Client" @@ -411,20 +450,32 @@ def post_callback(url, **kwargs): result = tool(param1="test") assert result == expected_payload["result"] - @pytest.mark.usefixtures("sync_client_environment") def test_sync_add_headers_duplicate_fail(self): - """ - Tests that adding a duplicate header via add_headers raises ValueError. - Manually create client to control initial headers. - """ + """Tests that adding a duplicate header via add_headers raises ValueError (from async client).""" initial_headers = {"X-Initial-Header": "initial_value"} + mock_async_client = AsyncMock(spec=ToolboxClient) + + # Configure add_headers to simulate the ValueError from ToolboxClient + def mock_add_headers(headers): + # Simulate ToolboxClient's check + if "X-Initial-Header" in headers: + raise ValueError( + "Client header(s) `X-Initial-Header` already registered" + ) - with ToolboxSyncClient(TEST_BASE_URL, client_headers=initial_headers) as client: - with pytest.raises( - ValueError, - match="Client header\\(s\\) `X-Initial-Header` already registered", - ): - client.add_headers({"X-Initial-Header": "another_value"}) + mock_async_client.add_headers = Mock(side_effect=mock_add_headers) + + with patch( + "toolbox_core.sync_client.ToolboxClient", return_value=mock_async_client + ): + with ToolboxSyncClient( + TEST_BASE_URL, client_headers=initial_headers + ) as client: + with pytest.raises( + ValueError, + match="Client header\\(s\\) `X-Initial-Header` already registered", + ): + client.add_headers({"X-Initial-Header": "another_value"}) class TestSyncAuth: @@ -528,7 +579,7 @@ def test_auth_with_load_tool_fail_no_token( tool = sync_client.load_tool(tool_name_auth) with pytest.raises( - ValueError, + PermissionError, match="One or more of the following authn services are required to invoke this tool: my-auth-service", ): tool(argA=15, argB=True) @@ -592,3 +643,53 @@ def test_constructor_getters_missing_fail( tool_name_auth, auth_token_getters={UNUSED_AUTH_SERVICE: lambda: "token"}, ) + + +# --- Tests for @property methods of ToolboxSyncClient --- + + +@pytest.fixture +def sync_client_with_mocks() -> Generator[ToolboxSyncClient, Any, Any]: + """ + Fixture to create a ToolboxSyncClient with mocked internal async client + without relying on actual network calls during init. + """ + with patch( + "toolbox_core.sync_client.ToolboxClient", autospec=True + ) as MockToolboxClient: + # Mock the async client's constructor to return an AsyncMock instance + mock_async_client_instance = AsyncMock(spec=ToolboxClient) + MockToolboxClient.return_value = mock_async_client_instance + + # Mock the run_coroutine_threadsafe and its result() + with patch("asyncio.run_coroutine_threadsafe") as mock_run_coroutine_threadsafe: + mock_future = Mock() + mock_future.result.return_value = mock_async_client_instance + mock_run_coroutine_threadsafe.return_value = mock_future + + client = ToolboxSyncClient(TEST_BASE_URL) + yield client + + +def test_sync_client_underscore_async_client_property( + sync_client_with_mocks: ToolboxSyncClient, +): + """Tests the _async_client property.""" + assert isinstance(sync_client_with_mocks._async_client, AsyncMock) + + +def test_sync_client_underscore_loop_property( + sync_client_with_mocks: ToolboxSyncClient, +): + """Tests the _loop property.""" + assert sync_client_with_mocks._loop is not None + assert isinstance(sync_client_with_mocks._loop, AbstractEventLoop) + + +def test_sync_client_underscore_thread_property( + sync_client_with_mocks: ToolboxSyncClient, +): + """Tests the _thread property.""" + assert sync_client_with_mocks._thread is not None + assert isinstance(sync_client_with_mocks._thread, Thread) + assert sync_client_with_mocks._thread.is_alive() diff --git a/packages/toolbox-core/tests/test_sync_e2e.py b/packages/toolbox-core/tests/test_sync_e2e.py index 885724e9..f2730e47 100644 --- a/packages/toolbox-core/tests/test_sync_e2e.py +++ b/packages/toolbox-core/tests/test_sync_e2e.py @@ -129,7 +129,7 @@ def test_run_tool_no_auth(self, toolbox: ToolboxSyncClient): """Tests running a tool requiring auth without providing auth.""" tool = toolbox.load_tool("get-row-by-id-auth") with pytest.raises( - Exception, + PermissionError, match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): tool(id="2") @@ -156,7 +156,7 @@ def test_run_tool_param_auth_no_auth(self, toolbox: ToolboxSyncClient): """Tests running a tool with a param requiring auth, without auth.""" tool = toolbox.load_tool("get-row-by-email-auth") with pytest.raises( - ValueError, + PermissionError, match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): tool() diff --git a/packages/toolbox-core/tests/test_tool.py b/packages/toolbox-core/tests/test_tool.py index c64149f5..7dffad91 100644 --- a/packages/toolbox-core/tests/test_tool.py +++ b/packages/toolbox-core/tests/test_tool.py @@ -23,7 +23,7 @@ import pytest_asyncio from aiohttp import ClientSession from aioresponses import aioresponses -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from toolbox_core.protocol import ParameterSchema from toolbox_core.tool import ToolboxTool, create_func_docstring, resolve_value @@ -578,6 +578,24 @@ def test_toolbox_tool_underscore_client_headers_property(toolbox_tool: ToolboxTo client_headers["new_header"] = "new_value" +def test_toolbox_tool_underscore_pydantic_model_property(toolbox_tool: ToolboxTool): + """Tests the _pydantic_model property returns the correct Pydantic model.""" + pydantic_model = toolbox_tool._pydantic_model + assert issubclass(pydantic_model, BaseModel) + assert pydantic_model.__name__ == TEST_TOOL_NAME + + # Test that the model can validate expected data + valid_data = {"message": "test", "count": 10} + validated_data = pydantic_model.model_validate(valid_data) + assert validated_data.message == "test" + assert validated_data.count == 10 + + # Test that the model raises ValidationError for invalid data + invalid_data = {"message": 123, "count": "not_an_int"} + with pytest.raises(ValidationError): + pydantic_model.model_validate(invalid_data) + + # --- Test for the HTTP Warning --- @pytest.mark.parametrize( "trigger_condition_params", diff --git a/packages/toolbox-langchain/README.md b/packages/toolbox-langchain/README.md index 9f698694..fca7736b 100644 --- a/packages/toolbox-langchain/README.md +++ b/packages/toolbox-langchain/README.md @@ -227,7 +227,7 @@ tools = toolbox.load_toolset() auth_tool = tools[0].add_auth_token_getter("my_auth", get_auth_token) # Single token -multi_auth_tool = tools[0].add_auth_token_getters({"my_auth", get_auth_token}) # Multiple tokens +multi_auth_tool = tools[0].add_auth_token_getters({"auth_1": get_auth_1}, {"auth_2": get_auth_2}) # Multiple tokens # OR diff --git a/packages/toolbox-langchain/integration.cloudbuild.yaml b/packages/toolbox-langchain/integration.cloudbuild.yaml index 644794fb..51f0ce81 100644 --- a/packages/toolbox-langchain/integration.cloudbuild.yaml +++ b/packages/toolbox-langchain/integration.cloudbuild.yaml @@ -15,10 +15,11 @@ steps: - id: Install library requirements name: 'python:${_VERSION}' + dir: 'packages/toolbox-langchain' args: - install - '-r' - - 'packages/toolbox-langchain/requirements.txt' + - 'requirements.txt' - '--user' entrypoint: pip - id: Install test requirements diff --git a/packages/toolbox-langchain/pyproject.toml b/packages/toolbox-langchain/pyproject.toml index f4f5b7aa..9aaa254a 100644 --- a/packages/toolbox-langchain/pyproject.toml +++ b/packages/toolbox-langchain/pyproject.toml @@ -9,6 +9,8 @@ authors = [ {name = "Google LLC", email = "googleapis-packages@google.com"} ] dependencies = [ + # TODO: Bump lowest supported version to 0.2.0 + "toolbox-core>=0.1.0,<1.0.0", "langchain-core>=0.2.23,<1.0.0", "PyYAML>=6.0.1,<7.0.0", "pydantic>=2.7.0,<3.0.0", diff --git a/packages/toolbox-langchain/requirements.txt b/packages/toolbox-langchain/requirements.txt index 5fd65843..3ada831d 100644 --- a/packages/toolbox-langchain/requirements.txt +++ b/packages/toolbox-langchain/requirements.txt @@ -1,3 +1,4 @@ +-e ../toolbox-core langchain-core==0.3.56 PyYAML==6.0.2 pydantic==2.11.4 diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py index aacbc5af..95e384c8 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_client.py @@ -16,9 +16,9 @@ from warnings import warn from aiohttp import ClientSession +from toolbox_core.client import ToolboxClient as ToolboxCoreClient -from .tools import AsyncToolboxTool -from .utils import ManifestSchema, _load_manifest +from .async_tools import AsyncToolboxTool # This class is an internal implementation detail and is not exposed to the @@ -38,8 +38,7 @@ def __init__( url: The base URL of the Toolbox service. session: An HTTP client session. """ - self.__url = url - self.__session = session + self.__core_client = ToolboxCoreClient(url=url, session=session) async def aload_tool( self, @@ -48,7 +47,6 @@ async def aload_tool( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, ) -> AsyncToolboxTool: """ Loads the tool with the given tool name from the Toolbox service. @@ -61,51 +59,42 @@ async def aload_tool( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. Returns: A tool loaded from the Toolbox. """ - if auth_headers: + if auth_tokens: if auth_token_getters: warn( - "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", DeprecationWarning, ) else: warn( - "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", DeprecationWarning, ) - auth_token_getters = auth_headers + auth_token_getters = auth_tokens - if auth_tokens: + if auth_headers: if auth_token_getters: warn( - "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", DeprecationWarning, ) else: warn( - "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", DeprecationWarning, ) - auth_token_getters = auth_tokens + auth_token_getters = auth_headers - url = f"{self.__url}/api/tool/{tool_name}" - manifest: ManifestSchema = await _load_manifest(url, self.__session) - - return AsyncToolboxTool( - tool_name, - manifest.tools[tool_name], - self.__url, - self.__session, - auth_token_getters, - bound_params, - strict, + core_tool = await self.__core_client.load_tool( + name=tool_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, ) + return AsyncToolboxTool(core_tool=core_tool) async def aload_toolset( self, @@ -114,7 +103,7 @@ async def aload_toolset( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + strict: bool = False, ) -> list[AsyncToolboxTool]: """ Loads tools from the Toolbox service, optionally filtered by toolset @@ -129,55 +118,51 @@ async def aload_toolset( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. + strict: If True, raises an error if *any* loaded tool instance fails + to utilize at least one provided parameter or auth token (if any + provided). If False (default), raises an error only if a + user-provided parameter or auth token cannot be applied to *any* + loaded tool across the set. Returns: A list of all tools loaded from the Toolbox. """ - if auth_headers: + if auth_tokens: if auth_token_getters: warn( - "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", DeprecationWarning, ) else: warn( - "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", DeprecationWarning, ) - auth_token_getters = auth_headers + auth_token_getters = auth_tokens - if auth_tokens: + if auth_headers: if auth_token_getters: warn( - "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", DeprecationWarning, ) else: warn( - "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", DeprecationWarning, ) - auth_token_getters = auth_tokens + auth_token_getters = auth_headers - url = f"{self.__url}/api/toolset/{toolset_name or ''}" - manifest: ManifestSchema = await _load_manifest(url, self.__session) - tools: list[AsyncToolboxTool] = [] - - for tool_name, tool_schema in manifest.tools.items(): - tools.append( - AsyncToolboxTool( - tool_name, - tool_schema, - self.__url, - self.__session, - auth_token_getters, - bound_params, - strict, - ) - ) + core_tools = await self.__core_client.load_toolset( + name=toolset_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=strict, + ) + + tools = [] + for core_tool in core_tools: + tools.append(AsyncToolboxTool(core_tool=core_tool)) return tools def load_tool( @@ -187,7 +172,6 @@ def load_tool( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, ) -> AsyncToolboxTool: raise NotImplementedError("Synchronous methods not supported by async client.") @@ -198,6 +182,6 @@ def load_toolset( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + strict: bool = False, ) -> list[AsyncToolboxTool]: raise NotImplementedError("Synchronous methods not supported by async client.") diff --git a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py index 40e21ee6..282341a8 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/async_tools.py @@ -12,22 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy -from typing import Any, Callable, TypeVar, Union -from warnings import warn +from typing import Any, Callable, Union -from aiohttp import ClientSession +from deprecated import deprecated from langchain_core.tools import BaseTool - -from .utils import ( - ToolSchema, - _find_auth_params, - _find_bound_params, - _invoke_tool, - _schema_to_model, -) - -T = TypeVar("T") +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool # This class is an internal implementation detail and is not exposed to the @@ -41,109 +30,28 @@ class AsyncToolboxTool(BaseTool): def __init__( self, - name: str, - schema: ToolSchema, - url: str, - session: ClientSession, - auth_token_getters: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + core_tool: ToolboxCoreTool, ) -> None: """ Initializes an AsyncToolboxTool instance. Args: - name: The name of the tool. - schema: The tool schema. - url: The base URL of the Toolbox service. - session: The HTTP client session. - auth_token_getters: A mapping of authentication source names to - functions that retrieve ID tokens. - bound_params: A mapping of parameter names to their bound - values. - strict: If True, raises a ValueError if any of the given bound - parameters is missing from the schema or requires - authentication. If False, only issues a warning. + core_tool: The underlying core async ToolboxTool instance. """ - # If the schema is not already a ToolSchema instance, we create one from - # its attributes. This allows flexibility in how the schema is provided, - # accepting both a ToolSchema object and a dictionary of schema - # attributes. - if not isinstance(schema, ToolSchema): - schema = ToolSchema(**schema) - - auth_params, non_auth_params = _find_auth_params(schema.parameters) - non_auth_bound_params, non_auth_non_bound_params = _find_bound_params( - non_auth_params, list(bound_params) - ) - - # Check if the user is trying to bind a param that is authenticated or - # is missing from the given schema. - auth_bound_params: list[str] = [] - missing_bound_params: list[str] = [] - for bound_param in bound_params: - if bound_param in [param.name for param in auth_params]: - auth_bound_params.append(bound_param) - elif bound_param not in [param.name for param in non_auth_params]: - missing_bound_params.append(bound_param) - - # Create error messages for any params that are found to be - # authenticated or missing. - messages: list[str] = [] - if auth_bound_params: - messages.append( - f"Parameter(s) {', '.join(auth_bound_params)} already authenticated and cannot be bound." - ) - if missing_bound_params: - messages.append( - f"Parameter(s) {', '.join(missing_bound_params)} missing and cannot be bound." - ) - - # Join any error messages and raise them as an error or warning, - # depending on the value of the strict flag. - if messages: - message = "\n\n".join(messages) - if strict: - raise ValueError(message) - warn(message) - - # Bind values for parameters present in the schema that don't require - # authentication. - bound_params = { - param_name: param_value - for param_name, param_value in bound_params.items() - if param_name in [param.name for param in non_auth_bound_params] - } - - # Update the tools schema to validate only the presence of parameters - # that neither require authentication nor are bound. - schema.parameters = non_auth_non_bound_params - # Due to how pydantic works, we must initialize the underlying # BaseTool class before assigning values to member variables. super().__init__( - name=name, - description=schema.description, - args_schema=_schema_to_model(model_name=name, schema=schema.parameters), + name=core_tool.__name__, + description=core_tool.__doc__, + args_schema=core_tool._pydantic_model, ) + self.__core_tool = core_tool - self.__name = name - self.__schema = schema - self.__url = url - self.__session = session - self.__auth_token_getters = auth_token_getters - self.__auth_params = auth_params - self.__bound_params = bound_params - - # Warn users about any missing authentication so they can add it before - # tool invocation. - self.__validate_auth(strict=False) - - def _run(self, **kwargs: Any) -> dict[str, Any]: + def _run(self, **kwargs: Any) -> str: raise NotImplementedError("Synchronous methods not supported by async tools.") - async def _arun(self, **kwargs: Any) -> dict[str, Any]: + async def _arun(self, **kwargs: Any) -> str: """ The coroutine that invokes the tool with the given arguments. @@ -154,140 +62,10 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]: A dictionary containing the parsed JSON response from the tool invocation. """ - - # If the tool had parameters that require authentication, then right - # before invoking that tool, we check whether all these required - # authentication sources have been registered or not. - self.__validate_auth() - - # Evaluate dynamic parameter values if any - evaluated_params = {} - for param_name, param_value in self.__bound_params.items(): - if callable(param_value): - evaluated_params[param_name] = param_value() - else: - evaluated_params[param_name] = param_value - - # Merge bound parameters with the provided arguments - kwargs.update(evaluated_params) - - return await _invoke_tool( - self.__url, self.__session, self.__name, kwargs, self.__auth_token_getters - ) - - def __validate_auth(self, strict: bool = True) -> None: - """ - Checks if a tool meets the authentication requirements. - - A tool is considered authenticated if all of its parameters meet at - least one of the following conditions: - - * The parameter has at least one registered authentication source. - * The parameter requires no authentication. - - Args: - strict: If True, raises a PermissionError if any required - authentication sources are not registered. If False, only issues - a warning. - - Raises: - PermissionError: If strict is True and any required authentication - sources are not registered. - """ - is_authenticated: bool = not self.__schema.authRequired - params_missing_auth: list[str] = [] - - # Check tool for at least 1 required auth source - for src in self.__schema.authRequired: - if src in self.__auth_token_getters: - is_authenticated = True - break - - # Check each parameter for at least 1 required auth source - for param in self.__auth_params: - if not param.authSources: - raise ValueError("Auth sources cannot be None.") - has_auth = False - for src in param.authSources: - - # Find first auth source that is specified - if src in self.__auth_token_getters: - has_auth = True - break - if not has_auth: - params_missing_auth.append(param.name) - - messages: list[str] = [] - - if not is_authenticated: - messages.append( - f"Tool {self.__name} requires authentication, but no valid authentication sources are registered. Please register the required sources before use." - ) - - if params_missing_auth: - messages.append( - f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name} require authentication, but no valid authentication sources are registered. Please register the required sources before use." - ) - - if messages: - message = "\n\n".join(messages) - if strict: - raise PermissionError(message) - warn(message) - - def __create_copy( - self, - *, - auth_token_getters: dict[str, Callable[[], str]] = {}, - bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool, - ) -> "AsyncToolboxTool": - """ - Creates a copy of the current AsyncToolboxTool instance, allowing for - modification of auth tokens and bound params. - - This method enables the creation of new tool instances with inherited - properties from the current instance, while optionally updating the auth - tokens and bound params. This is useful for creating variations of the - tool with additional auth tokens or bound params without modifying the - original instance, ensuring immutability. - - Args: - auth_token_getters: A dictionary of auth source names to functions - that retrieve ID tokens. These tokens will be merged with the - existing auth tokens. - bound_params: A dictionary of parameter names to their - bound values or functions to retrieve the values. These params - will be merged with the existing bound params. - strict: If True, raises a ValueError if any of the given bound - parameters is missing from the schema or requires - authentication. If False, only issues a warning. - - Returns: - A new AsyncToolboxTool instance that is a deep copy of the current - instance, with added auth tokens or bound params. - """ - new_schema = deepcopy(self.__schema) - - # Reconstruct the complete parameter schema by merging the auth - # parameters back with the non-auth parameters. This is necessary to - # accurately validate the new combination of auth tokens and bound - # params in the constructor of the new AsyncToolboxTool instance, ensuring - # that any overlaps or conflicts are correctly identified and reported - # as errors or warnings, depending on the given `strict` flag. - new_schema.parameters += self.__auth_params - return AsyncToolboxTool( - name=self.__name, - schema=new_schema, - url=self.__url, - session=self.__session, - auth_token_getters={**self.__auth_token_getters, **auth_token_getters}, - bound_params={**self.__bound_params, **bound_params}, - strict=strict, - ) + return await self.__core_tool(**kwargs) def add_auth_token_getters( - self, auth_token_getters: dict[str, Callable[[], str]], strict: bool = True + self, auth_token_getters: dict[str, Callable[[], str]] ) -> "AsyncToolboxTool": """ Registers functions to retrieve ID tokens for the corresponding @@ -296,36 +74,21 @@ def add_auth_token_getters( Args: auth_token_getters: A dictionary of authentication source names to the functions that return corresponding ID token getters. - strict: If True, a ValueError is raised if any of the provided auth - parameters is already bound. If False, only a warning is issued. Returns: A new AsyncToolboxTool instance that is a deep copy of the current - instance, with added auth tokens. + instance, with added auth token getters. Raises: ValueError: If any of the provided auth parameters is already registered. - ValueError: If any of the provided auth parameters is already bound - and strict is True. """ - - # Check if the authentication source is already registered. - dupe_tokens: list[str] = [] - for auth_token, _ in auth_token_getters.items(): - if auth_token in self.__auth_token_getters: - dupe_tokens.append(auth_token) - - if dupe_tokens: - raise ValueError( - f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name}`." - ) - - return self.__create_copy(auth_token_getters=auth_token_getters, strict=strict) + new_core_tool = self.__core_tool.add_auth_token_getters(auth_token_getters) + return AsyncToolboxTool(core_tool=new_core_tool) def add_auth_token_getter( - self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + self, auth_source: str, get_id_token: Callable[[], str] ) -> "AsyncToolboxTool": """ Registers a function to retrieve an ID token for a given authentication @@ -334,24 +97,32 @@ def add_auth_token_getter( Args: auth_source: The name of the authentication source. get_id_token: A function that returns the ID token. - strict: If True, a ValueError is raised if the provided auth - parameter is already bound. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current - instance, with added auth token. + instance, with added auth token getter. Raises: ValueError: If the provided auth parameter is already registered. - ValueError: If the provided auth parameter is already bound and - strict is True. + """ - return self.add_auth_token_getters({auth_source: get_id_token}, strict=strict) + return self.add_auth_token_getters({auth_source: get_id_token}) + + @deprecated("Please use `add_auth_token_getters` instead.") + def add_auth_tokens( + self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True + ) -> "AsyncToolboxTool": + return self.add_auth_token_getters(auth_tokens) + + @deprecated("Please use `add_auth_token_getter` instead.") + def add_auth_token( + self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + ) -> "AsyncToolboxTool": + return self.add_auth_token_getter(auth_source, get_id_token) def bind_params( self, bound_params: dict[str, Union[Any, Callable[[], Any]]], - strict: bool = True, ) -> "AsyncToolboxTool": """ Registers values or functions to retrieve the value for the @@ -360,9 +131,6 @@ def bind_params( Args: bound_params: A dictionary of the bound parameter name to the value or function of the bound value. - strict: If True, a ValueError is raised if any of the provided bound - params is not defined in the tool's schema, or requires - authentication. If False, only a warning is issued. Returns: A new AsyncToolboxTool instance that is a deep copy of the current @@ -370,29 +138,14 @@ def bind_params( Raises: ValueError: If any of the provided bound params is already bound. - ValueError: if any of the provided bound params is not defined in - the tool's schema, or requires authentication, and strict is - True. """ - - # Check if the parameter is already bound. - dupe_params: list[str] = [] - for param_name, _ in bound_params.items(): - if param_name in self.__bound_params: - dupe_params.append(param_name) - - if dupe_params: - raise ValueError( - f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self.__name}`." - ) - - return self.__create_copy(bound_params=bound_params, strict=strict) + new_core_tool = self.__core_tool.bind_params(bound_params) + return AsyncToolboxTool(core_tool=new_core_tool) def bind_param( self, param_name: str, param_value: Union[Any, Callable[[], Any]], - strict: bool = True, ) -> "AsyncToolboxTool": """ Registers a value or a function to retrieve the value for a given bound @@ -402,9 +155,6 @@ def bind_param( param_name: The name of the bound parameter. param_value: The value of the bound parameter, or a callable that returns the value. - strict: If True, a ValueError is raised if the provided bound param - is not defined in the tool's schema, or requires authentication. - If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current @@ -412,7 +162,5 @@ def bind_param( Raises: ValueError: If the provided bound param is already bound. - ValueError: if the provided bound param is not defined in the tool's - schema, or requires authentication, and strict is True. """ - return self.bind_params({param_name: param_value}, strict) + return self.bind_params({param_name: param_value}) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/client.py b/packages/toolbox-langchain/src/toolbox_langchain/client.py index 3c75779c..e253756f 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/client.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/client.py @@ -12,22 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -from threading import Thread -from typing import Any, Awaitable, Callable, Optional, TypeVar, Union +from asyncio import wrap_future +from typing import Any, Callable, Optional, Union +from warnings import warn -from aiohttp import ClientSession +from toolbox_core.sync_client import ToolboxSyncClient as ToolboxCoreSyncClient +from toolbox_core.sync_tool import ToolboxSyncTool -from .async_client import AsyncToolboxClient from .tools import ToolboxTool -T = TypeVar("T") - class ToolboxClient: - __session: Optional[ClientSession] = None - __loop: Optional[asyncio.AbstractEventLoop] = None - __thread: Optional[Thread] = None def __init__( self, @@ -39,51 +34,7 @@ def __init__( Args: url: The base URL of the Toolbox service. """ - - # Running a loop in a background thread allows us to support async - # methods from non-async environments. - if ToolboxClient.__loop is None: - loop = asyncio.new_event_loop() - thread = Thread(target=loop.run_forever, daemon=True) - thread.start() - ToolboxClient.__thread = thread - ToolboxClient.__loop = loop - - async def __start_session() -> None: - - # Use a default session if none is provided. This leverages connection - # pooling for better performance by reusing a single session throughout - # the application's lifetime. - if ToolboxClient.__session is None: - ToolboxClient.__session = ClientSession() - - coro = __start_session() - - asyncio.run_coroutine_threadsafe(coro, ToolboxClient.__loop).result() - - if not ToolboxClient.__session: - raise ValueError("Session cannot be None.") - self.__async_client = AsyncToolboxClient(url, ToolboxClient.__session) - - def __run_as_sync(self, coro: Awaitable[T]) -> T: - """Run an async coroutine synchronously""" - if not self.__loop: - raise Exception( - "Cannot call synchronous methods before the background loop is initialized." - ) - return asyncio.run_coroutine_threadsafe(coro, self.__loop).result() - - async def __run_as_async(self, coro: Awaitable[T]) -> T: - """Run an async coroutine asynchronously""" - - # If a loop has not been provided, attempt to run in current thread. - if not self.__loop: - return await coro - - # Otherwise, run in the background thread. - return await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__loop) - ) + self.__core_client = ToolboxCoreSyncClient(url=url) async def aload_tool( self, @@ -92,7 +43,6 @@ async def aload_tool( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, ) -> ToolboxTool: """ Loads the tool with the given tool name from the Toolbox service. @@ -105,27 +55,42 @@ async def aload_tool( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. Returns: A tool loaded from the Toolbox. """ - async_tool = await self.__run_as_async( - self.__async_client.aload_tool( - tool_name, - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - strict, - ) - ) - - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - return ToolboxTool(async_tool, self.__loop, self.__thread) + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens + + if auth_headers: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_headers + + core_tool = await wrap_future(self.__core_client.load_tool_future( + name=tool_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + )) + return ToolboxTool(core_tool=core_tool) async def aload_toolset( self, @@ -134,7 +99,7 @@ async def aload_toolset( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + strict: bool = False, ) -> list[ToolboxTool]: """ Loads tools from the Toolbox service, optionally filtered by toolset @@ -149,30 +114,51 @@ async def aload_toolset( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. + strict: If True, raises an error if *any* loaded tool instance fails + to utilize at least one provided parameter or auth token (if any + provided). If False (default), raises an error only if a + user-provided parameter or auth token cannot be applied to *any* + loaded tool across the set. Returns: A list of all tools loaded from the Toolbox. """ - async_tools = await self.__run_as_async( - self.__async_client.aload_toolset( - toolset_name, - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - strict, - ) - ) - - tools: list[ToolboxTool] = [] - - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - for async_tool in async_tools: - tools.append(ToolboxTool(async_tool, self.__loop, self.__thread)) + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens + + if auth_headers: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_headers + + core_tools = await wrap_future(self.__core_client.load_toolset_future( + name=toolset_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=strict, + )) + + tools = [] + for core_tool in core_tools: + tools.append(ToolboxTool(core_tool=core_tool)) return tools def load_tool( @@ -182,7 +168,6 @@ def load_tool( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, ) -> ToolboxTool: """ Loads the tool with the given tool name from the Toolbox service. @@ -195,27 +180,42 @@ def load_tool( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. Returns: A tool loaded from the Toolbox. """ - async_tool = self.__run_as_sync( - self.__async_client.aload_tool( - tool_name, - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - strict, - ) + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens + + if auth_headers: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_headers + + core_sync_tool = self.__core_client.load_tool( + name=tool_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, ) - - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - return ToolboxTool(async_tool, self.__loop, self.__thread) + return ToolboxTool(core_tool=core_sync_tool) def load_toolset( self, @@ -224,7 +224,7 @@ def load_toolset( auth_tokens: Optional[dict[str, Callable[[], str]]] = None, auth_headers: Optional[dict[str, Callable[[], str]]] = None, bound_params: dict[str, Union[Any, Callable[[], Any]]] = {}, - strict: bool = True, + strict: bool = False, ) -> list[ToolboxTool]: """ Loads tools from the Toolbox service, optionally filtered by toolset @@ -239,27 +239,49 @@ def load_toolset( auth_headers: Deprecated. Use `auth_token_getters` instead. bound_params: An optional mapping of parameter names to their bound values. - strict: If True, raises a ValueError if any of the given bound - parameters are missing from the schema or require - authentication. If False, only issues a warning. + strict: If True, raises an error if *any* loaded tool instance fails + to utilize at least one provided parameter or auth token (if any + provided). If False (default), raises an error only if a + user-provided parameter or auth token cannot be applied to *any* + loaded tool across the set. Returns: A list of all tools loaded from the Toolbox. """ - async_tools = self.__run_as_sync( - self.__async_client.aload_toolset( - toolset_name, - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - strict, - ) + if auth_tokens: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_tokens + + if auth_headers: + if auth_token_getters: + warn( + "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used.", + DeprecationWarning, + ) + else: + warn( + "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.", + DeprecationWarning, + ) + auth_token_getters = auth_headers + + core_sync_tools = self.__core_client.load_toolset( + name=toolset_name, + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=strict, ) - if not self.__loop or not self.__thread: - raise ValueError("Background loop or thread cannot be None.") - tools: list[ToolboxTool] = [] - for async_tool in async_tools: - tools.append(ToolboxTool(async_tool, self.__loop, self.__thread)) + tools = [] + for core_sync_tool in core_sync_tools: + tools.append(ToolboxTool(core_tool=core_sync_tool)) return tools diff --git a/packages/toolbox-langchain/src/toolbox_langchain/tools.py b/packages/toolbox-langchain/src/toolbox_langchain/tools.py index feb2a597..6cca388a 100644 --- a/packages/toolbox-langchain/src/toolbox_langchain/tools.py +++ b/packages/toolbox-langchain/src/toolbox_langchain/tools.py @@ -12,16 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -from asyncio import AbstractEventLoop -from threading import Thread -from typing import Any, Awaitable, Callable, TypeVar, Union +from asyncio import wrap_future +from typing import Any, Callable, Union +from deprecated import deprecated from langchain_core.tools import BaseTool - -from .async_tools import AsyncToolboxTool - -T = TypeVar("T") +from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool class ToolboxTool(BaseTool): @@ -32,59 +28,32 @@ class ToolboxTool(BaseTool): def __init__( self, - async_tool: AsyncToolboxTool, - loop: AbstractEventLoop, - thread: Thread, + core_tool: ToolboxCoreSyncTool, ) -> None: """ Initializes a ToolboxTool instance. Args: - async_tool: The underlying AsyncToolboxTool instance. - loop: The event loop used to run asynchronous tasks. - thread: The thread to run blocking operations in. + core_tool: The underlying core sync ToolboxTool instance. """ # Due to how pydantic works, we must initialize the underlying # BaseTool class before assigning values to member variables. super().__init__( - name=async_tool.name, - description=async_tool.description, - args_schema=async_tool.args_schema, + name=core_tool.__name__, + description=core_tool.__doc__, + args_schema=core_tool._async_tool._pydantic_model, ) + self.__core_tool = core_tool - self.__async_tool = async_tool - self.__loop = loop - self.__thread = thread - - def __run_as_sync(self, coro: Awaitable[T]) -> T: - """Run an async coroutine synchronously""" - if not self.__loop: - raise Exception( - "Cannot call synchronous methods before the background loop is initialized." - ) - return asyncio.run_coroutine_threadsafe(coro, self.__loop).result() + def _run(self, **kwargs: Any) -> str: + return self.__core_tool(**kwargs) - async def __run_as_async(self, coro: Awaitable[T]) -> T: - """Run an async coroutine asynchronously""" - - # If a loop has not been provided, attempt to run in current thread. - if not self.__loop: - return await coro - - # Otherwise, run in the background thread. - return await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self.__loop) - ) - - def _run(self, **kwargs: Any) -> dict[str, Any]: - return self.__run_as_sync(self.__async_tool._arun(**kwargs)) - - async def _arun(self, **kwargs: Any) -> dict[str, Any]: - return await self.__run_as_async(self.__async_tool._arun(**kwargs)) + async def _arun(self, **kwargs: Any) -> str: + return await wrap_future(self.__core_tool.call_future(**kwargs)) def add_auth_token_getters( - self, auth_token_getters: dict[str, Callable[[], str]], strict: bool = True + self, auth_token_getters: dict[str, Callable[[], str]] ) -> "ToolboxTool": """ Registers functions to retrieve ID tokens for the corresponding @@ -93,27 +62,20 @@ def add_auth_token_getters( Args: auth_token_getters: A dictionary of authentication source names to the functions that return corresponding ID token. - strict: If True, a ValueError is raised if any of the provided auth - parameters is already bound. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current - instance, with added auth tokens. + instance, with added auth token getters. Raises: ValueError: If any of the provided auth parameters is already registered. - ValueError: If any of the provided auth parameters is already bound - and strict is True. """ - return ToolboxTool( - self.__async_tool.add_auth_token_getters(auth_token_getters, strict), - self.__loop, - self.__thread, - ) + new_core_tool = self.__core_tool.add_auth_token_getters(auth_token_getters) + return ToolboxTool(core_tool=new_core_tool) def add_auth_token_getter( - self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + self, auth_source: str, get_id_token: Callable[[], str] ) -> "ToolboxTool": """ Registers a function to retrieve an ID token for a given authentication @@ -122,28 +84,31 @@ def add_auth_token_getter( Args: auth_source: The name of the authentication source. get_id_token: A function that returns the ID token. - strict: If True, a ValueError is raised if the provided auth - parameter is already bound. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current - instance, with added auth token. + instance, with added auth token getter. Raises: ValueError: If the provided auth parameter is already registered. - ValueError: If the provided auth parameter is already bound and - strict is True. """ - return ToolboxTool( - self.__async_tool.add_auth_token_getter(auth_source, get_id_token, strict), - self.__loop, - self.__thread, - ) + return self.add_auth_token_getters({auth_source: get_id_token}) + + @deprecated("Please use `add_auth_token_getters` instead.") + def add_auth_tokens( + self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True + ) -> "ToolboxTool": + return self.add_auth_token_getters(auth_tokens) + + @deprecated("Please use `add_auth_token_getter` instead.") + def add_auth_token( + self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True + ) -> "ToolboxTool": + return self.add_auth_token_getter(auth_source, get_id_token) def bind_params( self, bound_params: dict[str, Union[Any, Callable[[], Any]]], - strict: bool = True, ) -> "ToolboxTool": """ Registers values or functions to retrieve the value for the @@ -152,9 +117,6 @@ def bind_params( Args: bound_params: A dictionary of the bound parameter name to the value or function of the bound value. - strict: If True, a ValueError is raised if any of the provided bound - params is not defined in the tool's schema, or requires - authentication. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current @@ -162,21 +124,14 @@ def bind_params( Raises: ValueError: If any of the provided bound params is already bound. - ValueError: if any of the provided bound params is not defined in - the tool's schema, or require authentication, and strict is - True. """ - return ToolboxTool( - self.__async_tool.bind_params(bound_params, strict), - self.__loop, - self.__thread, - ) + new_core_tool = self.__core_tool.bind_params(bound_params) + return ToolboxTool(core_tool=new_core_tool) def bind_param( self, param_name: str, param_value: Union[Any, Callable[[], Any]], - strict: bool = True, ) -> "ToolboxTool": """ Registers a value or a function to retrieve the value for a given bound @@ -186,9 +141,6 @@ def bind_param( param_name: The name of the bound parameter. param_value: The value of the bound parameter, or a callable that returns the value. - strict: If True, a ValueError is raised if the provided bound - param is not defined in the tool's schema, or requires - authentication. If False, only a warning is issued. Returns: A new ToolboxTool instance that is a deep copy of the current @@ -196,11 +148,5 @@ def bind_param( Raises: ValueError: If the provided bound param is already bound. - ValueError: if the provided bound param is not defined in the tool's - schema, or requires authentication, and strict is True. """ - return ToolboxTool( - self.__async_tool.bind_param(param_name, param_value, strict), - self.__loop, - self.__thread, - ) + return self.bind_params({param_name: param_value}) diff --git a/packages/toolbox-langchain/src/toolbox_langchain/utils.py b/packages/toolbox-langchain/src/toolbox_langchain/utils.py deleted file mode 100644 index 985c7bfe..00000000 --- a/packages/toolbox-langchain/src/toolbox_langchain/utils.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from typing import Any, Callable, Optional, Type, cast -from warnings import warn - -from aiohttp import ClientSession -from deprecated import deprecated -from langchain_core.tools import ToolException -from pydantic import BaseModel, Field, create_model - - -class ParameterSchema(BaseModel): - """ - Schema for a tool parameter. - """ - - name: str - type: str - description: str - authSources: Optional[list[str]] = None - items: Optional["ParameterSchema"] = None - - -class ToolSchema(BaseModel): - """ - Schema for a tool. - """ - - description: str - parameters: list[ParameterSchema] - authRequired: list[str] = [] - - -class ManifestSchema(BaseModel): - """ - Schema for the Toolbox manifest. - """ - - serverVersion: str - tools: dict[str, ToolSchema] - - -async def _load_manifest(url: str, session: ClientSession) -> ManifestSchema: - """ - Asynchronously fetches and parses the JSON manifest schema from the given - URL. - - Args: - url: The URL to fetch the JSON from. - session: The HTTP client session. - - Returns: - The parsed Toolbox manifest. - - Raises: - json.JSONDecodeError: If the response is not valid JSON. - ValueError: If the response is not a valid manifest. - """ - async with session.get(url) as response: - # TODO: Remove as it masks error messages. - response.raise_for_status() - try: - # TODO: Simply use response.json() - parsed_json = json.loads(await response.text()) - except json.JSONDecodeError as e: - raise json.JSONDecodeError( - f"Failed to parse JSON from {url}: {e}", e.doc, e.pos - ) from e - try: - return ManifestSchema(**parsed_json) - except ValueError as e: - raise ValueError(f"Invalid JSON data from {url}: {e}") from e - - -def _schema_to_model(model_name: str, schema: list[ParameterSchema]) -> Type[BaseModel]: - """ - Converts the given manifest schema to a Pydantic BaseModel class. - - Args: - model_name: The name of the model to create. - schema: The schema to convert. - - Returns: - A Pydantic BaseModel class. - """ - field_definitions = {} - for field in schema: - field_definitions[field.name] = cast( - Any, - ( - _parse_type(field), - Field(description=field.description), - ), - ) - - return create_model(model_name, **field_definitions) - - -def _parse_type(schema_: ParameterSchema) -> Any: - """ - Converts a schema type to a JSON type. - - Args: - schema_: The ParameterSchema to convert. - - Returns: - A valid JSON type. - - Raises: - ValueError: If the given type is not supported. - """ - type_ = schema_.type - - if type_ == "string": - return str - elif type_ == "integer": - return int - elif type_ == "float": - return float - elif type_ == "boolean": - return bool - elif type_ == "array": - if isinstance(schema_, ParameterSchema) and schema_.items: - return list[_parse_type(schema_.items)] # type: ignore - else: - raise ValueError(f"Schema missing field items") - else: - raise ValueError(f"Unsupported schema type: {type_}") - - -@deprecated("Please use `_get_auth_tokens` instead.") -def _get_auth_headers(id_token_getters: dict[str, Callable[[], str]]) -> dict[str, str]: - """ - Deprecated. Use `_get_auth_tokens` instead. - """ - return _get_auth_tokens(id_token_getters) - - -def _get_auth_tokens(id_token_getters: dict[str, Callable[[], str]]) -> dict[str, str]: - """ - Gets ID tokens for the given auth sources in the getters map and returns - tokens to be included in tool invocation. - - Args: - id_token_getters: A dict that maps auth source names to the functions - that return its ID token. - - Returns: - A dictionary of tokens to be included in the tool invocation. - """ - auth_tokens = {} - for auth_source, get_id_token in id_token_getters.items(): - auth_tokens[f"{auth_source}_token"] = get_id_token() - return auth_tokens - - -async def _invoke_tool( - url: str, - session: ClientSession, - tool_name: str, - data: dict, - id_token_getters: dict[str, Callable[[], str]], -) -> dict: - """ - Asynchronously makes an API call to the Toolbox service to invoke a tool. - - Args: - url: The base URL of the Toolbox service. - session: The HTTP client session. - tool_name: The name of the tool to invoke. - data: The input data for the tool. - id_token_getters: A dict that maps auth source names to the functions - that return its ID token. - - Returns: - A dictionary containing the parsed JSON response from the tool - invocation. - - Raises: - ToolException: If the Toolbox service returns an error. - """ - url = f"{url}/api/tool/{tool_name}/invoke" - auth_tokens = _get_auth_tokens(id_token_getters) - - # ID tokens contain sensitive user information (claims). Transmitting these - # over HTTP exposes the data to interception and unauthorized access. Always - # use HTTPS to ensure secure communication and protect user privacy. - if auth_tokens and not url.startswith("https://"): - warn( - "Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication." - ) - - async with session.post( - url, - json=data, - headers=auth_tokens, - ) as response: - ret = await response.json() - if "error" in ret: - raise ToolException(ret) - return ret.get("result", ret) - - -def _find_auth_params( - params: list[ParameterSchema], -) -> tuple[list[ParameterSchema], list[ParameterSchema]]: - """ - Separates parameters into those that are authenticated and those that are not. - - Args: - params: A list of ParameterSchema objects. - - Returns: - A tuple containing two lists: - - auth_params: A list of ParameterSchema objects that require authentication. - - non_auth_params: A list of ParameterSchema objects that do not require authentication. - """ - _auth_params: list[ParameterSchema] = [] - _non_auth_params: list[ParameterSchema] = [] - - for param in params: - if param.authSources: - _auth_params.append(param) - else: - _non_auth_params.append(param) - - return (_auth_params, _non_auth_params) - - -def _find_bound_params( - params: list[ParameterSchema], bound_params: list[str] -) -> tuple[list[ParameterSchema], list[ParameterSchema]]: - """ - Separates parameters into those that are bound and those that are not. - - Args: - params: A list of ParameterSchema objects. - bound_params: A list of parameter names that are bound. - - Returns: - A tuple containing two lists: - - bound_params: A list of ParameterSchema objects whose names are in the bound_params list. - - non_bound_params: A list of ParameterSchema objects whose names are not in the bound_params list. - """ - - _bound_params: list[ParameterSchema] = [] - _non_bound_params: list[ParameterSchema] = [] - - for param in params: - if param.name in bound_params: - _bound_params.append(param) - else: - _non_bound_params.append(param) - - return (_bound_params, _non_bound_params) diff --git a/packages/toolbox-langchain/tests/test_async_client.py b/packages/toolbox-langchain/tests/test_async_client.py index 25ad78eb..988d3974 100644 --- a/packages/toolbox-langchain/tests/test_async_client.py +++ b/packages/toolbox-langchain/tests/test_async_client.py @@ -17,10 +17,14 @@ import pytest from aiohttp import ClientSession +from toolbox_core.client import ToolboxClient as ToolboxCoreClient +from toolbox_core.protocol import ManifestSchema +from toolbox_core.protocol import ParameterSchema as CoreParameterSchema +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool +from toolbox_core.utils import params_to_pydantic_model from toolbox_langchain.async_client import AsyncToolboxClient from toolbox_langchain.async_tools import AsyncToolboxTool -from toolbox_langchain.utils import ManifestSchema URL = "http://test_url" MANIFEST_JSON = { @@ -60,123 +64,200 @@ def manifest_schema(self): def mock_session(self): return AsyncMock(spec=ClientSession) + @pytest.fixture + def mock_core_client_instance(self, manifest_schema, mock_session): + mock = AsyncMock(spec=ToolboxCoreClient) + + async def mock_load_tool_impl(name, auth_token_getters, bound_params): + tool_schema_dict = MANIFEST_JSON["tools"].get(name) + if not tool_schema_dict: + raise ValueError(f"Tool '{name}' not in mock manifest_dict") + + core_params = [ + CoreParameterSchema(**p) for p in tool_schema_dict["parameters"] + ] + # Return a mock that looks like toolbox_core.tool.ToolboxTool + core_tool_mock = AsyncMock(spec=ToolboxCoreTool) + core_tool_mock.__name__ = name + core_tool_mock.__doc__ = tool_schema_dict["description"] + core_tool_mock._pydantic_model = params_to_pydantic_model(name, core_params) + # Add other necessary attributes or method mocks if AsyncToolboxTool uses them + return core_tool_mock + + mock.load_tool = AsyncMock(side_effect=mock_load_tool_impl) + + async def mock_load_toolset_impl( + name, auth_token_getters, bound_params, strict + ): + core_tools_list = [] + for tool_name_iter, tool_schema_dict in MANIFEST_JSON["tools"].items(): + core_params = [ + CoreParameterSchema(**p) for p in tool_schema_dict["parameters"] + ] + core_tool_mock = AsyncMock(spec=ToolboxCoreTool) + core_tool_mock.__name__ = tool_name_iter + core_tool_mock.__doc__ = tool_schema_dict["description"] + core_tool_mock._pydantic_model = params_to_pydantic_model( + tool_name_iter, core_params + ) + core_tools_list.append(core_tool_mock) + return core_tools_list + + mock.load_toolset = AsyncMock(side_effect=mock_load_toolset_impl) + # Mock the session attribute if it's directly accessed by AsyncToolboxClient tests + mock._ToolboxClient__session = mock_session + return mock + @pytest.fixture() - def mock_client(self, mock_session): - return AsyncToolboxClient(URL, session=mock_session) + def mock_client(self, mock_session, mock_core_client_instance): + # Patch the ToolboxCoreClient constructor used by AsyncToolboxClient + with patch( + "toolbox_langchain.async_client.ToolboxCoreClient", + return_value=mock_core_client_instance, + ): + client = AsyncToolboxClient(URL, session=mock_session) + # Ensure the mocked core client is used + client._AsyncToolboxClient__core_client = mock_core_client_instance + return client async def test_create_with_existing_session(self, mock_client, mock_session): - assert mock_client._AsyncToolboxClient__session == mock_session + # AsyncToolboxClient stores the core_client, which stores the session + assert ( + mock_client._AsyncToolboxClient__core_client._ToolboxClient__session + == mock_session + ) - @patch("toolbox_langchain.async_client._load_manifest") async def test_aload_tool( - self, mock_load_manifest, mock_client, mock_session, manifest_schema + self, + mock_client, + manifest_schema, # mock_session removed as it's part of mock_core_client_instance ): tool_name = "test_tool_1" - mock_load_manifest.return_value = manifest_schema + # manifest_schema is used by mock_core_client_instance fixture to provide tool details tool = await mock_client.aload_tool(tool_name) - mock_load_manifest.assert_called_once_with( - f"{URL}/api/tool/{tool_name}", mock_session + # Assert that the core client's load_tool was called correctly + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, auth_token_getters={}, bound_params={} ) assert isinstance(tool, AsyncToolboxTool) - assert tool.name == tool_name + assert ( + tool.name == tool_name + ) # AsyncToolboxTool gets its name from the core_tool - @patch("toolbox_langchain.async_client._load_manifest") async def test_aload_tool_auth_headers_deprecated( - self, mock_load_manifest, mock_client, manifest_schema + self, mock_client, manifest_schema ): tool_name = "test_tool_1" - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + auth_lambda = lambda: "Bearer token" # Define lambda once with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_tool( - tool_name, auth_headers={"Authorization": lambda: "Bearer token"} + tool_name, + auth_headers={"Authorization": auth_lambda}, # Use the defined lambda ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) assert "auth_headers" in str(w[-1].message) - @patch("toolbox_langchain.async_client._load_manifest") + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, + auth_token_getters={"Authorization": auth_lambda}, + bound_params={}, + ) + async def test_aload_tool_auth_headers_and_tokens( - self, mock_load_manifest, mock_client, manifest_schema + self, mock_client, manifest_schema ): tool_name = "test_tool_1" - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + auth_getters = {"test": lambda: "token"} + auth_headers_lambda = lambda: "Bearer token" # Define lambda once + with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_tool( tool_name, - auth_headers={"Authorization": lambda: "Bearer token"}, - auth_token_getters={"test": lambda: "token"}, + auth_headers={ + "Authorization": auth_headers_lambda + }, # Use defined lambda + auth_token_getters=auth_getters, ) - assert len(w) == 1 + assert ( + len(w) == 1 + ) # Only one warning because auth_token_getters takes precedence assert issubclass(w[-1].category, DeprecationWarning) - assert "auth_headers" in str(w[-1].message) + assert "auth_headers" in str(w[-1].message) # Warning for auth_headers + + mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with( + name=tool_name, auth_token_getters=auth_getters, bound_params={} + ) - @patch("toolbox_langchain.async_client._load_manifest") async def test_aload_toolset( - self, mock_load_manifest, mock_client, mock_session, manifest_schema + self, mock_client, manifest_schema # mock_session removed ): - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest tools = await mock_client.aload_toolset() - mock_load_manifest.assert_called_once_with(f"{URL}/api/toolset/", mock_session) - assert len(tools) == 2 + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, auth_token_getters={}, bound_params={}, strict=False + ) + assert len(tools) == 2 # Based on MANIFEST_JSON for tool in tools: assert isinstance(tool, AsyncToolboxTool) assert tool.name in ["test_tool_1", "test_tool_2"] - @patch("toolbox_langchain.async_client._load_manifest") async def test_aload_toolset_with_toolset_name( - self, mock_load_manifest, mock_client, mock_session, manifest_schema + self, mock_client, manifest_schema # mock_session removed ): - toolset_name = "test_toolset_1" - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + toolset_name = "test_toolset_1" # This name isn't in MANIFEST_JSON, but load_toolset mock doesn't filter by it tools = await mock_client.aload_toolset(toolset_name=toolset_name) - mock_load_manifest.assert_called_once_with( - f"{URL}/api/toolset/{toolset_name}", mock_session + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=toolset_name, auth_token_getters={}, bound_params={}, strict=False ) assert len(tools) == 2 for tool in tools: assert isinstance(tool, AsyncToolboxTool) assert tool.name in ["test_tool_1", "test_tool_2"] - @patch("toolbox_langchain.async_client._load_manifest") async def test_aload_toolset_auth_headers_deprecated( - self, mock_load_manifest, mock_client, manifest_schema + self, mock_client, manifest_schema ): - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + auth_lambda = lambda: "Bearer token" # Define lambda once with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_toolset( - auth_headers={"Authorization": lambda: "Bearer token"} + auth_headers={"Authorization": auth_lambda} # Use defined lambda ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) assert "auth_headers" in str(w[-1].message) + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, + auth_token_getters={"Authorization": auth_lambda}, + bound_params={}, + strict=False, + ) - @patch("toolbox_langchain.async_client._load_manifest") async def test_aload_toolset_auth_headers_and_tokens( - self, mock_load_manifest, mock_client, manifest_schema + self, mock_client, manifest_schema ): - mock_manifest = manifest_schema - mock_load_manifest.return_value = mock_manifest + auth_getters = {"test": lambda: "token"} + auth_headers_lambda = lambda: "Bearer token" # Define lambda once with catch_warnings(record=True) as w: simplefilter("always") await mock_client.aload_toolset( - auth_headers={"Authorization": lambda: "Bearer token"}, - auth_token_getters={"test": lambda: "token"}, + auth_headers={ + "Authorization": auth_headers_lambda + }, # Use defined lambda + auth_token_getters=auth_getters, ) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning) assert "auth_headers" in str(w[-1].message) + mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with( + name=None, auth_token_getters=auth_getters, bound_params={}, strict=False + ) async def test_load_tool_not_implemented(self, mock_client): with pytest.raises(NotImplementedError) as excinfo: diff --git a/packages/toolbox-langchain/tests/test_async_tools.py b/packages/toolbox-langchain/tests/test_async_tools.py index e23aee85..96bd7660 100644 --- a/packages/toolbox-langchain/tests/test_async_tools.py +++ b/packages/toolbox-langchain/tests/test_async_tools.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import types # For MappingProxyType from unittest.mock import AsyncMock, Mock, patch import pytest import pytest_asyncio from pydantic import ValidationError +from toolbox_core.protocol import ParameterSchema as CoreParameterSchema +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool from toolbox_langchain.async_tools import AsyncToolboxTool @@ -24,7 +27,7 @@ @pytest.mark.asyncio class TestAsyncToolboxTool: @pytest.fixture - def tool_schema(self): + def tool_schema_dict(self): return { "description": "Test Tool Description", "parameters": [ @@ -34,9 +37,10 @@ def tool_schema(self): } @pytest.fixture - def auth_tool_schema(self): + def auth_tool_schema_dict(self): return { "description": "Test Tool Description", + "authRequired": ["test-auth-source"], "parameters": [ { "name": "param1", @@ -48,133 +52,193 @@ def auth_tool_schema(self): ], } + def _create_core_tool_from_dict( + self, session, name, schema_dict, url, initial_auth_getters=None + ): + core_params_schemas = [ + CoreParameterSchema(**p) for p in schema_dict["parameters"] + ] + + tool_constructor_params = [] + required_authn_for_core = {} + for p_schema in core_params_schemas: + if p_schema.authSources: + required_authn_for_core[p_schema.name] = p_schema.authSources + else: + tool_constructor_params.append(p_schema) + + return ToolboxCoreTool( + session=session, + base_url=url, + name=name, + description=schema_dict["description"], + params=tool_constructor_params, + required_authn_params=types.MappingProxyType(required_authn_for_core), + required_authz_tokens=schema_dict.get("authRequired", []), + auth_service_token_getters=types.MappingProxyType( + initial_auth_getters or {} + ), + bound_params=types.MappingProxyType({}), + client_headers=types.MappingProxyType({}), + ) + @pytest_asyncio.fixture @patch("aiohttp.ClientSession") - async def toolbox_tool(self, MockClientSession, tool_schema): + async def toolbox_tool(self, MockClientSession, tool_schema_dict): mock_session = MockClientSession.return_value - mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() - mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( - return_value={"result": "test-result"} - ) - tool = AsyncToolboxTool( + mock_response = mock_session.post.return_value.__aenter__.return_value + mock_response.raise_for_status = Mock() + mock_response.json = AsyncMock(return_value={"result": "test-result"}) + mock_response.status = 200 # *** Fix: Set status for the mock response *** + + core_tool_instance = self._create_core_tool_from_dict( + session=mock_session, name="test_tool", - schema=tool_schema, + schema_dict=tool_schema_dict, url="http://test_url", - session=mock_session, ) + tool = AsyncToolboxTool(core_tool=core_tool_instance) return tool @pytest_asyncio.fixture @patch("aiohttp.ClientSession") - async def auth_toolbox_tool(self, MockClientSession, auth_tool_schema): + async def auth_toolbox_tool(self, MockClientSession, auth_tool_schema_dict): mock_session = MockClientSession.return_value - mock_session.post.return_value.__aenter__.return_value.raise_for_status = Mock() - mock_session.post.return_value.__aenter__.return_value.json = AsyncMock( - return_value={"result": "test-result"} + mock_response = mock_session.post.return_value.__aenter__.return_value + mock_response.raise_for_status = Mock() + mock_response.json = AsyncMock(return_value={"result": "test-result"}) + mock_response.status = 200 # *** Fix: Set status for the mock response *** + + core_tool_instance = self._create_core_tool_from_dict( + session=mock_session, + name="test_tool", + schema_dict=auth_tool_schema_dict, + url="https://test-url", ) - with pytest.warns( - UserWarning, - match=r"Parameter\(s\) `param1` of tool test_tool require authentication", - ): - tool = AsyncToolboxTool( - name="test_tool", - schema=auth_tool_schema, - url="https://test-url", - session=mock_session, - ) + tool = AsyncToolboxTool(core_tool=core_tool_instance) return tool @patch("aiohttp.ClientSession") - async def test_toolbox_tool_init(self, MockClientSession, tool_schema): + async def test_toolbox_tool_init(self, MockClientSession, tool_schema_dict): mock_session = MockClientSession.return_value - tool = AsyncToolboxTool( + mock_response = mock_session.post.return_value.__aenter__.return_value + mock_response.status = 200 + core_tool_instance = self._create_core_tool_from_dict( + session=mock_session, name="test_tool", - schema=tool_schema, + schema_dict=tool_schema_dict, url="https://test-url", - session=mock_session, ) + tool = AsyncToolboxTool(core_tool=core_tool_instance) assert tool.name == "test_tool" - assert tool.description == "Test Tool Description" + assert tool.description == core_tool_instance.__doc__ @pytest.mark.parametrize( - "params, expected_bound_params", + "params_to_bind", [ - ({"param1": "bound-value"}, {"param1": "bound-value"}), - ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), - ( - {"param1": "bound-value", "param2": 123}, - {"param1": "bound-value", "param2": 123}, - ), + ({"param1": "bound-value"}), + ({"param1": lambda: "bound-value"}), + ({"param1": "bound-value", "param2": 123}), ], ) - async def test_toolbox_tool_bind_params( - self, toolbox_tool, params, expected_bound_params - ): - tool = toolbox_tool.bind_params(params) - for key, value in expected_bound_params.items(): - if callable(value): - assert value() == tool._AsyncToolboxTool__bound_params[key]() - else: - assert value == tool._AsyncToolboxTool__bound_params[key] - - @pytest.mark.parametrize("strict", [True, False]) - async def test_toolbox_tool_bind_params_invalid(self, toolbox_tool, strict): - if strict: - with pytest.raises(ValueError) as e: - tool = toolbox_tool.bind_params( - {"param3": "bound-value"}, strict=strict - ) - assert "Parameter(s) param3 missing and cannot be bound." in str(e.value) - else: - with pytest.warns(UserWarning) as record: - tool = toolbox_tool.bind_params( - {"param3": "bound-value"}, strict=strict - ) - assert len(record) == 1 - assert "Parameter(s) param3 missing and cannot be bound." in str( - record[0].message + async def test_toolbox_tool_bind_params(self, toolbox_tool, params_to_bind): + original_core_tool = toolbox_tool._AsyncToolboxTool__core_tool + with patch.object( + original_core_tool, "bind_params", wraps=original_core_tool.bind_params + ) as mock_core_bind_params: + new_langchain_tool = toolbox_tool.bind_params(params_to_bind) + mock_core_bind_params.assert_called_once_with(params_to_bind) + assert isinstance( + new_langchain_tool._AsyncToolboxTool__core_tool, ToolboxCoreTool + ) + new_core_tool_signature_params = ( + new_langchain_tool._AsyncToolboxTool__core_tool.__signature__.parameters ) + for bound_param_name in params_to_bind.keys(): + assert bound_param_name not in new_core_tool_signature_params + + async def test_toolbox_tool_bind_params_invalid(self, toolbox_tool): + with pytest.raises( + ValueError, match="unable to bind parameters: no parameter named param3" + ): + toolbox_tool.bind_params({"param3": "bound-value"}) async def test_toolbox_tool_bind_params_duplicate(self, toolbox_tool): tool = toolbox_tool.bind_params({"param1": "bound-value"}) - with pytest.raises(ValueError) as e: - tool = tool.bind_params({"param1": "bound-value"}) - assert "Parameter(s) `param1` already bound in tool `test_tool`." in str( - e.value - ) + with pytest.raises( + ValueError, + match="cannot re-bind parameter: parameter 'param1' is already bound", + ): + tool.bind_params({"param1": "bound-value"}) async def test_toolbox_tool_bind_params_invalid_params(self, auth_toolbox_tool): - with pytest.raises(ValueError) as e: + auth_core_tool = auth_toolbox_tool._AsyncToolboxTool__core_tool + # Verify that 'param1' is not in the list of bindable parameters for the core tool + # because it requires authentication. + assert "param1" not in [p.name for p in auth_core_tool._ToolboxTool__params] + with pytest.raises( + ValueError, match="unable to bind parameters: no parameter named param1" + ): auth_toolbox_tool.bind_params({"param1": "bound-value"}) - assert "Parameter(s) param1 already authenticated and cannot be bound." in str( - e.value + + async def test_toolbox_tool_add_valid_auth_token_getter(self, auth_toolbox_tool): + get_token_lambda = lambda: "test-token-value" + original_core_tool = auth_toolbox_tool._AsyncToolboxTool__core_tool + with patch.object( + original_core_tool, + "add_auth_token_getters", + wraps=original_core_tool.add_auth_token_getters, + ) as mock_core_add_getters: + tool = auth_toolbox_tool.add_auth_token_getters( + {"test-auth-source": get_token_lambda} + ) + mock_core_add_getters.assert_called_once_with( + {"test-auth-source": get_token_lambda} + ) + core_tool_after_add = tool._AsyncToolboxTool__core_tool + assert ( + "test-auth-source" + in core_tool_after_add._ToolboxTool__auth_service_token_getters + ) + assert ( + core_tool_after_add._ToolboxTool__auth_service_token_getters[ + "test-auth-source" + ] + is get_token_lambda + ) + assert not core_tool_after_add._ToolboxTool__required_authn_params.get( + "param1" + ) + assert ( + "test-auth-source" + not in core_tool_after_add._ToolboxTool__required_authz_tokens + ) + + async def test_toolbox_tool_add_unused_auth_token_getter_raises_error( + self, auth_toolbox_tool + ): + unused_lambda = lambda: "another-token" + with pytest.raises(ValueError) as excinfo: + auth_toolbox_tool.add_auth_token_getters( + {"another-auth-source": unused_lambda} + ) + assert ( + "Authentication source(s) `another-auth-source` unused by tool `test_tool`" + in str(excinfo.value) ) - @pytest.mark.parametrize( - "auth_token_getters, expected_auth_token_getters", - [ - ( - {"test-auth-source": lambda: "test-token"}, - {"test-auth-source": lambda: "test-token"}, - ), - ( + valid_lambda = lambda: "test-token" + with pytest.raises(ValueError) as excinfo_mixed: + auth_toolbox_tool.add_auth_token_getters( { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", - }, - { - "test-auth-source": lambda: "test-token", - "another-auth-source": lambda: "another-token", - }, - ), - ], - ) - async def test_toolbox_tool_add_auth_token_getters( - self, auth_toolbox_tool, auth_token_getters, expected_auth_token_getters - ): - tool = auth_toolbox_tool.add_auth_token_getters(auth_token_getters) - for source, getter in expected_auth_token_getters.items(): - assert tool._AsyncToolboxTool__auth_token_getters[source]() == getter() + "test-auth-source": valid_lambda, + "another-auth-source": unused_lambda, + } + ) + assert ( + "Authentication source(s) `another-auth-source` unused by tool `test_tool`" + in str(excinfo_mixed.value) + ) async def test_toolbox_tool_add_auth_token_getters_duplicate( self, auth_toolbox_tool @@ -182,45 +246,44 @@ async def test_toolbox_tool_add_auth_token_getters_duplicate( tool = auth_toolbox_tool.add_auth_token_getters( {"test-auth-source": lambda: "test-token"} ) - with pytest.raises(ValueError) as e: - tool = tool.add_auth_token_getters( - {"test-auth-source": lambda: "test-token"} - ) - assert ( - "Authentication source(s) `test-auth-source` already registered in tool `test_tool`." - in str(e.value) - ) + with pytest.raises( + ValueError, + match="Authentication source\\(s\\) `test-auth-source` already registered in tool `test_tool`\\.", + ): + tool.add_auth_token_getters({"test-auth-source": lambda: "test-token"}) - async def test_toolbox_tool_validate_auth_strict(self, auth_toolbox_tool): - with pytest.raises(PermissionError) as e: - auth_toolbox_tool._AsyncToolboxTool__validate_auth(strict=True) - assert "Parameter(s) `param1` of tool test_tool require authentication" in str( - e.value - ) + async def test_toolbox_tool_call_requires_auth_strict(self, auth_toolbox_tool): + with pytest.raises( + PermissionError, + match="One or more of the following authn services are required to invoke this tool: test-auth-source", + ): + await auth_toolbox_tool.ainvoke({"param2": 123}) async def test_toolbox_tool_call(self, toolbox_tool): result = await toolbox_tool.ainvoke({"param1": "test-value", "param2": 123}) assert result == "test-result" - toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + core_tool = toolbox_tool._AsyncToolboxTool__core_tool + core_tool._ToolboxTool__session.post.assert_called_once_with( "http://test_url/api/tool/test_tool/invoke", json={"param1": "test-value", "param2": 123}, headers={}, ) @pytest.mark.parametrize( - "bound_param, expected_value", + "bound_param_map, expected_value", [ ({"param1": "bound-value"}, "bound-value"), ({"param1": lambda: "dynamic-value"}, "dynamic-value"), ], ) async def test_toolbox_tool_call_with_bound_params( - self, toolbox_tool, bound_param, expected_value + self, toolbox_tool, bound_param_map, expected_value ): - tool = toolbox_tool.bind_params(bound_param) + tool = toolbox_tool.bind_params(bound_param_map) result = await tool.ainvoke({"param2": 123}) assert result == "test-result" - toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + core_tool = tool._AsyncToolboxTool__core_tool + core_tool._ToolboxTool__session.post.assert_called_once_with( "http://test_url/api/tool/test_tool/invoke", json={"param1": expected_value, "param2": 123}, headers={}, @@ -232,29 +295,53 @@ async def test_toolbox_tool_call_with_auth_tokens(self, auth_toolbox_tool): ) result = await tool.ainvoke({"param2": 123}) assert result == "test-result" - auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( + core_tool = tool._AsyncToolboxTool__core_tool + core_tool._ToolboxTool__session.post.assert_called_once_with( "https://test-url/api/tool/test_tool/invoke", json={"param2": 123}, headers={"test-auth-source_token": "test-token"}, ) - async def test_toolbox_tool_call_with_auth_tokens_insecure(self, auth_toolbox_tool): + async def test_toolbox_tool_call_with_auth_tokens_insecure( + self, auth_toolbox_tool, auth_tool_schema_dict + ): + core_tool_of_auth_tool = auth_toolbox_tool._AsyncToolboxTool__core_tool + mock_session = core_tool_of_auth_tool._ToolboxTool__session + with pytest.warns( UserWarning, match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", ): - auth_toolbox_tool._AsyncToolboxTool__url = "http://test-url" - tool = auth_toolbox_tool.add_auth_token_getters( - {"test-auth-source": lambda: "test-token"} - ) - result = await tool.ainvoke({"param2": 123}) - assert result == "test-result" - auth_toolbox_tool._AsyncToolboxTool__session.post.assert_called_once_with( - "http://test-url/api/tool/test_tool/invoke", - json={"param2": 123}, - headers={"test-auth-source_token": "test-token"}, + insecure_core_tool = self._create_core_tool_from_dict( + session=mock_session, + name="test_tool", + schema_dict=auth_tool_schema_dict, + url="http://test-url", ) + insecure_auth_langchain_tool = AsyncToolboxTool(core_tool=insecure_core_tool) + + tool_with_getter = insecure_auth_langchain_tool.add_auth_token_getters( + {"test-auth-source": lambda: "test-token"} + ) + result = await tool_with_getter.ainvoke({"param2": 123}) + assert result == "test-result" + + modified_core_tool_in_new_tool = tool_with_getter._AsyncToolboxTool__core_tool + assert ( + modified_core_tool_in_new_tool._ToolboxTool__base_url == "http://test-url" + ) + assert ( + modified_core_tool_in_new_tool._ToolboxTool__url + == "http://test-url/api/tool/test_tool/invoke" + ) + + modified_core_tool_in_new_tool._ToolboxTool__session.post.assert_called_once_with( + "http://test-url/api/tool/test_tool/invoke", + json={"param2": 123}, + headers={"test-auth-source_token": "test-token"}, + ) + async def test_toolbox_tool_call_with_invalid_input(self, toolbox_tool): with pytest.raises(ValidationError) as e: await toolbox_tool.ainvoke({"param1": 123, "param2": "invalid"}) diff --git a/packages/toolbox-langchain/tests/test_client.py b/packages/toolbox-langchain/tests/test_client.py index 62999019..d7eb62a8 100644 --- a/packages/toolbox-langchain/tests/test_client.py +++ b/packages/toolbox-langchain/tests/test_client.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest from pydantic import BaseModel +from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool # For spec +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool # For spec from toolbox_langchain.client import ToolboxClient from toolbox_langchain.tools import ToolboxTool @@ -28,232 +30,292 @@ class TestToolboxClient: def toolbox_client(self): client = ToolboxClient(URL) assert isinstance(client, ToolboxClient) - assert client._ToolboxClient__async_client is not None + assert client._ToolboxClient__core_client is not None + assert client._ToolboxClient__core_client._async_client is not None + assert client._ToolboxClient__core_client._loop is not None + assert client._ToolboxClient__core_client._loop.is_running() + assert client._ToolboxClient__core_client._thread is not None + assert client._ToolboxClient__core_client._thread.is_alive() + return client - # Check that the background loop was created and started - assert client._ToolboxClient__loop is not None - assert client._ToolboxClient__loop.is_running() + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") + def test_load_tool(self, mock_core_load_tool, toolbox_client): + mock_core_tool_instance = Mock( + spec=ToolboxCoreSyncTool + ) # Spec with Core Sync Tool + mock_core_tool_instance.__name__ = "mock-core-sync-tool" + mock_core_tool_instance.__doc__ = "mock core sync description" - return client + mock_underlying_async_tool = Mock( + spec=ToolboxCoreTool + ) # Core Async Tool for pydantic model + mock_underlying_async_tool._pydantic_model = BaseModel + mock_core_tool_instance._async_tool = mock_underlying_async_tool + + mock_core_load_tool.return_value = mock_core_tool_instance + + langchain_tool = toolbox_client.load_tool("test_tool") + + assert isinstance(langchain_tool, ToolboxTool) + assert langchain_tool.name == mock_core_tool_instance.__name__ + assert langchain_tool.description == mock_core_tool_instance.__doc__ + assert langchain_tool.args_schema == mock_underlying_async_tool._pydantic_model + + mock_core_load_tool.assert_called_once_with( + name="test_tool", auth_token_getters={}, bound_params={} + ) + + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") + def test_load_toolset(self, mock_core_load_toolset, toolbox_client): + mock_core_tool_instance1 = Mock(spec=ToolboxCoreSyncTool) + mock_core_tool_instance1.__name__ = "mock-core-sync-tool-0" + mock_core_tool_instance1.__doc__ = "desc 0" + mock_async_tool0 = Mock(spec=ToolboxCoreTool) + mock_async_tool0._pydantic_model = BaseModel + mock_core_tool_instance1._async_tool = mock_async_tool0 + + mock_core_tool_instance2 = Mock(spec=ToolboxCoreSyncTool) + mock_core_tool_instance2.__name__ = "mock-core-sync-tool-1" + mock_core_tool_instance2.__doc__ = "desc 1" + mock_async_tool1 = Mock(spec=ToolboxCoreTool) + mock_async_tool1._pydantic_model = BaseModel + mock_core_tool_instance2._async_tool = mock_async_tool1 + + mock_core_load_toolset.return_value = [ + mock_core_tool_instance1, + mock_core_tool_instance2, + ] + + langchain_tools = toolbox_client.load_toolset() + assert len(langchain_tools) == 2 + assert isinstance(langchain_tools[0], ToolboxTool) + assert isinstance(langchain_tools[1], ToolboxTool) + assert langchain_tools[0].name == "mock-core-sync-tool-0" + assert langchain_tools[1].name == "mock-core-sync-tool-1" - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") - def test_load_tool(self, mock_aload_tool, toolbox_client): - mock_tool = Mock(spec=ToolboxTool) - mock_tool.name = "mock-tool" - mock_tool.description = "mock description" - mock_tool.args_schema = BaseModel - mock_aload_tool.return_value = mock_tool - - tool = toolbox_client.load_tool("test_tool") - - assert tool.name == mock_tool.name - assert tool.description == mock_tool.description - assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with("test_tool", {}, None, None, {}, True) - - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") - def test_load_toolset(self, mock_aload_toolset, toolbox_client): - mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] - mock_tools[0].name = "mock-tool-0" - mock_tools[0].description = "mock description 0" - mock_tools[0].args_schema = BaseModel - mock_tools[1].name = "mock-tool-1" - mock_tools[1].description = "mock description 1" - mock_tools[1].args_schema = BaseModel - mock_aload_toolset.return_value = mock_tools - - tools = toolbox_client.load_toolset() - assert len(tools) == len(mock_tools) - assert all( - a.name == b.name - and a.description == b.description - and a.args_schema == b.args_schema - for a, b in zip(tools, mock_tools) + mock_core_load_toolset.assert_called_once_with( + name=None, auth_token_getters={}, bound_params={}, strict=False ) - mock_aload_toolset.assert_called_once_with(None, {}, None, None, {}, True) @pytest.mark.asyncio - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") - async def test_aload_tool(self, mock_aload_tool, toolbox_client): - mock_tool = Mock(spec=ToolboxTool) - mock_tool.name = "mock-tool" - mock_tool.description = "mock description" - mock_tool.args_schema = BaseModel - mock_aload_tool.return_value = mock_tool - - tool = await toolbox_client.aload_tool("test_tool") - assert tool.name == mock_tool.name - assert tool.description == mock_tool.description - assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with("test_tool", {}, None, None, {}, True) + @patch("toolbox_core.client.ToolboxClient.load_tool") + async def test_aload_tool(self, mock_core_aload_tool, toolbox_client): + mock_core_tool_instance = AsyncMock( + spec=ToolboxCoreTool + ) # *** Use AsyncMock for async method return *** + mock_core_tool_instance.__name__ = "mock-core-async-tool" + mock_core_tool_instance.__doc__ = "mock core async description" + mock_core_tool_instance._pydantic_model = BaseModel + mock_core_aload_tool.return_value = mock_core_tool_instance + + langchain_tool = await toolbox_client.aload_tool("test_tool") + + assert isinstance(langchain_tool, ToolboxTool) + assert langchain_tool.name == mock_core_tool_instance.__name__ + assert langchain_tool.description == mock_core_tool_instance.__doc__ + + toolbox_client._ToolboxClient__core_client._async_client.load_tool.assert_called_once_with( + name="test_tool", auth_token_getters={}, bound_params={} + ) @pytest.mark.asyncio - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") - async def test_aload_toolset(self, mock_aload_toolset, toolbox_client): - mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] - mock_tools[0].name = "mock-tool-0" - mock_tools[0].description = "mock description 0" - mock_tools[0].args_schema = BaseModel - mock_tools[1].name = "mock-tool-1" - mock_tools[1].description = "mock description 1" - mock_tools[1].args_schema = BaseModel - mock_aload_toolset.return_value = mock_tools - - tools = await toolbox_client.aload_toolset() - assert len(tools) == len(mock_tools) - assert all( - a.name == b.name - and a.description == b.description - and a.args_schema == b.args_schema - for a, b in zip(tools, mock_tools) + @patch("toolbox_core.client.ToolboxClient.load_toolset") + async def test_aload_toolset(self, mock_core_aload_toolset, toolbox_client): + mock_core_tool_instance1 = AsyncMock( + spec=ToolboxCoreTool + ) # *** Use AsyncMock *** + mock_core_tool_instance1.__name__ = "mock-core-async-tool-0" + mock_core_tool_instance1.__doc__ = "desc 0" + mock_core_tool_instance1._pydantic_model = BaseModel + + mock_core_tool_instance2 = AsyncMock( + spec=ToolboxCoreTool + ) # *** Use AsyncMock *** + mock_core_tool_instance2.__name__ = "mock-core-async-tool-1" + mock_core_tool_instance2.__doc__ = "desc 1" + mock_core_tool_instance2._pydantic_model = BaseModel + + mock_core_aload_toolset.return_value = [ + mock_core_tool_instance1, + mock_core_tool_instance2, + ] + + langchain_tools = await toolbox_client.aload_toolset() + assert len(langchain_tools) == 2 + assert isinstance(langchain_tools[0], ToolboxTool) + assert isinstance(langchain_tools[1], ToolboxTool) + + toolbox_client._ToolboxClient__core_client._async_client.load_toolset.assert_called_once_with( + name=None, auth_token_getters={}, bound_params={}, strict=False ) - mock_aload_toolset.assert_called_once_with(None, {}, None, None, {}, True) - - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") - def test_load_tool_with_args(self, mock_aload_tool, toolbox_client): - mock_tool = Mock(spec=ToolboxTool) - mock_tool.name = "mock-tool" - mock_tool.description = "mock description" - mock_tool.args_schema = BaseModel - mock_aload_tool.return_value = mock_tool + + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool") + def test_load_tool_with_args(self, mock_core_load_tool, toolbox_client): + mock_core_tool_instance = Mock(spec=ToolboxCoreSyncTool) + mock_core_tool_instance.__name__ = "mock-tool" + mock_async_tool = Mock(spec=ToolboxCoreTool) + mock_async_tool._pydantic_model = BaseModel + mock_core_tool_instance._async_tool = mock_async_tool + mock_core_load_tool.return_value = mock_core_tool_instance + auth_token_getters = {"token_getter1": lambda: "value1"} - auth_tokens = {"token1": lambda: "value2"} - auth_headers = {"header1": lambda: "value3"} + auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} + auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} bound_params = {"param1": "value4"} - tool = toolbox_client.load_tool( - "test_tool_name", + # Test case where auth_token_getters takes precedence + with pytest.warns(DeprecationWarning) as record: + tool = toolbox_client.load_tool( + "test_tool_name", + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + ) + # Expect two warnings: one for auth_tokens, one for auth_headers + assert len(record) == 2 + messages = [str(r.message) for r in record] + assert any("auth_tokens` is deprecated" in m for m in messages) + assert any("auth_headers` is deprecated" in m for m in messages) + + assert isinstance(tool, ToolboxTool) + mock_core_load_tool.assert_called_with( # Use called_with for flexibility if called multiple times in setup + name="test_tool_name", auth_token_getters=auth_token_getters, - auth_tokens=auth_tokens, - auth_headers=auth_headers, bound_params=bound_params, - strict=False, ) + mock_core_load_tool.reset_mock() # Reset for next test case + + # Test case where auth_tokens is used (auth_token_getters is None) + with pytest.warns(DeprecationWarning, match="auth_tokens` is deprecated"): + toolbox_client.load_tool( + "test_tool_name_2", + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, # This will also warn + bound_params=bound_params, + ) + mock_core_load_tool.assert_called_with( + name="test_tool_name_2", + auth_token_getters=auth_tokens_deprecated, # auth_tokens becomes auth_token_getters + bound_params=bound_params, + ) + mock_core_load_tool.reset_mock() - assert tool.name == mock_tool.name - assert tool.description == mock_tool.description - assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with( - "test_tool_name", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, + # Test case where auth_headers is used (auth_token_getters and auth_tokens are None) + with pytest.warns(DeprecationWarning, match="auth_headers` is deprecated"): + toolbox_client.load_tool( + "test_tool_name_3", + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + ) + mock_core_load_tool.assert_called_with( + name="test_tool_name_3", + auth_token_getters=auth_headers_deprecated, # auth_headers becomes auth_token_getters + bound_params=bound_params, ) - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") - def test_load_toolset_with_args(self, mock_aload_toolset, toolbox_client): - mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] - mock_tools[0].name = "mock-tool-0" - mock_tools[0].description = "mock description 0" - mock_tools[0].args_schema = BaseModel - mock_tools[1].name = "mock-tool-1" - mock_tools[1].description = "mock description 1" - mock_tools[1].args_schema = BaseModel - mock_aload_toolset.return_value = mock_tools + @patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset") + def test_load_toolset_with_args(self, mock_core_load_toolset, toolbox_client): + mock_core_tool_instance = Mock(spec=ToolboxCoreSyncTool) + mock_core_tool_instance.__name__ = "mock-tool-0" + mock_async_tool = Mock(spec=ToolboxCoreTool) + mock_async_tool._pydantic_model = BaseModel + mock_core_tool_instance._async_tool = mock_async_tool + mock_core_load_toolset.return_value = [mock_core_tool_instance] auth_token_getters = {"token_getter1": lambda: "value1"} - auth_tokens = {"token1": lambda: "value2"} - auth_headers = {"header1": lambda: "value3"} + auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} + auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} bound_params = {"param1": "value4"} - tools = toolbox_client.load_toolset( - toolset_name="my_toolset", + with pytest.warns(DeprecationWarning) as record: # Expect 2 warnings + tools = toolbox_client.load_toolset( + toolset_name="my_toolset", + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + strict=False, + ) + assert len(record) == 2 + messages = [str(r.message) for r in record] + assert any("auth_tokens` is deprecated" in m for m in messages) + assert any("auth_headers` is deprecated" in m for m in messages) + + assert len(tools) == 1 + mock_core_load_toolset.assert_called_with( + name="my_toolset", auth_token_getters=auth_token_getters, - auth_tokens=auth_tokens, - auth_headers=auth_headers, bound_params=bound_params, strict=False, ) - assert len(tools) == len(mock_tools) - assert all( - a.name == b.name - and a.description == b.description - and a.args_schema == b.args_schema - for a, b in zip(tools, mock_tools) - ) - mock_aload_toolset.assert_called_once_with( - "my_toolset", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, - ) - @pytest.mark.asyncio - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_tool") - async def test_aload_tool_with_args(self, mock_aload_tool, toolbox_client): - mock_tool = Mock(spec=ToolboxTool) - mock_tool.name = "mock-tool" - mock_tool.description = "mock description" - mock_tool.args_schema = BaseModel - mock_aload_tool.return_value = mock_tool + @patch("toolbox_core.client.ToolboxClient.load_tool") + async def test_aload_tool_with_args(self, mock_core_aload_tool, toolbox_client): + mock_core_tool_instance = AsyncMock(spec=ToolboxCoreTool) + mock_core_tool_instance.__name__ = "mock-tool" + mock_core_tool_instance._pydantic_model = BaseModel + mock_core_aload_tool.return_value = mock_core_tool_instance auth_token_getters = {"token_getter1": lambda: "value1"} - auth_tokens = {"token1": lambda: "value2"} - auth_headers = {"header1": lambda: "value3"} + auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} + auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} bound_params = {"param1": "value4"} - tool = await toolbox_client.aload_tool( - "test_tool", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, - ) - assert tool.name == mock_tool.name - assert tool.description == mock_tool.description - assert tool.args_schema == mock_tool.args_schema - mock_aload_tool.assert_called_once_with( - "test_tool", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, + with pytest.warns(DeprecationWarning) as record: # Expect 2 warnings + tool = await toolbox_client.aload_tool( + "test_tool", + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + ) + assert len(record) == 2 + messages = [str(r.message) for r in record] + assert any("auth_tokens` is deprecated" in m for m in messages) + assert any("auth_headers` is deprecated" in m for m in messages) + + assert isinstance(tool, ToolboxTool) + toolbox_client._ToolboxClient__core_client._async_client.load_tool.assert_called_with( + name="test_tool", + auth_token_getters=auth_token_getters, + bound_params=bound_params, ) @pytest.mark.asyncio - @patch("toolbox_langchain.client.AsyncToolboxClient.aload_toolset") - async def test_aload_toolset_with_args(self, mock_aload_toolset, toolbox_client): - mock_tools = [Mock(spec=ToolboxTool), Mock(spec=ToolboxTool)] - mock_tools[0].name = "mock-tool-0" - mock_tools[0].description = "mock description 0" - mock_tools[0].args_schema = BaseModel - mock_tools[1].name = "mock-tool-1" - mock_tools[1].description = "mock description 1" - mock_tools[1].args_schema = BaseModel - mock_aload_toolset.return_value = mock_tools + @patch("toolbox_core.client.ToolboxClient.load_toolset") + async def test_aload_toolset_with_args( + self, mock_core_aload_toolset, toolbox_client + ): + mock_core_tool_instance = AsyncMock(spec=ToolboxCoreTool) + mock_core_tool_instance.__name__ = "mock-tool-0" + mock_core_tool_instance._pydantic_model = BaseModel + mock_core_aload_toolset.return_value = [mock_core_tool_instance] auth_token_getters = {"token_getter1": lambda: "value1"} - auth_tokens = {"token1": lambda: "value2"} - auth_headers = {"header1": lambda: "value3"} + auth_tokens_deprecated = {"token_deprecated": lambda: "value_dep"} + auth_headers_deprecated = {"header_deprecated": lambda: "value_head_dep"} bound_params = {"param1": "value4"} - tools = await toolbox_client.aload_toolset( - "my_toolset", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, - ) - assert len(tools) == len(mock_tools) - assert all( - a.name == b.name - and a.description == b.description - and a.args_schema == b.args_schema - for a, b in zip(tools, mock_tools) - ) - mock_aload_toolset.assert_called_once_with( - "my_toolset", - auth_token_getters, - auth_tokens, - auth_headers, - bound_params, - False, + with pytest.warns(DeprecationWarning) as record: # Expect 2 warnings + tools = await toolbox_client.aload_toolset( + "my_toolset", + auth_token_getters=auth_token_getters, + auth_tokens=auth_tokens_deprecated, + auth_headers=auth_headers_deprecated, + bound_params=bound_params, + strict=False, + ) + assert len(record) == 2 + messages = [str(r.message) for r in record] + assert any("auth_tokens` is deprecated" in m for m in messages) + assert any("auth_headers` is deprecated" in m for m in messages) + + assert len(tools) == 1 + toolbox_client._ToolboxClient__core_client._async_client.load_toolset.assert_called_with( + name="my_toolset", + auth_token_getters=auth_token_getters, + bound_params=bound_params, + strict=False, ) diff --git a/packages/toolbox-langchain/tests/test_e2e.py b/packages/toolbox-langchain/tests/test_e2e.py index 214ea305..12002717 100644 --- a/packages/toolbox-langchain/tests/test_e2e.py +++ b/packages/toolbox-langchain/tests/test_e2e.py @@ -36,7 +36,6 @@ import pytest import pytest_asyncio -from langchain_core.tools import ToolException from pydantic import ValidationError from toolbox_langchain.client import ToolboxClient @@ -54,7 +53,7 @@ def toolbox(self): @pytest_asyncio.fixture(scope="function") async def get_n_rows_tool(self, toolbox): tool = await toolbox.aload_tool("get-n-rows") - assert tool._ToolboxTool__async_tool._AsyncToolboxTool__name == "get-n-rows" + assert tool._ToolboxTool__core_tool.__name__ == "get-n-rows" return tool #### Basic e2e tests @@ -71,7 +70,7 @@ async def test_aload_toolset_specific( toolset = await toolbox.aload_toolset(toolset_name) assert len(toolset) == expected_length for tool in toolset: - name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + name = tool._ToolboxTool__core_tool.__name__ assert name in expected_tools async def test_aload_toolset_all(self, toolbox): @@ -85,7 +84,7 @@ async def test_aload_toolset_all(self, toolbox): "get-row-by-content-auth", ] for tool in toolset: - name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + name = tool._ToolboxTool__core_tool.__name__ assert name in tool_names async def test_run_tool_async(self, get_n_rows_tool): @@ -114,11 +113,14 @@ async def test_run_tool_wrong_param_type(self, get_n_rows_tool): @pytest.mark.asyncio async def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): """Tests running a tool that doesn't require auth, with auth provided.""" - tool = await toolbox.aload_tool( - "get-row-by-id", auth_token_getters={"my-test-auth": lambda: auth_token2} - ) - response = await tool.ainvoke({"id": "2"}) - assert "row2" in response + with pytest.raises( + ValueError, + match="Validation failed for tool 'get-row-by-id': unused auth tokens: my-test-auth.", + ): + await toolbox.aload_tool( + "get-row-by-id", + auth_token_getters={"my-test-auth": lambda: auth_token2}, + ) async def test_run_tool_no_auth(self, toolbox): """Tests running a tool requiring auth without providing auth.""" @@ -127,7 +129,7 @@ async def test_run_tool_no_auth(self, toolbox): ) with pytest.raises( PermissionError, - match="Tool get-row-by-id-auth requires authentication, but no valid authentication sources are registered. Please register the required sources before use.", + match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): await tool.ainvoke({"id": "2"}) @@ -138,8 +140,8 @@ async def test_run_tool_wrong_auth(self, toolbox, auth_token2): ) auth_tool = tool.add_auth_token_getter("my-test-auth", lambda: auth_token2) with pytest.raises( - ToolException, - match="{'status': 'Unauthorized', 'error': 'tool invocation not authorized. Please make sure your specify correct auth headers'}", + Exception, + match="tool invocation not authorized. Please make sure your specify correct auth headers", ): await auth_tool.ainvoke({"id": "2"}) @@ -157,7 +159,7 @@ async def test_run_tool_param_auth_no_auth(self, toolbox): tool = await toolbox.aload_tool("get-row-by-email-auth") with pytest.raises( PermissionError, - match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.", + match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): await tool.ainvoke({"email": ""}) @@ -179,8 +181,8 @@ async def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): auth_token_getters={"my-test-auth": lambda: auth_token1}, ) with pytest.raises( - ToolException, - match="{'status': 'Bad Request', 'error': 'provided parameters were invalid: error parsing authenticated parameter \"data\": no field named row_data in claims'}", + Exception, + match='provided parameters were invalid: error parsing authenticated parameter "data": no field named row_data in claims', ): await tool.ainvoke({}) @@ -196,7 +198,7 @@ def toolbox(self): @pytest.fixture(scope="function") def get_n_rows_tool(self, toolbox): tool = toolbox.load_tool("get-n-rows") - assert tool._ToolboxTool__async_tool._AsyncToolboxTool__name == "get-n-rows" + assert tool._ToolboxTool__core_tool.__name__ == "get-n-rows" return tool #### Basic e2e tests @@ -213,7 +215,7 @@ def test_load_toolset_specific( toolset = toolbox.load_toolset(toolset_name) assert len(toolset) == expected_length for tool in toolset: - name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + name = tool._ToolboxTool__core_tool.__name__ assert name in expected_tools def test_aload_toolset_all(self, toolbox): @@ -227,7 +229,7 @@ def test_aload_toolset_all(self, toolbox): "get-row-by-content-auth", ] for tool in toolset: - name = tool._ToolboxTool__async_tool._AsyncToolboxTool__name + name = tool._ToolboxTool__core_tool.__name__ assert name in tool_names @pytest.mark.asyncio @@ -256,11 +258,14 @@ def test_run_tool_wrong_param_type(self, get_n_rows_tool): #### Auth tests def test_run_tool_unauth_with_auth(self, toolbox, auth_token2): """Tests running a tool that doesn't require auth, with auth provided.""" - tool = toolbox.load_tool( - "get-row-by-id", auth_token_getters={"my-test-auth": lambda: auth_token2} - ) - response = tool.invoke({"id": "2"}) - assert "row2" in response + with pytest.raises( + ValueError, + match="Validation failed for tool 'get-row-by-id': unused auth tokens: my-test-auth.", + ): + toolbox.load_tool( + "get-row-by-id", + auth_token_getters={"my-test-auth": lambda: auth_token2}, + ) def test_run_tool_no_auth(self, toolbox): """Tests running a tool requiring auth without providing auth.""" @@ -269,7 +274,7 @@ def test_run_tool_no_auth(self, toolbox): ) with pytest.raises( PermissionError, - match="Tool get-row-by-id-auth requires authentication, but no valid authentication sources are registered. Please register the required sources before use.", + match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): tool.invoke({"id": "2"}) @@ -280,8 +285,8 @@ def test_run_tool_wrong_auth(self, toolbox, auth_token2): ) auth_tool = tool.add_auth_token_getter("my-test-auth", lambda: auth_token2) with pytest.raises( - ToolException, - match="{'status': 'Unauthorized', 'error': 'tool invocation not authorized. Please make sure your specify correct auth headers'}", + Exception, + match="tool invocation not authorized. Please make sure your specify correct auth headers", ): auth_tool.invoke({"id": "2"}) @@ -299,7 +304,7 @@ def test_run_tool_param_auth_no_auth(self, toolbox): tool = toolbox.load_tool("get-row-by-email-auth") with pytest.raises( PermissionError, - match="Parameter\(s\) `email` of tool get-row-by-email-auth require authentication\, but no valid authentication sources are registered\. Please register the required sources before use\.", + match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): tool.invoke({"email": ""}) @@ -321,7 +326,7 @@ def test_run_tool_param_auth_no_field(self, toolbox, auth_token1): auth_token_getters={"my-test-auth": lambda: auth_token1}, ) with pytest.raises( - ToolException, - match="{'status': 'Bad Request', 'error': 'provided parameters were invalid: error parsing authenticated parameter \"data\": no field named row_data in claims'}", + Exception, + match='provided parameters were invalid: error parsing authenticated parameter "data": no field named row_data in claims', ): tool.invoke({}) diff --git a/packages/toolbox-langchain/tests/test_tools.py b/packages/toolbox-langchain/tests/test_tools.py index 751005af..5560cf99 100644 --- a/packages/toolbox-langchain/tests/test_tools.py +++ b/packages/toolbox-langchain/tests/test_tools.py @@ -16,17 +16,17 @@ import pytest from pydantic import BaseModel +from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool +from toolbox_core.tool import ToolboxTool as ToolboxCoreTool -from toolbox_langchain.async_tools import AsyncToolboxTool from toolbox_langchain.tools import ToolboxTool class TestToolboxTool: @pytest.fixture - def tool_schema(self): + def tool_schema_dict(self): return { "description": "Test Tool Description", - "name": "test_tool", "parameters": [ {"name": "param1", "type": "string", "description": "Param 1"}, {"name": "param2", "type": "integer", "description": "Param 2"}, @@ -34,10 +34,10 @@ def tool_schema(self): } @pytest.fixture - def auth_tool_schema(self): + def auth_tool_schema_dict(self): return { - "description": "Test Tool Description", - "name": "test_tool", + "description": "Test Auth Tool Description", + "authRequired": ["test-auth-source"], "parameters": [ { "name": "param1", @@ -50,62 +50,66 @@ def auth_tool_schema(self): } @pytest.fixture(scope="function") - def mock_async_tool(self, tool_schema): - mock_async_tool = Mock(spec=AsyncToolboxTool) - mock_async_tool.name = "test_tool" - mock_async_tool.description = "test description" - mock_async_tool.args_schema = BaseModel - mock_async_tool._AsyncToolboxTool__name = "test_tool" - mock_async_tool._AsyncToolboxTool__schema = tool_schema - mock_async_tool._AsyncToolboxTool__url = "http://test_url" - mock_async_tool._AsyncToolboxTool__session = Mock() - mock_async_tool._AsyncToolboxTool__auth_token_getters = {} - mock_async_tool._AsyncToolboxTool__bound_params = {} - return mock_async_tool + def mock_core_async_tool(self, tool_schema_dict): + mock = Mock(spec=ToolboxCoreTool) + mock.__name__ = "test_tool" + mock.__doc__ = tool_schema_dict["description"] + mock._pydantic_model = BaseModel + return mock @pytest.fixture(scope="function") - def mock_async_auth_tool(self, auth_tool_schema): - mock_async_tool = Mock(spec=AsyncToolboxTool) - mock_async_tool.name = "test_tool" - mock_async_tool.description = "test description" - mock_async_tool.args_schema = BaseModel - mock_async_tool._AsyncToolboxTool__name = "test_tool" - mock_async_tool._AsyncToolboxTool__schema = auth_tool_schema - mock_async_tool._AsyncToolboxTool__url = "http://test_url" - mock_async_tool._AsyncToolboxTool__session = Mock() - mock_async_tool._AsyncToolboxTool__auth_token_getters = {} - mock_async_tool._AsyncToolboxTool__bound_params = {} - return mock_async_tool + def mock_core_async_auth_tool(self, auth_tool_schema_dict): + mock = Mock(spec=ToolboxCoreTool) + mock.__name__ = "test_auth_tool" + mock.__doc__ = auth_tool_schema_dict["description"] + mock._pydantic_model = BaseModel + return mock @pytest.fixture - def toolbox_tool(self, mock_async_tool): - return ToolboxTool( - async_tool=mock_async_tool, - loop=Mock(), - thread=Mock(), - ) + def mock_core_tool(self, mock_core_async_tool): + sync_mock = Mock(spec=ToolboxCoreSyncTool) + sync_mock.__name__ = mock_core_async_tool.__name__ + sync_mock.__doc__ = mock_core_async_tool.__doc__ + sync_mock._async_tool = mock_core_async_tool + sync_mock.add_auth_token_getters = Mock(return_value=sync_mock) + sync_mock.bind_params = Mock(return_value=sync_mock) + sync_mock.bind_param = Mock( + return_value=sync_mock + ) # Keep this if bind_param exists on core, otherwise remove + sync_mock.__call__ = Mock(return_value="mocked_sync_call_result") + return sync_mock @pytest.fixture - def auth_toolbox_tool(self, mock_async_auth_tool): - return ToolboxTool( - async_tool=mock_async_auth_tool, - loop=Mock(), - thread=Mock(), - ) + def mock_core_sync_auth_tool(self, mock_core_async_auth_tool): + sync_mock = Mock(spec=ToolboxCoreSyncTool) + sync_mock.__name__ = mock_core_async_auth_tool.__name__ + sync_mock.__doc__ = mock_core_async_auth_tool.__doc__ + sync_mock._async_tool = mock_core_async_auth_tool + sync_mock.add_auth_token_getters = Mock(return_value=sync_mock) + sync_mock.bind_params = Mock(return_value=sync_mock) + sync_mock.bind_param = Mock( + return_value=sync_mock + ) # Keep this if bind_param exists on core + sync_mock.__call__ = Mock(return_value="mocked_auth_sync_call_result") + return sync_mock - def test_toolbox_tool_init(self, mock_async_tool): - tool = ToolboxTool( - async_tool=mock_async_tool, - loop=Mock(), - thread=Mock(), - ) - async_tool = tool._ToolboxTool__async_tool - assert async_tool.name == mock_async_tool.name - assert async_tool.description == mock_async_tool.description - assert async_tool.args_schema == mock_async_tool.args_schema + @pytest.fixture + def toolbox_tool(self, mock_core_tool): + return ToolboxTool(core_tool=mock_core_tool) + + @pytest.fixture + def auth_toolbox_tool(self, mock_core_sync_auth_tool): + return ToolboxTool(core_tool=mock_core_sync_auth_tool) + + def test_toolbox_tool_init(self, mock_core_tool): + tool = ToolboxTool(core_tool=mock_core_tool) + core_tool_in_tool = tool._ToolboxTool__core_tool + assert core_tool_in_tool.__name__ == mock_core_tool.__name__ + assert core_tool_in_tool.__doc__ == mock_core_tool.__doc__ + assert tool.args_schema == mock_core_tool._async_tool._pydantic_model @pytest.mark.parametrize( - "params, expected_bound_params", + "params, expected_bound_params_on_core", [ ({"param1": "bound-value"}, {"param1": "bound-value"}), ({"param1": lambda: "bound-value"}, {"param1": lambda: "bound-value"}), @@ -118,44 +122,33 @@ def test_toolbox_tool_init(self, mock_async_tool): def test_toolbox_tool_bind_params( self, params, - expected_bound_params, + expected_bound_params_on_core, toolbox_tool, - mock_async_tool, + mock_core_tool, ): - mock_async_tool._AsyncToolboxTool__bound_params = expected_bound_params - mock_async_tool.bind_params.return_value = mock_async_tool - - tool = toolbox_tool.bind_params(params) - mock_async_tool.bind_params.assert_called_once_with(params, True) - assert isinstance(tool, ToolboxTool) - - for key, value in expected_bound_params.items(): - async_tool_bound_param_val = ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__bound_params[key] - ) - if callable(value): - assert value() == async_tool_bound_param_val() - else: - assert value == async_tool_bound_param_val - - def test_toolbox_tool_bind_param(self, mock_async_tool, toolbox_tool): - expected_bound_param = {"param1": "bound-value"} - mock_async_tool._AsyncToolboxTool__bound_params = expected_bound_param - mock_async_tool.bind_param.return_value = mock_async_tool - - tool = toolbox_tool.bind_param("param1", "bound-value") - mock_async_tool.bind_param.assert_called_once_with( - "param1", "bound-value", True + mock_core_tool.bind_params.return_value = mock_core_tool + new_langchain_tool = toolbox_tool.bind_params(params) + mock_core_tool.bind_params.assert_called_once_with(params) + assert isinstance(new_langchain_tool, ToolboxTool) + assert ( + new_langchain_tool._ToolboxTool__core_tool + == mock_core_tool.bind_params.return_value ) + def test_toolbox_tool_bind_param(self, toolbox_tool, mock_core_tool): + # ToolboxTool.bind_param calls core_tool.bind_params + mock_core_tool.bind_params.return_value = mock_core_tool + new_langchain_tool = toolbox_tool.bind_param("param1", "bound-value") + # *** Fix: Assert that bind_params is called on the core tool *** + mock_core_tool.bind_params.assert_called_once_with({"param1": "bound-value"}) + assert isinstance(new_langchain_tool, ToolboxTool) assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__bound_params - == expected_bound_param + new_langchain_tool._ToolboxTool__core_tool + == mock_core_tool.bind_params.return_value ) - assert isinstance(tool, ToolboxTool) @pytest.mark.parametrize( - "auth_token_getters, expected_auth_token_getters", + "auth_token_getters, expected_auth_getters_on_core", [ ( {"test-auth-source": lambda: "test-token"}, @@ -176,63 +169,44 @@ def test_toolbox_tool_bind_param(self, mock_async_tool, toolbox_tool): def test_toolbox_tool_add_auth_token_getters( self, auth_token_getters, - expected_auth_token_getters, - mock_async_auth_tool, + expected_auth_getters_on_core, auth_toolbox_tool, + mock_core_sync_auth_tool, ): - auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters = ( - expected_auth_token_getters + mock_core_sync_auth_tool.add_auth_token_getters.return_value = ( + mock_core_sync_auth_tool ) - auth_toolbox_tool._ToolboxTool__async_tool.add_auth_token_getters.return_value = ( - mock_async_auth_tool + new_langchain_tool = auth_toolbox_tool.add_auth_token_getters( + auth_token_getters ) - - tool = auth_toolbox_tool.add_auth_token_getters(auth_token_getters) - mock_async_auth_tool.add_auth_token_getters.assert_called_once_with( - auth_token_getters, True + mock_core_sync_auth_tool.add_auth_token_getters.assert_called_once_with( + auth_token_getters + ) + assert isinstance(new_langchain_tool, ToolboxTool) + assert ( + new_langchain_tool._ToolboxTool__core_tool + == mock_core_sync_auth_tool.add_auth_token_getters.return_value ) - for source, getter in expected_auth_token_getters.items(): - assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters[ - source - ]() - == getter() - ) - assert isinstance(tool, ToolboxTool) def test_toolbox_tool_add_auth_token_getter( - self, mock_async_auth_tool, auth_toolbox_tool + self, auth_toolbox_tool, mock_core_sync_auth_tool ): get_id_token = lambda: "test-token" - expected_auth_token_getters = {"test-auth-source": get_id_token} - auth_toolbox_tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters = ( - expected_auth_token_getters - ) - auth_toolbox_tool._ToolboxTool__async_tool.add_auth_token_getter.return_value = ( - mock_async_auth_tool + # ToolboxTool.add_auth_token_getter calls core_tool.add_auth_token_getters + mock_core_sync_auth_tool.add_auth_token_getters.return_value = ( + mock_core_sync_auth_tool ) - tool = auth_toolbox_tool.add_auth_token_getter("test-auth-source", get_id_token) - mock_async_auth_tool.add_auth_token_getter.assert_called_once_with( - "test-auth-source", get_id_token, True + new_langchain_tool = auth_toolbox_tool.add_auth_token_getter( + "test-auth-source", get_id_token ) - assert ( - tool._ToolboxTool__async_tool._AsyncToolboxTool__auth_token_getters[ - "test-auth-source" - ]() - == "test-token" - ) - assert isinstance(tool, ToolboxTool) - - def test_toolbox_tool_validate_auth_strict(self, auth_toolbox_tool): - auth_toolbox_tool._ToolboxTool__async_tool._arun = Mock( - side_effect=PermissionError( - "Parameter(s) `param1` of tool test_tool require authentication" - ) + # *** Fix: Assert that add_auth_token_getters is called on the core tool *** + mock_core_sync_auth_tool.add_auth_token_getters.assert_called_once_with( + {"test-auth-source": get_id_token} ) - with pytest.raises(PermissionError) as e: - auth_toolbox_tool._run() - assert "Parameter(s) `param1` of tool test_tool require authentication" in str( - e.value + assert isinstance(new_langchain_tool, ToolboxTool) + assert ( + new_langchain_tool._ToolboxTool__core_tool + == mock_core_sync_auth_tool.add_auth_token_getters.return_value ) diff --git a/packages/toolbox-langchain/tests/test_utils.py b/packages/toolbox-langchain/tests/test_utils.py deleted file mode 100644 index 488a6aef..00000000 --- a/packages/toolbox-langchain/tests/test_utils.py +++ /dev/null @@ -1,290 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import json -import re -import warnings -from unittest.mock import AsyncMock, Mock, patch - -import aiohttp -import pytest -from pydantic import BaseModel - -from toolbox_langchain.utils import ( - ParameterSchema, - _get_auth_headers, - _invoke_tool, - _load_manifest, - _parse_type, - _schema_to_model, -) - -URL = "https://my-toolbox.com/test" -MOCK_MANIFEST = """ -{ - "serverVersion": "0.0.1", - "tools": { - "test_tool": { - "summary": "Test Tool", - "description": "This is a test tool.", - "parameters": [ - { - "name": "param1", - "type": "string", - "description": "Parameter 1" - }, - { - "name": "param2", - "type": "integer", - "description": "Parameter 2" - } - ] - } - } -} -""" - - -class TestUtils: - @pytest.fixture(scope="module") - def mock_manifest(self): - return aiohttp.ClientResponse( - method="GET", - url=aiohttp.client.URL(URL), - writer=None, - continue100=None, - timer=None, - request_info=None, - traces=None, - session=None, - loop=asyncio.get_event_loop(), - ) - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.get") - async def test_load_manifest(self, mock_get, mock_manifest): - mock_manifest.raise_for_status = Mock() - mock_manifest.text = AsyncMock(return_value=MOCK_MANIFEST) - - mock_get.return_value = mock_manifest - session = aiohttp.ClientSession() - manifest = await _load_manifest(URL, session) - await session.close() - mock_get.assert_called_once_with(URL) - - assert manifest.serverVersion == "0.0.1" - assert len(manifest.tools) == 1 - - tool = manifest.tools["test_tool"] - assert tool.description == "This is a test tool." - assert tool.parameters == [ - ParameterSchema(name="param1", type="string", description="Parameter 1"), - ParameterSchema(name="param2", type="integer", description="Parameter 2"), - ] - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.get") - async def test_load_manifest_invalid_json(self, mock_get, mock_manifest): - mock_manifest.raise_for_status = Mock() - mock_manifest.text = AsyncMock(return_value="{ invalid manifest") - mock_get.return_value = mock_manifest - - with pytest.raises(Exception) as e: - session = aiohttp.ClientSession() - await _load_manifest(URL, session) - - mock_get.assert_called_once_with(URL) - assert isinstance(e.value, json.JSONDecodeError) - assert ( - str(e.value) - == "Failed to parse JSON from https://my-toolbox.com/test: Expecting property name enclosed in double quotes: line 1 column 3 (char 2): line 1 column 3 (char 2)" - ) - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.get") - async def test_load_manifest_invalid_manifest(self, mock_get, mock_manifest): - mock_manifest.raise_for_status = Mock() - mock_manifest.text = AsyncMock(return_value='{ "something": "invalid" }') - mock_get.return_value = mock_manifest - - with pytest.raises(Exception) as e: - session = aiohttp.ClientSession() - await _load_manifest(URL, session) - - mock_get.assert_called_once_with(URL) - assert isinstance(e.value, ValueError) - assert re.match( - r"Invalid JSON data from https://my-toolbox.com/test: 2 validation errors for ManifestSchema\nserverVersion\n Field required \[type=missing, input_value={'something': 'invalid'}, input_type=dict]\n For further information visit https://errors.pydantic.dev/\d+\.\d+/v/missing\ntools\n Field required \[type=missing, input_value={'something': 'invalid'}, input_type=dict]\n For further information visit https://errors.pydantic.dev/\d+\.\d+/v/missing", - str(e.value), - ) - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.get") - async def test_load_manifest_api_error(self, mock_get, mock_manifest): - error = aiohttp.ClientError("Simulated HTTP Error") - mock_manifest.raise_for_status = Mock() - mock_manifest.text = AsyncMock(side_effect=error) - mock_get.return_value = mock_manifest - - with pytest.raises(aiohttp.ClientError) as exc_info: - session = aiohttp.ClientSession() - await _load_manifest(URL, session) - mock_get.assert_called_once_with(URL) - assert exc_info.value == error - - def test_schema_to_model(self): - schema = [ - ParameterSchema(name="param1", type="string", description="Parameter 1"), - ParameterSchema(name="param2", type="integer", description="Parameter 2"), - ] - model = _schema_to_model("TestModel", schema) - assert issubclass(model, BaseModel) - - assert model.model_fields["param1"].annotation == str - assert model.model_fields["param1"].description == "Parameter 1" - assert model.model_fields["param2"].annotation == int - assert model.model_fields["param2"].description == "Parameter 2" - - def test_schema_to_model_empty(self): - model = _schema_to_model("TestModel", []) - assert issubclass(model, BaseModel) - assert len(model.model_fields) == 0 - - @pytest.mark.parametrize( - "parameter_schema, expected_type", - [ - (ParameterSchema(name="foo", description="bar", type="string"), str), - (ParameterSchema(name="foo", description="bar", type="integer"), int), - (ParameterSchema(name="foo", description="bar", type="float"), float), - (ParameterSchema(name="foo", description="bar", type="boolean"), bool), - ( - ParameterSchema( - name="foo", - description="bar", - type="array", - items=ParameterSchema( - name="foo", description="bar", type="integer" - ), - ), - list[int], - ), - ], - ) - def test_parse_type(self, parameter_schema, expected_type): - assert _parse_type(parameter_schema) == expected_type - - @pytest.mark.parametrize( - "fail_parameter_schema", - [ - (ParameterSchema(name="foo", description="bar", type="invalid")), - ( - ParameterSchema( - name="foo", - description="bar", - type="array", - items=ParameterSchema( - name="foo", description="bar", type="invalid" - ), - ) - ), - ], - ) - def test_parse_type_invalid(self, fail_parameter_schema): - with pytest.raises(ValueError): - _parse_type(fail_parameter_schema) - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.post") - async def test_invoke_tool(self, mock_post): - mock_response = Mock() - mock_response.raise_for_status = Mock() - mock_response.json = AsyncMock(return_value={"key": "value"}) - mock_post.return_value.__aenter__.return_value = mock_response - - result = await _invoke_tool( - "http://localhost:5000", - aiohttp.ClientSession(), - "tool_name", - {"input": "data"}, - {}, - ) - - mock_post.assert_called_once_with( - "http://localhost:5000/api/tool/tool_name/invoke", - json={"input": "data"}, - headers={}, - ) - assert result == {"key": "value"} - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.post") - async def test_invoke_tool_unsecure_with_auth(self, mock_post): - mock_response = Mock() - mock_response.raise_for_status = Mock() - mock_response.json = AsyncMock(return_value={"key": "value"}) - mock_post.return_value.__aenter__.return_value = mock_response - - with pytest.warns( - UserWarning, - match="Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.", - ): - result = await _invoke_tool( - "http://localhost:5000", - aiohttp.ClientSession(), - "tool_name", - {"input": "data"}, - {"my_test_auth": lambda: "fake_id_token"}, - ) - - mock_post.assert_called_once_with( - "http://localhost:5000/api/tool/tool_name/invoke", - json={"input": "data"}, - headers={"my_test_auth_token": "fake_id_token"}, - ) - assert result == {"key": "value"} - - @pytest.mark.asyncio - @patch("aiohttp.ClientSession.post") - async def test_invoke_tool_secure_with_auth(self, mock_post): - session = aiohttp.ClientSession() - mock_response = Mock() - mock_response.raise_for_status = Mock() - mock_response.json = AsyncMock(return_value={"key": "value"}) - mock_post.return_value.__aenter__.return_value = mock_response - - with warnings.catch_warnings(): - warnings.simplefilter("error") - result = await _invoke_tool( - "https://localhost:5000", - session, - "tool_name", - {"input": "data"}, - {"my_test_auth": lambda: "fake_id_token"}, - ) - - mock_post.assert_called_once_with( - "https://localhost:5000/api/tool/tool_name/invoke", - json={"input": "data"}, - headers={"my_test_auth_token": "fake_id_token"}, - ) - assert result == {"key": "value"} - - def test_get_auth_headers_deprecation_warning(self): - """Test _get_auth_headers deprecation warning.""" - with pytest.warns( - DeprecationWarning, - match=r"Call to deprecated function \(or staticmethod\) _get_auth_headers\. \(Please use `_get_auth_tokens` instead\.\)$", - ): - _get_auth_headers({"auth_source1": lambda: "test_token"})