Skip to content

Commit 996545e

Browse files
authored
chore: Make type hints of tool and client classes more accurate and specific (#249)
1 parent 8fb9762 commit 996545e

File tree

5 files changed

+116
-46
lines changed

5 files changed

+116
-46
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# limitations under the License.
1414

1515

16-
import types
17-
from typing import Any, Callable, Coroutine, Mapping, Optional, Union
16+
from types import MappingProxyType
17+
from typing import Any, Awaitable, Callable, Mapping, Optional, Union
1818

1919
from aiohttp import ClientSession
2020

@@ -38,7 +38,9 @@ def __init__(
3838
self,
3939
url: str,
4040
session: Optional[ClientSession] = None,
41-
client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None,
41+
client_headers: Optional[
42+
Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]
43+
] = None,
4244
):
4345
"""
4446
Initializes the ToolboxClient.
@@ -64,15 +66,23 @@ def __parse_tool(
6466
self,
6567
name: str,
6668
schema: ToolSchema,
67-
auth_token_getters: dict[str, Callable[[], str]],
68-
all_bound_params: Mapping[str, Union[Callable[[], Any], Any]],
69-
client_headers: Mapping[str, Union[Callable, Coroutine, str]],
69+
auth_token_getters: Mapping[
70+
str, Union[Callable[[], str], Callable[[], Awaitable[str]]]
71+
],
72+
all_bound_params: Mapping[
73+
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
74+
],
75+
client_headers: Mapping[
76+
str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]
77+
],
7078
) -> tuple[ToolboxTool, set[str], set[str]]:
7179
"""Internal helper to create a callable tool from its schema."""
7280
# sort into reg, authn, and bound params
7381
params = []
7482
authn_params: dict[str, list[str]] = {}
75-
bound_params: dict[str, Callable[[], str]] = {}
83+
bound_params: dict[
84+
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
85+
] = {}
7686
for p in schema.parameters:
7787
if p.authSources: # authn parameter
7888
authn_params[p.name] = p.authSources
@@ -94,11 +104,11 @@ def __parse_tool(
94104
description=schema.description,
95105
# create a read-only values to prevent mutation
96106
params=tuple(params),
97-
required_authn_params=types.MappingProxyType(authn_params),
107+
required_authn_params=MappingProxyType(authn_params),
98108
required_authz_tokens=authz_tokens,
99-
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
100-
bound_params=types.MappingProxyType(bound_params),
101-
client_headers=types.MappingProxyType(client_headers),
109+
auth_service_token_getters=MappingProxyType(auth_token_getters),
110+
bound_params=MappingProxyType(bound_params),
111+
client_headers=MappingProxyType(client_headers),
102112
)
103113

104114
used_bound_keys = set(bound_params.keys())
@@ -140,8 +150,12 @@ async def close(self):
140150
async def load_tool(
141151
self,
142152
name: str,
143-
auth_token_getters: dict[str, Callable[[], str]] = {},
144-
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
153+
auth_token_getters: Mapping[
154+
str, Union[Callable[[], str], Callable[[], Awaitable[str]]]
155+
] = {},
156+
bound_params: Mapping[
157+
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
158+
] = {},
145159
) -> ToolboxTool:
146160
"""
147161
Asynchronously loads a tool from the server.
@@ -213,8 +227,12 @@ async def load_tool(
213227
async def load_toolset(
214228
self,
215229
name: Optional[str] = None,
216-
auth_token_getters: dict[str, Callable[[], str]] = {},
217-
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
230+
auth_token_getters: Mapping[
231+
str, Union[Callable[[], str], Callable[[], Awaitable[str]]]
232+
] = {},
233+
bound_params: Mapping[
234+
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
235+
] = {},
218236
strict: bool = False,
219237
) -> list[ToolboxTool]:
220238
"""
@@ -309,7 +327,8 @@ async def load_toolset(
309327
return tools
310328

311329
def add_headers(
312-
self, headers: Mapping[str, Union[Callable, Coroutine, str]]
330+
self,
331+
headers: Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]]]],
313332
) -> None:
314333
"""
315334
Add headers to be included in each request sent through this client.

packages/toolbox-core/src/toolbox_core/sync_client.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import asyncio
1717
from threading import Thread
18-
from typing import Any, Callable, Coroutine, Mapping, Optional, Union
18+
from typing import Any, Awaitable, Callable, Mapping, Optional, Union
1919

2020
from .client import ToolboxClient
2121
from .sync_tool import ToolboxSyncTool
@@ -35,7 +35,9 @@ class ToolboxSyncClient:
3535
def __init__(
3636
self,
3737
url: str,
38-
client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None,
38+
client_headers: Optional[
39+
Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]
40+
] = None,
3941
):
4042
"""
4143
Initializes the ToolboxSyncClient.
@@ -75,8 +77,12 @@ def close(self):
7577
def load_tool(
7678
self,
7779
name: str,
78-
auth_token_getters: dict[str, Callable[[], str]] = {},
79-
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
80+
auth_token_getters: Mapping[
81+
str, Union[Callable[[], str], Callable[[], Awaitable[str]]]
82+
] = {},
83+
bound_params: Mapping[
84+
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
85+
] = {},
8086
) -> ToolboxSyncTool:
8187
"""
8288
Synchronously loads a tool from the server.
@@ -108,8 +114,12 @@ def load_tool(
108114
def load_toolset(
109115
self,
110116
name: Optional[str] = None,
111-
auth_token_getters: dict[str, Callable[[], str]] = {},
112-
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
117+
auth_token_getters: Mapping[
118+
str, Union[Callable[[], str], Callable[[], Awaitable[str]]]
119+
] = {},
120+
bound_params: Mapping[
121+
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
122+
] = {},
113123
strict: bool = False,
114124
) -> list[ToolboxSyncTool]:
115125
"""
@@ -148,7 +158,10 @@ def load_toolset(
148158
]
149159

150160
def add_headers(
151-
self, headers: Mapping[str, Union[Callable, Coroutine, str]]
161+
self,
162+
headers: Mapping[
163+
str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]
164+
],
152165
) -> None:
153166
"""
154167
Add headers to be included in each request sent through this client.

packages/toolbox-core/src/toolbox_core/sync_tool.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from asyncio import AbstractEventLoop
1818
from inspect import Signature
1919
from threading import Thread
20-
from typing import Any, Callable, Coroutine, Mapping, Sequence, Union
20+
from typing import Any, Awaitable, Callable, Mapping, Sequence, Union
2121

2222
from .protocol import ParameterSchema
2323
from .tool import ToolboxTool
@@ -102,19 +102,25 @@ def _params(self) -> Sequence[ParameterSchema]:
102102
return self.__async_tool._params
103103

104104
@property
105-
def _bound_params(self) -> Mapping[str, Union[Callable[[], Any], Any]]:
105+
def _bound_params(
106+
self,
107+
) -> Mapping[str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]]:
106108
return self.__async_tool._bound_params
107109

