Skip to content

Commit 10087a1

Browse files
authored
feat(toolbox-core): add authenticated parameters support (#119)
1 parent 713c4dc commit 10087a1

File tree

4 files changed

+301
-25
lines changed

4 files changed

+301
-25
lines changed

packages/toolbox-core/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ test = [
4444
"isort==6.0.1",
4545
"mypy==1.15.0",
4646
"pytest==8.3.5",
47-
"pytest-aioresponses==0.3.0"
47+
"pytest-aioresponses==0.3.0",
48+
"pytest-asyncio==0.25.3",
4849
]
4950
[build-system]
5051
requires = ["setuptools"]

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from typing import Optional
14+
import re
15+
import types
16+
from typing import Any, Callable, Optional
1617

1718
from aiohttp import ClientSession
1819

1920
from .protocol import ManifestSchema, ToolSchema
20-
from .tool import ToolboxTool
21+
from .tool import ToolboxTool, identify_required_authn_params
2122

2223

2324
class ToolboxClient:
@@ -53,14 +54,37 @@ def __init__(
5354
session = ClientSession()
5455
self.__session = session
5556

56-
def __parse_tool(self, name: str, schema: ToolSchema) -> ToolboxTool:
57+
def __parse_tool(
58+
self,
59+
name: str,
60+
schema: ToolSchema,
61+
auth_token_getters: dict[str, Callable[[], str]],
62+
) -> ToolboxTool:
5763
"""Internal helper to create a callable tool from its schema."""
64+
# sort into authenticated and reg params
65+
params = []
66+
authn_params: dict[str, list[str]] = {}
67+
auth_sources: set[str] = set()
68+
for p in schema.parameters:
69+
if not p.authSources:
70+
params.append(p)
71+
else:
72+
authn_params[p.name] = p.authSources
73+
auth_sources.update(p.authSources)
74+
75+
authn_params = identify_required_authn_params(
76+
authn_params, auth_token_getters.keys()
77+
)
78+
5879
tool = ToolboxTool(
5980
session=self.__session,
6081
base_url=self.__base_url,
6182
name=name,
6283
desc=schema.description,
63-
params=[p.to_param() for p in schema.parameters],
84+
params=[p.to_param() for p in params],
85+
# create a read-only values for the maps to prevent mutation
86+
required_authn_params=types.MappingProxyType(authn_params),
87+
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
6488
)
6589
return tool
6690

@@ -99,6 +123,7 @@ async def close(self):
99123
async def load_tool(
100124
self,
101125
name: str,
126+
auth_token_getters: dict[str, Callable[[], str]] = {},
102127
) -> ToolboxTool:
103128
"""
104129
Asynchronously loads a tool from the server.
@@ -109,6 +134,8 @@ async def load_tool(
109134
110135
Args:
111136
name: The unique name or identifier of the tool to load.
137+
auth_token_getters: A mapping of authentication service names to
138+
callables that return the corresponding authentication token.
112139
113140
Returns:
114141
ToolboxTool: A callable object representing the loaded tool, ready
@@ -127,19 +154,23 @@ async def load_tool(
127154
if name not in manifest.tools:
128155
# TODO: Better exception
129156
raise Exception(f"Tool '{name}' not found!")
130-
tool = self.__parse_tool(name, manifest.tools[name])
157+
tool = self.__parse_tool(name, manifest.tools[name], auth_token_getters)
131158

132159
return tool
133160

134161
async def load_toolset(
135162
self,
136163
name: str,
164+
auth_token_getters: dict[str, Callable[[], str]] = {},
137165
) -> list[ToolboxTool]:
138166
"""
139167
Asynchronously fetches a toolset and loads all tools defined within it.
140168
141169
Args:
142170
name: Name of the toolset to load tools.
171+
auth_token_getters: A mapping of authentication service names to
172+
callables that return the corresponding authentication token.
173+
143174
144175
Returns:
145176
list[ToolboxTool]: A list of callables, one for each tool defined
@@ -152,5 +183,8 @@ async def load_toolset(
152183
manifest: ManifestSchema = ManifestSchema(**json)
153184

154185
# parse each tools name and schema into a list of ToolboxTools
155-
tools = [self.__parse_tool(n, s) for n, s in manifest.tools.items()]
186+
tools = [
187+
self.__parse_tool(n, s, auth_token_getters)
188+
for n, s in manifest.tools.items()
189+
]
156190
return tools

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 157 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
# limitations under the License.
1414

1515

16+
import types
17+
from collections import defaultdict
1618
from inspect import Parameter, Signature
17-
from typing import Any
19+
from typing import Any, Callable, DefaultDict, Iterable, Mapping, Optional, Sequence
1820

1921
from aiohttp import ClientSession
22+
from pytest import Session
2023

2124

2225
class ToolboxTool:
@@ -32,20 +35,19 @@ class ToolboxTool:
3235
and `inspect` work as expected.
3336
"""
3437

35-
__url: str
36-
__session: ClientSession
37-
__signature__: Signature
38-
3938
def __init__(
4039
self,
4140
session: ClientSession,
4241
base_url: str,
4342
name: str,
4443
desc: str,
45-
params: list[Parameter],
44+
params: Sequence[Parameter],
45+
required_authn_params: Mapping[str, list[str]],
46+
auth_service_token_getters: Mapping[str, Callable[[], str]],
4647
):
4748
"""
48-
Initializes a callable that will trigger the tool invocation through the Toolbox server.
49+
Initializes a callable that will trigger the tool invocation through the
50+
Toolbox server.
4951
5052
Args:
5153
session: The `aiohttp.ClientSession` used for making API requests.
@@ -54,19 +56,73 @@ def __init__(
5456
desc: The description of the remote tool (used as its docstring).
5557
params: A list of `inspect.Parameter` objects defining the tool's
5658
arguments and their types/defaults.
59+
required_authn_params: A dict of required authenticated parameters to a list
60+
of services that provide values for them.
61+
auth_service_token_getters: A dict of authService -> token (or callables that
62+
produce a token)
5763
"""
5864

5965
# used to invoke the toolbox API
60-
self.__session = session
66+
self.__session: ClientSession = session
67+
self.__base_url: str = base_url
6168
self.__url = f"{base_url}/api/tool/{name}/invoke"
6269

63-
# the following properties are set to help anyone that might inspect it determine
70+
self.__desc = desc
71+
self.__params = params
72+
73+
# the following properties are set to help anyone that might inspect it determine usage
6474
self.__name__ = name
6575
self.__doc__ = desc
6676
self.__signature__ = Signature(parameters=params, return_annotation=str)
6777
self.__annotations__ = {p.name: p.annotation for p in params}
6878
# TODO: self.__qualname__ ??
6979

80+
# map of parameter name to auth service required by it
81+
self.__required_authn_params = required_authn_params
82+
# map of authService -> token_getter
83+
self.__auth_service_token_getters = auth_service_token_getters
84+
85+
def __copy(
86+
self,
87+
session: Optional[ClientSession] = None,
88+
base_url: Optional[str] = None,
89+
name: Optional[str] = None,
90+
desc: Optional[str] = None,
91+
params: Optional[list[Parameter]] = None,
92+
required_authn_params: Optional[Mapping[str, list[str]]] = None,
93+
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
94+
) -> "ToolboxTool":
95+
"""
96+
Creates a copy of the ToolboxTool, overriding specific fields.
97+
98+
Args:
99+
session: The `aiohttp.ClientSession` used for making API requests.
100+
base_url: The base URL of the Toolbox server API.
101+
name: The name of the remote tool.
102+
desc: The description of the remote tool (used as its docstring).
103+
params: A list of `inspect.Parameter` objects defining the tool's
104+
arguments and their types/defaults.
105+
required_authn_params: A dict of required authenticated parameters that need
106+
a auth_service_token_getter set for them yet.
107+
auth_service_token_getters: A dict of authService -> token (or callables
108+
that produce a token)
109+
110+
"""
111+
check = lambda val, default: val if val is not None else default
112+
return ToolboxTool(
113+
session=check(session, self.__session),
114+
base_url=check(base_url, self.__base_url),
115+
name=check(name, self.__name__),
116+
desc=check(desc, self.__desc),
117+
params=check(params, self.__params),
118+
required_authn_params=check(
119+
required_authn_params, self.__required_authn_params
120+
),
121+
auth_service_token_getters=check(
122+
auth_service_token_getters, self.__auth_service_token_getters
123+
),
124+
)
125+
70126
async def __call__(self, *args: Any, **kwargs: Any) -> str:
71127
"""
72128
Asynchronously calls the remote tool with the provided arguments.
@@ -81,16 +137,103 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
81137
Returns:
82138
The string result returned by the remote tool execution.
83139
"""
140+
141+
# check if any auth services need to be specified yet
142+
if len(self.__required_authn_params) > 0:
143+
# Gather all the required auth services into a set
144+
req_auth_services = set()
145+
for s in self.__required_authn_params.values():
146+
req_auth_services.update(s)
147+
raise Exception(
148+
f"One or more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}"
149+
)
150+
151+
# validate inputs to this call using the signature
84152
all_args = self.__signature__.bind(*args, **kwargs)
85153
all_args.apply_defaults() # Include default values if not provided
86154
payload = all_args.arguments
87155

156+
# create headers for auth services
157+
headers = {}
158+
for auth_service, token_getter in self.__auth_service_token_getters.items():
159+
headers[f"{auth_service}_token"] = token_getter()
160+
88161
async with self.__session.post(
89162
self.__url,
90163
json=payload,
164+
headers=headers,
91165
) as resp:
92-
ret = await resp.json()
93-
if "error" in ret:
94-
# TODO: better error
95-
raise Exception(ret["error"])
96-
return ret.get("result", ret)
166+
body = await resp.json()
167+
if resp.status < 200 or resp.status >= 300:
168+
err = body.get("error", f"unexpected status from server: {resp.status}")
169+
raise Exception(err)
170+
return body.get("result", body)
171+
172+
def add_auth_token_getters(
173+
self,
174+
auth_token_getters: Mapping[str, Callable[[], str]],
175+
) -> "ToolboxTool":
176+
"""
177+
Registers an auth token getter function that is used for AuthServices when tools
178+
are invoked.
179+
180+
Args:
181+
auth_token_getters: A mapping of authentication service names to
182+
callables that return the corresponding authentication token.
183+
184+
Returns:
185+
A new ToolboxTool instance with the specified authentication token
186+
getters registered.
187+
"""
188+
189+
# throw an error if the authentication source is already registered
190+
existing_services = self.__auth_service_token_getters.keys()
191+
incoming_services = auth_token_getters.keys()
192+
duplicates = existing_services & incoming_services
193+
if duplicates:
194+
raise ValueError(
195+
f"Authentication source(s) `{', '.join(duplicates)}` already registered in tool `{self.__name__}`."
196+
)
197+
198+
# create a read-only updated value for new_getters
199+
new_getters = types.MappingProxyType(
200+
dict(self.__auth_service_token_getters, **auth_token_getters)
201+
)
202+
# create a read-only updated for params that are still required
203+
new_req_authn_params = types.MappingProxyType(
204+
identify_required_authn_params(
205+
self.__required_authn_params, auth_token_getters.keys()
206+
)
207+
)
208+
209+
return self.__copy(
210+
auth_service_token_getters=new_getters,
211+
required_authn_params=new_req_authn_params,
212+
)
213+
214+
215+
def identify_required_authn_params(
216+
req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str]
217+
) -> dict[str, list[str]]:
218+
"""
219+
Identifies authentication parameters that are still required; or not covered by
220+
the provided `auth_service_names`.
221+
222+
Args:
223+
req_authn_params: A mapping of parameter names to sets of required
224+
authentication services.
225+
auth_service_names: An iterable of authentication service names for which
226+
token getters are available.
227+
228+
Returns:
229+
A new dictionary representing the subset of required authentication
230+
parameters that are not covered by the provided `auth_service_names`.
231+
"""
232+
required_params = {} # params that are still required with provided auth_services
233+
for param, services in req_authn_params.items():
234+
# if we don't have a token_getter for any of the services required by the param,
235+
# the param is still required
236+
required = not any(s in services for s in auth_service_names)
237+
if required:
238+
required_params[param] = services
239+
return required_params

0 commit comments

Comments
 (0)