diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index df8c2743..de4e4107 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -13,8 +13,8 @@ # limitations under the License. -import types -from typing import Any, Callable, Coroutine, Mapping, Optional, Union +from types import MappingProxyType +from typing import Any, Awaitable, Callable, Mapping, Optional, Union from aiohttp import ClientSession @@ -38,7 +38,9 @@ def __init__( self, url: str, session: Optional[ClientSession] = None, - client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None, + client_headers: Optional[ + Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]] + ] = None, ): """ Initializes the ToolboxClient. @@ -64,15 +66,23 @@ def __parse_tool( self, name: str, schema: ToolSchema, - auth_token_getters: dict[str, Callable[[], str]], - all_bound_params: Mapping[str, Union[Callable[[], Any], Any]], - client_headers: Mapping[str, Union[Callable, Coroutine, str]], + auth_token_getters: Mapping[ + str, Union[Callable[[], str], Callable[[], Awaitable[str]]] + ], + all_bound_params: Mapping[ + str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any] + ], + client_headers: Mapping[ + str, Union[Callable[[], str], Callable[[], Awaitable[str]], str] + ], ) -> tuple[ToolboxTool, set[str], set[str]]: """Internal helper to create a callable tool from its schema.""" # sort into reg, authn, and bound params params = [] authn_params: dict[str, list[str]] = {} - bound_params: dict[str, Callable[[], str]] = {} + bound_params: dict[ + str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any] + ] = {} for p in schema.parameters: if p.authSources: # authn parameter authn_params[p.name] = p.authSources @@ -94,11 +104,11 @@ def __parse_tool( description=schema.description, # create a read-only values to prevent mutation params=tuple(params), - required_authn_params=types.MappingProxyType(authn_params), + required_authn_params=MappingProxyType(authn_params), required_authz_tokens=authz_tokens, - auth_service_token_getters=types.MappingProxyType(auth_token_getters), - bound_params=types.MappingProxyType(bound_params), - client_headers=types.MappingProxyType(client_headers), + auth_service_token_getters=MappingProxyType(auth_token_getters), + bound_params=MappingProxyType(bound_params), + client_headers=MappingProxyType(client_headers), ) used_bound_keys = set(bound_params.keys()) @@ -140,8 +150,12 @@ async def close(self): async def load_tool( self, name: str, - auth_token_getters: dict[str, Callable[[], str]] = {}, - bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, + auth_token_getters: Mapping[ + str, Union[Callable[[], str], Callable[[], Awaitable[str]]] + ] = {}, + bound_params: Mapping[ + str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any] + ] = {}, ) -> ToolboxTool: """ Asynchronously loads a tool from the server. @@ -213,8 +227,12 @@ async def load_tool( async def load_toolset( self, name: Optional[str] = None, - auth_token_getters: dict[str, Callable[[], str]] = {}, - bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, + auth_token_getters: Mapping[ + str, Union[Callable[[], str], Callable[[], Awaitable[str]]] + ] = {}, + bound_params: Mapping[ + str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any] + ] = {}, strict: bool = False, ) -> list[ToolboxTool]: """ @@ -309,7 +327,8 @@ async def load_toolset( return tools def add_headers( - self, headers: Mapping[str, Union[Callable, Coroutine, str]] + self, + headers: Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]]]], ) -> None: """ Add headers to be included in each request sent through this client. diff --git a/packages/toolbox-core/src/toolbox_core/sync_client.py b/packages/toolbox-core/src/toolbox_core/sync_client.py index 5e0c2d41..312a27ad 100644 --- a/packages/toolbox-core/src/toolbox_core/sync_client.py +++ b/packages/toolbox-core/src/toolbox_core/sync_client.py @@ -15,7 +15,7 @@ import asyncio from threading import Thread -from typing import Any, Callable, Coroutine, Mapping, Optional, Union +from typing import Any, Awaitable, Callable, Mapping, Optional, Union from .client import ToolboxClient from .sync_tool import ToolboxSyncTool @@ -35,7 +35,9 @@ class ToolboxSyncClient: def __init__( self, url: str, - client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None, + client_headers: Optional[ + Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]] + ] = None, ): """ Initializes the ToolboxSyncClient. @@ -75,8 +77,12 @@ def close(self): def load_tool( self, name: str, - auth_token_getters: dict[str, Callable[[], str]] = {}, - bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, + auth_token_getters: Mapping[ + str, Union[Callable[[], str], Callable[[], Awaitable[str]]] + ] = {}, + bound_params: Mapping[ + str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any] + ] = {}, ) -> ToolboxSyncTool: """ Synchronously loads a tool from the server. @@ -108,8 +114,12 @@ def load_tool( def load_toolset( self, name: Optional[str] = None, - auth_token_getters: dict[str, Callable[[], str]] = {}, - bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {}, + auth_token_getters: Mapping[ + str, Union[Callable[[], str], Callable[[], Awaitable[str]]] + ] = {}, + bound_params: Mapping[ + str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any] + ] = {}, strict: bool = False, ) -> list[ToolboxSyncTool]: """ @@ -148,7 +158,10 @@ def load_toolset( ] def add_headers( - self, headers: Mapping[str, Union[Callable, Coroutine, str]] + self, + headers: Mapping[ + str, Union[Callable[[], str], Callable[[], Awaitable[str]], str] + ], ) -> None: """ Add headers to be included in each request sent through this client. diff --git a/packages/toolbox-core/src/toolbox_core/sync_tool.py b/packages/toolbox-core/src/toolbox_core/sync_tool.py index 5a545d9f..565bdc8c 100644 --- a/packages/toolbox-core/src/toolbox_core/sync_tool.py +++ b/packages/toolbox-core/src/toolbox_core/sync_tool.py @@ -17,7 +17,7 @@ from asyncio import AbstractEventLoop from inspect import Signature from threading import Thread -from typing import Any, Callable, Coroutine, Mapping, Sequence, Union +from typing import Any, Awaitable, Callable, Mapping, Sequence, Union from .protocol import ParameterSchema from .tool import ToolboxTool @@ -102,7 +102,9 @@ def _params(self) -> Sequence[ParameterSchema]: return self.__async_tool._params @property - def _bound_params(self) -> Mapping[str, Union[Callable[[], Any], Any]]: + def _bound_params( + self, + ) -> Mapping[str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]]: return self.__async_tool._bound_params @property @@ -110,11 +112,15 @@ def _required_auth_params(self) -> Mapping[str, list[str]]: return self.__async_tool._required_auth_params @property - def _auth_service_token_getters(self) -> Mapping[str, Callable[[], str]]: + def _auth_service_token_getters( + self, + ) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]]]]: return self.__async_tool._auth_service_token_getters @property - def _client_headers(self) -> Mapping[str, Union[Callable, Coroutine, str]]: + def _client_headers( + self, + ) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]: return self.__async_tool._client_headers def __call__(self, *args: Any, **kwargs: Any) -> str: @@ -136,7 +142,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> str: def add_auth_token_getters( self, - auth_token_getters: Mapping[str, Callable[[], str]], + auth_token_getters: Mapping[ + str, Union[Callable[[], str], Callable[[], Awaitable[str]]] + ], ) -> "ToolboxSyncTool": """ Registers auth token getter functions that are used for AuthServices @@ -159,7 +167,9 @@ def add_auth_token_getters( return ToolboxSyncTool(new_async_tool, self.__loop, self.__thread) def add_auth_token_getter( - self, auth_source: str, get_id_token: Callable[[], str] + self, + auth_source: str, + get_id_token: Union[Callable[[], str], Callable[[], Awaitable[str]]], ) -> "ToolboxSyncTool": """ Registers an auth token getter function that is used for AuthService @@ -181,7 +191,10 @@ def add_auth_token_getter( return self.add_auth_token_getters({auth_source: get_id_token}) def bind_params( - self, bound_params: Mapping[str, Union[Callable[[], Any], Any]] + self, + bound_params: Mapping[ + str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any] + ], ) -> "ToolboxSyncTool": """ Binds parameters to values or callables that produce values. @@ -204,7 +217,7 @@ def bind_params( def bind_param( self, param_name: str, - param_value: Union[Callable[[], Any], Any], + param_value: Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any], ) -> "ToolboxSyncTool": """ Binds a parameter to the value or callable that produce the value. diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 64dd689a..ebfa8358 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -15,7 +15,7 @@ import copy from inspect import Signature from types import MappingProxyType -from typing import Any, Callable, Coroutine, Mapping, Optional, Sequence, Union +from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union from warnings import warn from aiohttp import ClientSession @@ -51,9 +51,15 @@ def __init__( params: Sequence[ParameterSchema], required_authn_params: Mapping[str, list[str]], required_authz_tokens: Sequence[str], - auth_service_token_getters: Mapping[str, Callable[[], str]], - bound_params: Mapping[str, Union[Callable[[], Any], Any]], - client_headers: Mapping[str, Union[Callable, Coroutine, str]], + auth_service_token_getters: Mapping[ + str, Union[Callable[[], str], Callable[[], Awaitable[str]]] + ], + bound_params: Mapping[ + str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any] + ], + client_headers: Mapping[ + str, Union[Callable[[], str], Callable[[], Awaitable[str]], str] + ], ): """ Initializes a callable that will trigger the tool invocation through the @@ -143,7 +149,9 @@ def _params(self) -> Sequence[ParameterSchema]: return copy.deepcopy(self.__params) @property - def _bound_params(self) -> Mapping[str, Union[Callable[[], Any], Any]]: + def _bound_params( + self, + ) -> Mapping[str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]]: return MappingProxyType(self.__bound_parameters) @property @@ -151,11 +159,15 @@ def _required_auth_params(self) -> Mapping[str, list[str]]: return MappingProxyType(self.__required_authn_params) @property - def _auth_service_token_getters(self) -> Mapping[str, Callable[[], str]]: + def _auth_service_token_getters( + self, + ) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]]]]: return MappingProxyType(self.__auth_service_token_getters) @property - def _client_headers(self) -> Mapping[str, Union[Callable, Coroutine, str]]: + def _client_headers( + self, + ) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]: return MappingProxyType(self.__client_headers) def __copy( @@ -167,9 +179,15 @@ def __copy( params: Optional[Sequence[ParameterSchema]] = None, required_authn_params: Optional[Mapping[str, list[str]]] = None, required_authz_tokens: Optional[Sequence[str]] = None, - auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, - bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None, - client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None, + auth_service_token_getters: Optional[ + Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]]]] + ] = None, + bound_params: Optional[ + Mapping[str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]] + ] = None, + client_headers: Optional[ + Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]] + ] = None, ) -> "ToolboxTool": """ Creates a copy of the ToolboxTool, overriding specific fields. @@ -278,7 +296,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: def add_auth_token_getters( self, - auth_token_getters: Mapping[str, Callable[[], str]], + auth_token_getters: Mapping[ + str, Union[Callable[[], str], Callable[[], Awaitable[str]]] + ], ) -> "ToolboxTool": """ Registers auth token getter functions that are used for AuthServices @@ -347,7 +367,9 @@ def add_auth_token_getters( ) def add_auth_token_getter( - self, auth_source: str, get_id_token: Callable[[], str] + self, + auth_source: str, + get_id_token: Union[Callable[[], str], Callable[[], Awaitable[str]]], ) -> "ToolboxTool": """ Registers an auth token getter function that is used for AuthService @@ -369,7 +391,10 @@ def add_auth_token_getter( return self.add_auth_token_getters({auth_source: get_id_token}) def bind_params( - self, bound_params: Mapping[str, Union[Callable[[], Any], Any]] + self, + bound_params: Mapping[ + str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any] + ], ) -> "ToolboxTool": """ Binds parameters to values or callables that produce values. @@ -413,7 +438,7 @@ def bind_params( def bind_param( self, param_name: str, - param_value: Union[Callable[[], Any], Any], + param_value: Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any], ) -> "ToolboxTool": """ Binds a parameter to the value or callable that produce the value. diff --git a/packages/toolbox-core/src/toolbox_core/utils.py b/packages/toolbox-core/src/toolbox_core/utils.py index de3b728d..615a23ec 100644 --- a/packages/toolbox-core/src/toolbox_core/utils.py +++ b/packages/toolbox-core/src/toolbox_core/utils.py @@ -122,7 +122,7 @@ def params_to_pydantic_model( async def resolve_value( - source: Union[Callable[[], Awaitable[Any]], Callable[[], Any], Any], + source: Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any], ) -> Any: """ Asynchronously or synchronously resolves a given source to its value.