108110
@property
109111
def _required_auth_params(self) -> Mapping[str, list[str]]:
110112
return self.__async_tool._required_auth_params
111113

112114
@property
113-
def _auth_service_token_getters(self) -> Mapping[str, Callable[[], str]]:
115+
def _auth_service_token_getters(
116+
self,
117+
) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]]]]:
114118
return self.__async_tool._auth_service_token_getters
115119

116120
@property
117-
def _client_headers(self) -> Mapping[str, Union[Callable, Coroutine, str]]:
121+
def _client_headers(
122+
self,
123+
) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]:
118124
return self.__async_tool._client_headers
119125

120126
def __call__(self, *args: Any, **kwargs: Any) -> str:
@@ -136,7 +142,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> str:
136142

137143
def add_auth_token_getters(
138144
self,
139-
auth_token_getters: Mapping[str, Callable[[], str]],
145+
auth_token_getters: Mapping[
146+
str, Union[Callable[[], str], Callable[[], Awaitable[str]]]
147+
],
140148
) -> "ToolboxSyncTool":
141149
"""
142150
Registers auth token getter functions that are used for AuthServices
@@ -159,7 +167,9 @@ def add_auth_token_getters(
159167
return ToolboxSyncTool(new_async_tool, self.__loop, self.__thread)
160168

161169
def add_auth_token_getter(
162-
self, auth_source: str, get_id_token: Callable[[], str]
170+
self,
171+
auth_source: str,
172+
get_id_token: Union[Callable[[], str], Callable[[], Awaitable[str]]],
163173
) -> "ToolboxSyncTool":
164174
"""
165175
Registers an auth token getter function that is used for AuthService
@@ -181,7 +191,10 @@ def add_auth_token_getter(
181191
return self.add_auth_token_getters({auth_source: get_id_token})
182192

183193
def bind_params(
184-
self, bound_params: Mapping[str, Union[Callable[[], Any], Any]]
194+
self,
195+
bound_params: Mapping[
196+
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
197+
],
185198
) -> "ToolboxSyncTool":
186199
"""
187200
Binds parameters to values or callables that produce values.
@@ -204,7 +217,7 @@ def bind_params(
204217
def bind_param(
205218
self,
206219
param_name: str,
207-
param_value: Union[Callable[[], Any], Any],
220+
param_value: Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any],
208221
) -> "ToolboxSyncTool":
209222
"""
210223
Binds a parameter to the value or callable that produce the value.

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import copy
1616
from inspect import Signature
1717
from types import MappingProxyType
18-
from typing import Any, Callable, Coroutine, Mapping, Optional, Sequence, Union
18+
from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union
1919
from warnings import warn
2020

2121
from aiohttp import ClientSession
@@ -51,9 +51,15 @@ def __init__(
5151
params: Sequence[ParameterSchema],
5252
required_authn_params: Mapping[str, list[str]],
5353
required_authz_tokens: Sequence[str],
54-
auth_service_token_getters: Mapping[str, Callable[[], str]],
55-
bound_params: Mapping[str, Union[Callable[[], Any], Any]],
56-
client_headers: Mapping[str, Union[Callable, Coroutine, str]],
54+
auth_service_token_getters: Mapping[
55+
str, Union[Callable[[], str], Callable[[], Awaitable[str]]]
56+
],
57+
bound_params: Mapping[
58+
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
59+
],
60+
client_headers: Mapping[
61+
str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]
62+
],
5763
):
5864
"""
5965
Initializes a callable that will trigger the tool invocation through the
@@ -143,19 +149,25 @@ def _params(self) -> Sequence[ParameterSchema]:
143149
return copy.deepcopy(self.__params)
144150

145151
@property
146-
def _bound_params(self) -> Mapping[str, Union[Callable[[], Any], Any]]:
152+
def _bound_params(
153+
self,
154+
) -> Mapping[str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]]:
147155
return MappingProxyType(self.__bound_parameters)
148156

149157
@property
150158
def _required_auth_params(self) -> Mapping[str, list[str]]:
151159
return MappingProxyType(self.__required_authn_params)
152160

153161
@property
154-
def _auth_service_token_getters(self) -> Mapping[str, Callable[[], str]]:
162+
def _auth_service_token_getters(
163+
self,
164+
) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]]]]:
155165
return MappingProxyType(self.__auth_service_token_getters)
156166

157167
@property
158-
def _client_headers(self) -> Mapping[str, Union[Callable, Coroutine, str]]:
168+
def _client_headers(
169+
self,
170+
) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]:
159171
return MappingProxyType(self.__client_headers)
160172

161173
def __copy(
@@ -167,9 +179,15 @@ def __copy(
167179
params: Optional[Sequence[ParameterSchema]] = None,
168180
required_authn_params: Optional[Mapping[str, list[str]]] = None,
169181
required_authz_tokens: Optional[Sequence[str]] = None,
170-
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
171-
bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None,
172-
client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None,
182+
auth_service_token_getters: Optional[
183+
Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]]]]
184+
] = None,
185+
bound_params: Optional[
186+
Mapping[str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]]
187+
] = None,
188+
client_headers: Optional[
189+
Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]
190+
] = None,
173191
) -> "ToolboxTool":
174192
"""
175193
Creates a copy of the ToolboxTool, overriding specific fields.
@@ -278,7 +296,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
278296

