Skip to content

chore: Improve type hints for async bound_params, auth_token_getters, and client_headers #249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions packages/toolbox-core/src/toolbox_core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.


import types
from typing import Any, Callable, Coroutine, Mapping, Optional, Union
from types import MappingProxyType
from typing import Any, Awaitable, Callable, Mapping, Optional, Union

from aiohttp import ClientSession

Expand All @@ -38,7 +38,9 @@ def __init__(
self,
url: str,
session: Optional[ClientSession] = None,
client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None,
client_headers: Optional[
Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]
] = None,
):
"""
Initializes the ToolboxClient.
Expand All @@ -64,15 +66,23 @@ def __parse_tool(
self,
name: str,
schema: ToolSchema,
auth_token_getters: dict[str, Callable[[], str]],
all_bound_params: Mapping[str, Union[Callable[[], Any], Any]],
client_headers: Mapping[str, Union[Callable, Coroutine, str]],
auth_token_getters: Mapping[
str, Union[Callable[[], str], Callable[[], Awaitable[str]]]
],
all_bound_params: Mapping[
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
],
client_headers: Mapping[
str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]
],
) -> tuple[ToolboxTool, set[str], set[str]]:
"""Internal helper to create a callable tool from its schema."""
# sort into reg, authn, and bound params
params = []
authn_params: dict[str, list[str]] = {}
bound_params: dict[str, Callable[[], str]] = {}
bound_params: dict[
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
] = {}
for p in schema.parameters:
if p.authSources: # authn parameter
authn_params[p.name] = p.authSources
Expand All @@ -94,11 +104,11 @@ def __parse_tool(
description=schema.description,
# create a read-only values to prevent mutation
params=tuple(params),
required_authn_params=types.MappingProxyType(authn_params),
required_authn_params=MappingProxyType(authn_params),
required_authz_tokens=authz_tokens,
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
bound_params=types.MappingProxyType(bound_params),
client_headers=types.MappingProxyType(client_headers),
auth_service_token_getters=MappingProxyType(auth_token_getters),
bound_params=MappingProxyType(bound_params),
client_headers=MappingProxyType(client_headers),
)

used_bound_keys = set(bound_params.keys())
Expand Down Expand Up @@ -140,8 +150,12 @@ async def close(self):
async def load_tool(
self,
name: str,
auth_token_getters: dict[str, Callable[[], str]] = {},
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
auth_token_getters: Mapping[
str, Union[Callable[[], str], Callable[[], Awaitable[str]]]
] = {},
bound_params: Mapping[
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
] = {},
) -> ToolboxTool:
"""
Asynchronously loads a tool from the server.
Expand Down Expand Up @@ -213,8 +227,12 @@ async def load_tool(
async def load_toolset(
self,
name: Optional[str] = None,
auth_token_getters: dict[str, Callable[[], str]] = {},
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
auth_token_getters: Mapping[
str, Union[Callable[[], str], Callable[[], Awaitable[str]]]
] = {},
bound_params: Mapping[
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
] = {},
strict: bool = False,
) -> list[ToolboxTool]:
"""
Expand Down Expand Up @@ -309,7 +327,8 @@ async def load_toolset(
return tools

def add_headers(
self, headers: Mapping[str, Union[Callable, Coroutine, str]]
self,
headers: Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]]]],
) -> None:
"""
Add headers to be included in each request sent through this client.
Expand Down
27 changes: 20 additions & 7 deletions packages/toolbox-core/src/toolbox_core/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import asyncio
from threading import Thread
from typing import Any, Callable, Coroutine, Mapping, Optional, Union
from typing import Any, Awaitable, Callable, Mapping, Optional, Union

from .client import ToolboxClient
from .sync_tool import ToolboxSyncTool
Expand All @@ -35,7 +35,9 @@ class ToolboxSyncClient:
def __init__(
self,
url: str,
client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None,
client_headers: Optional[
Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]
] = None,
):
"""
Initializes the ToolboxSyncClient.
Expand Down Expand Up @@ -75,8 +77,12 @@ def close(self):
def load_tool(
self,
name: str,
auth_token_getters: dict[str, Callable[[], str]] = {},
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
auth_token_getters: Mapping[
str, Union[Callable[[], str], Callable[[], Awaitable[str]]]
] = {},
bound_params: Mapping[
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
] = {},
) -> ToolboxSyncTool:
"""
Synchronously loads a tool from the server.
Expand Down Expand Up @@ -108,8 +114,12 @@ def load_tool(
def load_toolset(
self,
name: Optional[str] = None,
auth_token_getters: dict[str, Callable[[], str]] = {},
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
auth_token_getters: Mapping[
str, Union[Callable[[], str], Callable[[], Awaitable[str]]]
] = {},
bound_params: Mapping[
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
] = {},
strict: bool = False,
) -> list[ToolboxSyncTool]:
"""
Expand Down Expand Up @@ -148,7 +158,10 @@ def load_toolset(
]

def add_headers(
self, headers: Mapping[str, Union[Callable, Coroutine, str]]
self,
headers: Mapping[
str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]
],
) -> None:
"""
Add headers to be included in each request sent through this client.
Expand Down
29 changes: 21 additions & 8 deletions packages/toolbox-core/src/toolbox_core/sync_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from asyncio import AbstractEventLoop
from inspect import Signature
from threading import Thread
from typing import Any, Callable, Coroutine, Mapping, Sequence, Union
from typing import Any, Awaitable, Callable, Mapping, Sequence, Union

from .protocol import ParameterSchema
from .tool import ToolboxTool
Expand Down Expand Up @@ -102,19 +102,25 @@ def _params(self) -> Sequence[ParameterSchema]:
return self.__async_tool._params

@property
def _bound_params(self) -> Mapping[str, Union[Callable[[], Any], Any]]:
def _bound_params(
self,
) -> Mapping[str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]]:
return self.__async_tool._bound_params

