Skip to content

Commit 5f1b2b0

Browse files
committed
feat: add authenticated parameters support
1 parent 4f992d8 commit 5f1b2b0

File tree

3 files changed

+279
-25
lines changed

3 files changed

+279
-25
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from typing import Optional
14+
import types
15+
from typing import Any, Callable, Optional
1616

1717
from aiohttp import ClientSession
1818

1919
from .protocol import ManifestSchema, ToolSchema
20-
from .tool import ToolboxTool
20+
from .tool import ToolboxTool, filter_required_authn_params
2121

2222

2323
class ToolboxClient:
@@ -53,14 +53,34 @@ def __init__(
5353
session = ClientSession()
5454
self.__session = session
5555

56-
def __parse_tool(self, name: str, schema: ToolSchema) -> ToolboxTool:
56+
def __parse_tool(
57+
self,
58+
name: str,
59+
schema: ToolSchema,
60+
auth_token_getters: dict[str, Callable[[], str]],
61+
) -> ToolboxTool:
5762
"""Internal helper to create a callable tool from its schema."""
63+
# sort into authenticated and reg params
64+
params = []
65+
authn_params: dict[str, list[str]] = {}
66+
auth_sources: set[str] = set()
67+
for p in schema.parameters:
68+
if not p.authSources:
69+
params.append(p)
70+
else:
71+
authn_params[p.name] = p.authSources
72+
auth_sources.update(p.authSources)
73+
74+
authn_params = filter_required_authn_params(authn_params, auth_sources)
75+
5876
tool = ToolboxTool(
5977
session=self.__session,
6078
base_url=self.__base_url,
6179
name=name,
6280
desc=schema.description,
63-
params=[p.to_param() for p in schema.parameters],
81+
params=[p.to_param() for p in params],
82+
required_authn_params=types.MappingProxyType(authn_params),
83+
auth_service_token_getters=auth_token_getters,
6484
)
6585
return tool
6686

@@ -99,6 +119,7 @@ async def close(self):
99119
async def load_tool(
100120
self,
101121
name: str,
122+
auth_service_tokens: dict[str, Callable[[], str]] = {},
102123
) -> ToolboxTool:
103124
"""
104125
Asynchronously loads a tool from the server.
@@ -127,13 +148,14 @@ async def load_tool(
127148
if name not in manifest.tools:
128149
# TODO: Better exception
129150
raise Exception(f"Tool '{name}' not found!")
130-
tool = self.__parse_tool(name, manifest.tools[name])
151+
tool = self.__parse_tool(name, manifest.tools[name], auth_service_tokens)
131152

132153
return tool
133154

134155
async def load_toolset(
135156
self,
136157
name: str,
158+
auth_token_getters: dict[str, Callable[[], str]] = {},
137159
) -> list[ToolboxTool]:
138160
"""
139161
Asynchronously fetches a toolset and loads all tools defined within it.
@@ -152,5 +174,8 @@ async def load_toolset(
152174
manifest: ManifestSchema = ManifestSchema(**json)
153175

154176
# parse each tools name and schema into a list of ToolboxTools
155-
tools = [self.__parse_tool(n, s) for n, s in manifest.tools.items()]
177+
tools = [
178+
self.__parse_tool(n, s, auth_token_getters)
179+
for n, s in manifest.tools.items()
180+
]
156181
return tools

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 146 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
# limitations under the License.
1414

1515

16+
import types
17+
from collections import defaultdict
1618
from inspect import Parameter, Signature
17-
from typing import Any
19+
from typing import Any, Callable, DefaultDict, Iterable, Mapping, Optional, Sequence
1820

1921
from aiohttp import ClientSession
22+
from pytest import Session
2023

2124

2225
class ToolboxTool:
@@ -32,20 +35,19 @@ class ToolboxTool:
3235
and `inspect` work as expected.
3336
"""
3437

35-
__url: str
36-
__session: ClientSession
37-
__signature__: Signature
38-
3938
def __init__(
4039
self,
4140
session: ClientSession,
4241
base_url: str,
4342
name: str,
4443
desc: str,
45-
params: list[Parameter],
44+
params: Sequence[Parameter],
45+
required_authn_params: Mapping[str, list[str]],
46+
auth_service_token_getters: Mapping[str, Callable[[], str]],
4647
):
4748
"""
48-
Initializes a callable that will trigger the tool invocation through the Toolbox server.
49+
Initializes a callable that will trigger the tool invocation through the
50+
Toolbox server.
4951
5052
Args:
5153
session: The `aiohttp.ClientSession` used for making API requests.
@@ -54,19 +56,69 @@ def __init__(
5456
desc: The description of the remote tool (used as its docstring).
5557
params: A list of `inspect.Parameter` objects defining the tool's
5658
arguments and their types/defaults.
59+
required_authn_params: A dict of required authenticated parameters that
60+
need a auth_service_token_getter set for them yet.
61+
auth_service_tokens: A dict of authService -> token (or callables that
62+
produce a token)
5763
"""
5864

5965
# used to invoke the toolbox API
60-
self.__session = session
66+
self.__session: ClientSession = session
67+
self.__base_url: str = base_url
6168
self.__url = f"{base_url}/api/tool/{name}/invoke"
6269

63-
# the following properties are set to help anyone that might inspect it determine
70+
self.__desc = desc
71+
self.__params = params
72+
73+
# the following properties are set to help anyone that might inspect it determine usage
6474
self.__name__ = name
6575
self.__doc__ = desc
6676
self.__signature__ = Signature(parameters=params, return_annotation=str)
6777
self.__annotations__ = {p.name: p.annotation for p in params}
6878
# TODO: self.__qualname__ ??
6979

80+
# map of parameter name to auth service required by it
81+
self.__required_authn_params = required_authn_params
82+
# map of authService -> token_getter
83+
self.__auth_service_token_getters = auth_service_token_getters
84+
85+
def __copy(
86+
self,
87+
session: Optional[ClientSession] = None,
88+
base_url: Optional[str] = None,
89+
name: Optional[str] = None,
90+
desc: Optional[str] = None,
91+
params: Optional[list[Parameter]] = None,
92+
required_authn_params: Optional[Mapping[str, list[str]]] = None,
93+
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
94+
):
95+
"""
96+
Creates a copy of the ToolboxTool, overriding specific fields.
97+
98+
Args:
99+
session: The `aiohttp.ClientSession` used for making API requests.
100+
base_url: The base URL of the Toolbox server API.
101+
name: The name of the remote tool.
102+
desc: The description of the remote tool (used as its docstring).
103+
params: A list of `inspect.Parameter` objects defining the tool's
104+
arguments and their types/defaults.
105+
required_authn_params: A dict of required authenticated parameters that need
106+
a auth_service_token_getter set for them yet.
107+
auth_service_token_getters: A dict of authService -> token (or callables
108+
that produce a token)
109+
110+
"""
111+
return ToolboxTool(
112+
session=session or self.__session,
113+
base_url=base_url or self.__base_url,
114+
name=name or self.__name__,
115+
desc=desc or self.__desc,
116+
params=params or self.__params,
117+
required_authn_params=required_authn_params or self.__required_authn_params,
118+
auth_service_token_getters=auth_service_token_getters
119+
or self.__auth_service_token_getters,
120+
)
121+
70122
async def __call__(self, *args: Any, **kwargs: Any) -> str:
71123
"""
72124
Asynchronously calls the remote tool with the provided arguments.
@@ -81,16 +133,96 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
81133
Returns:
82134
The string result returned by the remote tool execution.
83135
"""
136+
137+
# check if any auth services need to be specified yet
138+
if len(self.__required_authn_params) > 0:
139+
req_auth_services = set(l for l in self.__required_authn_params.keys())
140+
raise Exception(
141+
f"One of more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}"
142+
)
143+
144+
# validate inputs to this call using the signature
84145
all_args = self.__signature__.bind(*args, **kwargs)
85146
all_args.apply_defaults() # Include default values if not provided
86147
payload = all_args.arguments
87148

149+
# create headers for auth services
150+
headers = {}
151+
for auth_service, token_getter in self.__auth_service_token_getters.items():
152+
headers[f"{auth_service}_token"] = token_getter()
153+
88154
async with self.__session.post(
89155
self.__url,
90156
json=payload,
157+
headers=headers,
91158
) as resp:
92-
ret = await resp.json()
93-
if "error" in ret:
94-
# TODO: better error
95-
raise Exception(ret["error"])
96-
return ret.get("result", ret)
159+
body = await resp.json()
160+
if resp.status < 200 or resp.status >= 300:
161+
err = body.get("error", f"unexpected status from server: {resp.status}")
162+
raise Exception(err)
163+
return body.get("result", body)
164+
165+
def add_auth_token_getters(
166+
self,
167+
auth_token_getters: Mapping[str, Callable[[], str]],
168+
) -> "ToolboxTool":
169+
"""
170+
Registers a auth token getter function that is used for AuthServices when tools
171+
are invoked.
172+
173+
Args:
174+
auth_token_getters: A mapping of authentication service names to
175+
callables that return the corresponding authentication token.
176+
177+
Returns:
178+
A new ToolboxTool instance with the specified authentication token
179+
getters registered.
180+
"""
181+
182+
# throw an error if the authentication source is already registered
183+
dupes = auth_token_getters.keys() & self.__auth_service_token_getters.keys()
184+
if dupes:
185+
raise ValueError(
186+
f"Authentication source(s) `{', '.join(dupes)}` already registered in tool `{self.__name__}`."
187+
)
188+
189+
# create a read-only updated value for new_getters
190+
new_getters = types.MappingProxyType(
191+
dict(self.__auth_service_token_getters, **auth_token_getters)
192+
)
193+
# create a read-only updated for params that are still required
194+
new_req_authn_params = types.MappingProxyType(
195+
filter_required_authn_params(
196+
self.__required_authn_params, auth_token_getters.keys()
197+
)
198+
)
199+
200+
return self.__copy(
201+
auth_service_token_getters=new_getters,
202+
required_authn_params=new_req_authn_params,
203+
)
204+
205+
206+
def filter_required_authn_params(
207+
req_authn_params: Mapping[str, list[str]], auth_services: Iterable[str]
208+
) -> dict[str, list[str]]:
209+
"""
210+
Utility function for reducing 'req_authn_params' to a subset of parameters that aren't supplied by a least one service in auth_services.
211+
212+
Args:
213+
req_authn_params: A mapping of parameter names to sets of required
214+
authentication services.
215+
auth_services: An iterable of authentication service names for which
216+
token getters are available.
217+
218+
Returns:
219+
A new dictionary representing the subset of required authentication
220+
parameters that are not covered by the provided `auth_services`.
221+
"""
222+
req_params = {}
223+
for param, services in req_authn_params.items():
224+
# if we don't have a token_getter for any of the services required by the param, the param is still required
225+
required = not any(s in services for s in auth_services)
226+
if required:
227+
req_params[param] = services
228+
return req_params

0 commit comments

Comments
 (0)