279297
def add_auth_token_getters(
280298
self,
281-
auth_token_getters: Mapping[str, Callable[[], str]],
299+
auth_token_getters: Mapping[
300+
str, Union[Callable[[], str], Callable[[], Awaitable[str]]]
301+
],
282302
) -> "ToolboxTool":
283303
"""
284304
Registers auth token getter functions that are used for AuthServices
@@ -347,7 +367,9 @@ def add_auth_token_getters(
347367
)
348368

349369
def add_auth_token_getter(
350-
self, auth_source: str, get_id_token: Callable[[], str]
370+
self,
371+
auth_source: str,
372+
get_id_token: Union[Callable[[], str], Callable[[], Awaitable[str]]],
351373
) -> "ToolboxTool":
352374
"""
353375
Registers an auth token getter function that is used for AuthService
@@ -369,7 +391,10 @@ def add_auth_token_getter(
369391
return self.add_auth_token_getters({auth_source: get_id_token})
370392

371393
def bind_params(
372-
self, bound_params: Mapping[str, Union[Callable[[], Any], Any]]
394+
self,
395+
bound_params: Mapping[
396+
str, Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any]
397+
],
373398
) -> "ToolboxTool":
374399
"""
375400
Binds parameters to values or callables that produce values.
@@ -413,7 +438,7 @@ def bind_params(
413438
def bind_param(
414439
self,
415440
param_name: str,
416-
param_value: Union[Callable[[], Any], Any],
441+
param_value: Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any],
417442
) -> "ToolboxTool":
418443
"""
419444
Binds a parameter to the value or callable that produce the value.

packages/toolbox-core/src/toolbox_core/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def params_to_pydantic_model(
122122

123123

124124
async def resolve_value(
125-
source: Union[Callable[[], Awaitable[Any]], Callable[[], Any], Any],
125+
source: Union[Callable[[], Any], Callable[[], Awaitable[Any]], Any],
126126
) -> Any:
127127
"""
128128
Asynchronously or synchronously resolves a given source to its value.

0 commit comments

Comments
 (0)