|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 |
|
16 |
| -import types |
17 | 16 | from typing import Any, Callable, Mapping, Optional, Union
|
18 | 17 |
|
19 | 18 | from aiohttp import ClientSession
|
20 | 19 |
|
21 | 20 | from .protocol import ManifestSchema, ToolSchema
|
22 | 21 | from .tool import ToolboxTool, identify_required_authn_params
|
23 |
| - |
| 22 | +from .utils import parse_tool |
24 | 23 |
|
25 | 24 | class ToolboxClient:
|
26 | 25 | """
|
@@ -55,42 +54,6 @@ def __init__(
|
55 | 54 | session = ClientSession()
|
56 | 55 | self.__session = session
|
57 | 56 |
|
58 |
| - def __parse_tool( |
59 |
| - self, |
60 |
| - name: str, |
61 |
| - schema: ToolSchema, |
62 |
| - auth_token_getters: dict[str, Callable[[], str]], |
63 |
| - all_bound_params: Mapping[str, Union[Callable[[], Any], Any]], |
64 |
| - ) -> ToolboxTool: |
65 |
| - """Internal helper to create a callable tool from its schema.""" |
66 |
| - # sort into reg, authn, and bound params |
67 |
| - params = [] |
68 |
| - authn_params: dict[str, list[str]] = {} |
69 |
| - bound_params: dict[str, Callable[[], str]] = {} |
70 |
| - for p in schema.parameters: |
71 |
| - if p.authSources: # authn parameter |
72 |
| - authn_params[p.name] = p.authSources |
73 |
| - elif p.name in all_bound_params: # bound parameter |
74 |
| - bound_params[p.name] = all_bound_params[p.name] |
75 |
| - else: # regular parameter |
76 |
| - params.append(p) |
77 |
| - |
78 |
| - authn_params = identify_required_authn_params( |
79 |
| - authn_params, auth_token_getters.keys() |
80 |
| - ) |
81 |
| - |
82 |
| - tool = ToolboxTool( |
83 |
| - session=self.__session, |
84 |
| - base_url=self.__base_url, |
85 |
| - name=name, |
86 |
| - description=schema.description, |
87 |
| - params=params, |
88 |
| - # create a read-only values for the maps to prevent mutation |
89 |
| - required_authn_params=types.MappingProxyType(authn_params), |
90 |
| - auth_service_token_getters=types.MappingProxyType(auth_token_getters), |
91 |
| - bound_params=types.MappingProxyType(bound_params), |
92 |
| - ) |
93 |
| - return tool |
94 | 57 |
|
95 | 58 | async def __aenter__(self):
|
96 | 59 | """
|
@@ -163,8 +126,8 @@ async def load_tool(
|
163 | 126 | if name not in manifest.tools:
|
164 | 127 | # TODO: Better exception
|
165 | 128 | raise Exception(f"Tool '{name}' not found!")
|
166 |
| - tool = self.__parse_tool( |
167 |
| - name, manifest.tools[name], auth_token_getters, bound_params |
| 129 | + tool = parse_tool( |
| 130 | + self.__session, self.__base_url, name, manifest.tools[name], auth_token_getters, bound_params |
168 | 131 | )
|
169 | 132 |
|
170 | 133 | return tool
|
@@ -199,7 +162,7 @@ async def load_toolset(
|
199 | 162 |
|
200 | 163 | # parse each tools name and schema into a list of ToolboxTools
|
201 | 164 | tools = [
|
202 |
| - self.__parse_tool(n, s, auth_token_getters, bound_params) |
| 165 | + parse_tool(self.__session, self.__base_url, n, s, auth_token_getters, bound_params) |
203 | 166 | for n, s in manifest.tools.items()
|
204 | 167 | ]
|
205 | 168 | return tools
|
0 commit comments