Skip to content

Commit ffb8e90

Browse files
committed
chore[toolbox-core]: Move util functions to a separate file in toolbox core
Also includes unit test cases for the new file.
1 parent c80fadc commit ffb8e90

File tree

3 files changed

+362
-87
lines changed

3 files changed

+362
-87
lines changed

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 7 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,28 @@
1313
# limitations under the License.
1414

1515

16-
import asyncio
1716
import types
1817
from inspect import Signature
1918
from typing import (
2019
Any,
21-
Awaitable,
2220
Callable,
23-
Iterable,
2421
Mapping,
2522
Optional,
2623
Sequence,
27-
Type,
2824
Union,
29-
cast,
3025
)
3126

3227
from aiohttp import ClientSession
33-
from pydantic import BaseModel, Field, create_model
3428

3529
from toolbox_core.protocol import ParameterSchema
3630

31+
from .utils import (
32+
create_docstring,
33+
identify_required_authn_params,
34+
params_to_pydantic_model,
35+
resolve_value,
36+
)
37+
3738

3839
class ToolboxTool:
3940
"""
@@ -271,84 +272,3 @@ def bind_parameters(
271272
params=new_params,
272273
bound_params=types.MappingProxyType(all_bound_params),
273274
)
274-
275-
276-
def create_docstring(description: str, params: Sequence[ParameterSchema]) -> str:
277-
"""Convert tool description and params into its function docstring"""
278-
docstring = description
279-
if not params:
280-
return docstring
281-
docstring += "\n\nArgs:"
282-
for p in params:
283-
docstring += (
284-
f"\n {p.name} ({p.to_param().annotation.__name__}): {p.description}"
285-
)
286-
return docstring
287-
288-
289-
def identify_required_authn_params(
290-
req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str]
291-
) -> dict[str, list[str]]:
292-
"""
293-
Identifies authentication parameters that are still required; because they
294-
not covered by the provided `auth_service_names`.
295-
296-
Args:
297-
req_authn_params: A mapping of parameter names to sets of required
298-
authentication services.
299-
auth_service_names: An iterable of authentication service names for which
300-
token getters are available.
301-
302-
Returns:
303-
A new dictionary representing the subset of required authentication parameters
304-
that are not covered by the provided `auth_services`.
305-
"""
306-
required_params = {} # params that are still required with provided auth_services
307-
for param, services in req_authn_params.items():
308-
# if we don't have a token_getter for any of the services required by the param,
309-
# the param is still required
310-
required = not any(s in services for s in auth_service_names)
311-
if required:
312-
required_params[param] = services
313-
return required_params
314-
315-
316-
def params_to_pydantic_model(
317-
tool_name: str, params: Sequence[ParameterSchema]
318-
) -> Type[BaseModel]:
319-
"""Converts the given parameters to a Pydantic BaseModel class."""
320-
field_definitions = {}
321-
for field in params:
322-
field_definitions[field.name] = cast(
323-
Any,
324-
(
325-
field.to_param().annotation,
326-
Field(description=field.description),
327-
),
328-
)
329-
return create_model(tool_name, **field_definitions)
330-
331-
332-
async def resolve_value(
333-
source: Union[Callable[[], Awaitable[Any]], Callable[[], Any], Any],
334-
) -> Any:
335-
"""
336-
Asynchronously or synchronously resolves a given source to its value.
337-
338-
If the `source` is a coroutine function, it will be awaited.
339-
If the `source` is a regular callable, it will be called.
340-
Otherwise (if it's not a callable), the `source` itself is returned directly.
341-
342-
Args:
343-
source: The value, a callable returning a value, or a callable
344-
returning an awaitable value.
345-
346-
Returns:
347-
The resolved value.
348-
"""
349-
350-
if asyncio.iscoroutinefunction(source):
351-
return await source()
352-
elif callable(source):
353-
return source()
354-
return source
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import asyncio
17+
from typing import (
18+
Any,
19+
Awaitable,
20+
Callable,
21+
Iterable,
22+
Mapping,
23+
Sequence,
24+
Type,
25+
Union,
26+
cast,
27+
)
28+
29+
from pydantic import BaseModel, Field, create_model
30+
31+
from toolbox_core.protocol import ParameterSchema
32+
33+
34+
def create_docstring(description: str, params: Sequence[ParameterSchema]) -> str:
35+
"""Convert tool description and params into its function docstring"""
36+
docstring = description
37+
if not params:
38+
return docstring
39+
docstring += "\n\nArgs:"
40+
for p in params:
41+
docstring += (
42+
f"\n {p.name} ({p.to_param().annotation.__name__}): {p.description}"
43+
)
44+
return docstring
45+
46+
47+
def identify_required_authn_params(
48+
req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str]
49+
) -> dict[str, list[str]]:
50+
"""
51+
Identifies authentication parameters that are still required; because they
52+
not covered by the provided `auth_service_names`.
53+
54+
Args:
55+
req_authn_params: A mapping of parameter names to sets of required
56+
authentication services.
57+
auth_service_names: An iterable of authentication service names for which
58+
token getters are available.
59+
60+
Returns:
61+
A new dictionary representing the subset of required authentication parameters
62+
that are not covered by the provided `auth_services`.
63+
"""
64+
required_params = {} # params that are still required with provided auth_services
65+
for param, services in req_authn_params.items():
66+
# if we don't have a token_getter for any of the services required by the param,
67+
# the param is still required
68+
required = not any(s in services for s in auth_service_names)
69+
if required:
70+
required_params[param] = services
71+
return required_params
72+
73+
74+
def params_to_pydantic_model(
75+
tool_name: str, params: Sequence[ParameterSchema]
76+
) -> Type[BaseModel]:
77+
"""Converts the given parameters to a Pydantic BaseModel class."""
78+
field_definitions = {}
79+
for field in params:
80+
field_definitions[field.name] = cast(
81+
Any,
82+
(
83+
field.to_param().annotation,
84+
Field(description=field.description),
85+
),
86+
)
87+
return create_model(tool_name, **field_definitions)
88+
89+
90+
async def resolve_value(
91+
source: Union[Callable[[], Awaitable[Any]], Callable[[], Any], Any],
92+
) -> Any:
93+
"""
94+
Asynchronously or synchronously resolves a given source to its value.
95+
96+
If the `source` is a coroutine function, it will be awaited.
97+
If the `source` is a regular callable, it will be called.
98+
Otherwise (if it's not a callable), the `source` itself is returned directly.
99+
100+
Args:
101+
source: The value, a callable returning a value, or a callable
102+
returning an awaitable value.
103+
104+
Returns:
105+
The resolved value.
106+
"""
107+
108+
if asyncio.iscoroutinefunction(source):
109+
return await source()
110+
elif callable(source):
111+
return source()
112+
return source

0 commit comments

Comments
 (0)