@property
def _required_auth_params(self) -> Mapping[str, list[str]]:
return self.__async_tool._required_auth_params

@property
def _auth_service_token_getters(self) -> Mapping[str, Callable[[], str]]:
def _auth_service_token_getters(
self,
) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]]]]:
return self.__async_tool._auth_service_token_getters

@property
def _client_headers(self) -> Mapping[str, Union[Callable, Coroutine, str]]:
def _client_headers(
self,
) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]:
return self.__async_tool._client_headers

def __call__(self, *args: Any, **kwargs: Any) -> str:
Expand All @@ -136,7 +142,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> str:

def add_auth_token_getters(
self,
auth_token_getters: Mapping[str, Callable[[], str]],
auth_token_getters: Mapping[
str, Union[Callable[[], str], Callable[[], Awaitable[str]]]
],
) -> "ToolboxSyncTool":
"""
Registers auth token getter functions that are used for AuthServices
Expand All @@ -159,7 +167,9 @@ def add_auth_token_getters(
return ToolboxSyncTool(new_async_tool, self.__loop, self.__thread)

def add_auth_token_getter(
self, auth_source: str, get_id_token: Callable[[], str]
self,
auth_source: str,
get_id_token: Union[Callable[[], str], Callable[[], Awaitable[str]]],
) -> "ToolboxSyncTool":
"""
Registers an auth token getter function that is used for AuthService
Expand All @@ -181,7 +191,10 @@ def add_auth_token_getter(
return self.add_auth_token_getters({auth_source: get_id_token})

def bind_params(
self, bound_params: Mapping[str, Union[Callable[[], Any], Any]]
self,
bound_params: Mapping[
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
],
) -> "ToolboxSyncTool":
"""
Binds parameters to values or callables that produce values.
Expand All @@ -204,7 +217,7 @@ def bind_params(
def bind_param(
self,
param_name: str,
param_value: Union[Callable[[], Any], Any],
param_value: Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any],
) -> "ToolboxSyncTool":
"""
Binds a parameter to the value or callable that produce the value.
Expand Down
53 changes: 39 additions & 14 deletions packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import copy
from inspect import Signature
from types import MappingProxyType
from typing import Any, Callable, Coroutine, Mapping, Optional, Sequence, Union
from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union
from warnings import warn

from aiohttp import ClientSession
Expand Down Expand Up @@ -51,9 +51,15 @@ def __init__(
params: Sequence[ParameterSchema],
required_authn_params: Mapping[str, list[str]],
required_authz_tokens: Sequence[str],
auth_service_token_getters: Mapping[str, Callable[[], str]],
bound_params: Mapping[str, Union[Callable[[], Any], Any]],
client_headers: Mapping[str, Union[Callable, Coroutine, str]],
auth_service_token_getters: Mapping[
str, Union[Callable[[], str], Callable[[], Awaitable[str]]]
],
bound_params: Mapping[
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
],
client_headers: Mapping[
str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]
],
):
"""
Initializes a callable that will trigger the tool invocation through the
Expand Down Expand Up @@ -143,19 +149,25 @@ def _params(self) -> Sequence[ParameterSchema]:
return copy.deepcopy(self.__params)

@property
def _bound_params(self) -> Mapping[str, Union[Callable[[], Any], Any]]:
def _bound_params(
self,
) -> Mapping[str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]]:
return MappingProxyType(self.__bound_parameters)

@property
def _required_auth_params(self) -> Mapping[str, list[str]]:
return MappingProxyType(self.__required_authn_params)

@property
def _auth_service_token_getters(self) -> Mapping[str, Callable[[], str]]:
def _auth_service_token_getters(
self,
) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]]]]:
return MappingProxyType(self.__auth_service_token_getters)

@property
def _client_headers(self) -> Mapping[str, Union[Callable, Coroutine, str]]:
def _client_headers(
self,
) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]:
return MappingProxyType(self.__client_headers)

def __copy(
Expand All @@ -167,9 +179,15 @@ def __copy(
params: Optional[Sequence[ParameterSchema]] = None,
required_authn_params: Optional[Mapping[str, list[str]]] = None,
required_authz_tokens: Optional[Sequence[str]] = None,
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None,
client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None,
auth_service_token_getters: Optional[
Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]]]]
] = None,
bound_params: Optional[
Mapping[str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]]
] = None,
client_headers: Optional[
Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]
] = None,
) -> "ToolboxTool":
"""
Creates a copy of the ToolboxTool, overriding specific fields.
Expand Down Expand Up @@ -278,7 +296,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:

def add_auth_token_getters(
self,
auth_token_getters: Mapping[str, Callable[[], str]],
auth_token_getters: Mapping[
str, Union[Callable[[], str], Callable[[], Awaitable[str]]]
],
) -> "ToolboxTool":
"""
Registers auth token getter functions that are used for AuthServices
Expand Down Expand Up @@ -347,7 +367,9 @@ def add_auth_token_getters(
)

def add_auth_token_getter(
self, auth_source: str, get_id_token: Callable[[], str]
self,
auth_source: str,
get_id_token: Union[Callable[[], str], Callable[[], Awaitable[str]]],
) -> "ToolboxTool":
"""
Registers an auth token getter function that is used for AuthService
Expand All @@ -369,7 +391,10 @@ def add_auth_token_getter(
return self.add_auth_token_getters({auth_source: get_id_token})

def bind_params(
self, bound_params: Mapping[str, Union[Callable[[], Any], Any]]
self,
bound_params: Mapping[
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
],
) -> "ToolboxTool":
"""
Binds parameters to values or callables that produce values.
Expand Down Expand Up @@ -413,7 +438,7 @@ def bind_params(
def bind_param(
self,
param_name: str,
param_value: Union[Callable[[], Any], Any],
param_value: Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any],
) -> "ToolboxTool":
"""
Binds a parameter to the value or callable that produce the value.
Expand Down
2 changes: 1 addition & 1 deletion packages/toolbox-core/src/toolbox_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def params_to_pydantic_model(


async def resolve_value(
source: Union[Callable[[], Awaitable[Any]], Callable[[], Any], Any],
source: Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any],
) -> Any:
"""
Asynchronously or synchronously resolves a given source to its value.
Expand Down