Skip to content

Commit 3c8360c

Browse files
authored
feat(toolbox-core): Add support for optional parameters (#290)
* feat(toolbox-core): Add support for optional parameters * chore: Delint * chore: Remove unnecessary import * chore: Delint * chore: Delint * fix: Fix unit tests * chore: Add unit tests for optional parameters * chore: Delint * chore: Add E2E tests for optional params * chore: Delint * fix: Fix e2e tests * chore: Delint * chore: Update e2e tests * chore: Improve e2e tests * chore: Delint * chore: Improve e2e tests for optional params around null values * chore: Delint * chore: Make separation of required/optional params more efficient * chore: Delint * chore: optimize chaining required and optional params * chore: Fix integration tests * chore: Add additional integration tests * chore: Delint * chore: Fix integration tests * fix: Ignore null values from Toolbox core to fix server error of type validation * chore: Fix e2e tests * chore: Delint * fix: Fix issue causing attribute error in python 3.9 * chore: Use upgraded toolbox server that supports optional params * chore: Fix e2e tests for other packages
1 parent ac2bcf5 commit 3c8360c

File tree

11 files changed

+282
-20
lines changed

11 files changed

+282
-20
lines changed

packages/toolbox-core/integration.cloudbuild.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,4 @@ options:
4343
logging: CLOUD_LOGGING_ONLY
4444
substitutions:
4545
_VERSION: '3.13'
46-
_TOOLBOX_VERSION: '0.7.0'
46+
_TOOLBOX_VERSION: '0.8.0'

packages/toolbox-core/src/toolbox_core/protocol.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,31 +25,39 @@ class ParameterSchema(BaseModel):
2525

2626
name: str
2727
type: str
28+
required: bool = True
2829
description: str
2930
authSources: Optional[list[str]] = None
3031
items: Optional["ParameterSchema"] = None
3132

3233
def __get_type(self) -> Type:
34+
base_type: Type
3335
if self.type == "string":
34-
return str
36+
base_type = str
3537
elif self.type == "integer":
36-
return int
38+
base_type = int
3739
elif self.type == "float":
38-
return float
40+
base_type = float
3941
elif self.type == "boolean":
40-
return bool
42+
base_type = bool
4143
elif self.type == "array":
4244
if self.items is None:
4345
raise Exception("Unexpected value: type is 'list' but items is None")
44-
return list[self.items.__get_type()] # type: ignore
46+
base_type = list[self.items.__get_type()] # type: ignore
47+
else:
48+
raise ValueError(f"Unsupported schema type: {self.type}")
4549

46-
raise ValueError(f"Unsupported schema type: {self.type}")
50+
if not self.required:
51+
return Optional[base_type] # type: ignore
52+
53+
return base_type
4754

4855
def to_param(self) -> Parameter:
4956
return Parameter(
5057
self.name,
5158
Parameter.POSITIONAL_OR_KEYWORD,
5259
annotation=self.__get_type(),
60+
default=Parameter.empty if self.required else None,
5361
)
5462

5563

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515
import copy
16+
import itertools
17+
from collections import OrderedDict
1618
from inspect import Signature
1719
from types import MappingProxyType
1820
from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union
@@ -89,7 +91,13 @@ def __init__(
8991
self.__params = params
9092
self.__pydantic_model = params_to_pydantic_model(name, self.__params)
9193

92-
inspect_type_params = [param.to_param() for param in self.__params]
94+
# Separate parameters into required (no default) and optional (with
95+
# default) to prevent the "non-default argument follows default
96+
# argument" error when creating the function signature.
97+
required_params = (p for p in self.__params if p.required)
98+
optional_params = (p for p in self.__params if not p.required)
99+
ordered_params = itertools.chain(required_params, optional_params)
100+
inspect_type_params = [param.to_param() for param in ordered_params]
93101

94102
# the following properties are set to help anyone that might inspect it determine usage
95103
self.__name__ = name
@@ -268,7 +276,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
268276

269277
# validate inputs to this call using the signature
270278
all_args = self.__signature__.bind(*args, **kwargs)
271-
all_args.apply_defaults() # Include default values if not provided
279+
280+
# The payload will only contain arguments explicitly provided by the user.
281+
# Optional arguments not provided by the user will not be in the payload.
272282
payload = all_args.arguments
273283

274284
# Perform argument type validations using pydantic
@@ -278,6 +288,11 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
278288
for param, value in self.__bound_parameters.items():
279289
payload[param] = await resolve_value(value)
280290

291+
# Remove None values to prevent server-side type errors. The Toolbox
292+
# server requires specific types for each parameter and will raise an
293+
# error if it receives a None value, which it cannot convert.
294+
payload = OrderedDict({k: v for k, v in payload.items() if v is not None})
295+
281296
# create headers for auth services
282297
headers = {}
283298
for client_header_name, client_header_val in self.__client_headers.items():

packages/toolbox-core/src/toolbox_core/utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@ def create_func_docstring(description: str, params: Sequence[ParameterSchema]) -
3838
return docstring
3939
docstring += "\n\nArgs:"
4040
for p in params:
41-
docstring += (
42-
f"\n {p.name} ({p.to_param().annotation.__name__}): {p.description}"
43-
)
41+
annotation = p.to_param().annotation
42+
docstring += f"\n {p.name} ({getattr(annotation, '__name__', str(annotation))}): {p.description}"
4443
return docstring
4544

4645

@@ -111,11 +110,20 @@ def params_to_pydantic_model(
111110
"""Converts the given parameters to a Pydantic BaseModel class."""
112111
field_definitions = {}
113112
for field in params:
113+
114+
# Determine the default value based on the 'required' flag.
115+
# '...' (Ellipsis) signifies a required field in Pydantic.
116+
# 'None' makes the field optional with a default value of None.
117+
default_value = ... if field.required else None
118+
114119
field_definitions[field.name] = cast(
115120
Any,
116121
(
117122
field.to_param().annotation,
118-
Field(description=field.description),
123+
Field(
124+
description=field.description,
125+
default=default_value,
126+
),
119127
),
120128
)
121129
return create_model(tool_name, **field_definitions)

packages/toolbox-core/tests/test_e2e.py

Lines changed: 163 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from inspect import Parameter, signature
16+
from typing import Optional
17+
1418
import pytest
1519
import pytest_asyncio
1620
from pydantic import ValidationError
@@ -64,14 +68,15 @@ async def test_load_toolset_specific(
6468
async def test_load_toolset_default(self, toolbox: ToolboxClient):
6569
"""Load the default toolset, i.e. all tools."""
6670
toolset = await toolbox.load_toolset()
67-
assert len(toolset) == 5
71+
assert len(toolset) == 6
6872
tool_names = {tool.__name__ for tool in toolset}
6973
expected_tools = [
7074
"get-row-by-content-auth",
7175
"get-row-by-email-auth",
7276
"get-row-by-id-auth",
7377
"get-row-by-id",
7478
"get-n-rows",
79+
"search-rows",
7580
]
7681
assert tool_names == set(expected_tools)
7782

@@ -217,3 +222,160 @@ async def test_run_tool_param_auth_no_field(
217222
match="no field named row_data in claims",
218223
):
219224
await tool()
225+
226+
227+
@pytest.mark.asyncio
228+
@pytest.mark.usefixtures("toolbox_server")
229+
class TestOptionalParams:
230+
"""
231+
End-to-end tests for tools with optional parameters.
232+
"""
233+
234+
async def test_tool_signature_is_correct(self, toolbox: ToolboxClient):
235+
"""Verify the client correctly constructs the signature for a tool with optional params."""
236+
tool = await toolbox.load_tool("search-rows")
237+
sig = signature(tool)
238+
239+
assert "email" in sig.parameters
240+
assert "data" in sig.parameters
241+
assert "id" in sig.parameters
242+
243+
# The required parameter should have no default
244+
assert sig.parameters["email"].default is Parameter.empty
245+
assert sig.parameters["email"].annotation is str
246+
247+
# The optional parameter should have a default of None
248+
assert sig.parameters["data"].default is None
249+
assert sig.parameters["data"].annotation is Optional[str]
250+
251+
# The optional parameter should have a default of None
252+
assert sig.parameters["id"].default is None
253+
assert sig.parameters["id"].annotation is Optional[int]
254+
255+
async def test_run_tool_with_optional_params_omitted(self, toolbox: ToolboxClient):
256+
"""Invoke a tool providing only the required parameter."""
257+
tool = await toolbox.load_tool("search-rows")
258+
259+
response = await tool(email="[email protected]")
260+
assert isinstance(response, str)
261+
assert '"email":"[email protected]"' in response
262+
assert "row1" not in response
263+
assert "row2" in response
264+
assert "row3" not in response
265+
assert "row4" not in response
266+
assert "row5" not in response
267+
assert "row6" not in response
268+
269+
async def test_run_tool_with_optional_data_provided(self, toolbox: ToolboxClient):
270+
"""Invoke a tool providing both required and optional parameters."""
271+
tool = await toolbox.load_tool("search-rows")
272+
273+
response = await tool(email="[email protected]", data="row3")
274+
assert isinstance(response, str)
275+
assert '"email":"[email protected]"' in response
276+
assert "row1" not in response
277+
assert "row2" not in response
278+
assert "row3" in response
279+
assert "row4" not in response
280+
assert "row5" not in response
281+
assert "row6" not in response
282+
283+
async def test_run_tool_with_optional_data_null(self, toolbox: ToolboxClient):
284+
"""Invoke a tool providing both required and optional parameters."""
285+
tool = await toolbox.load_tool("search-rows")
286+
287+
response = await tool(email="[email protected]", data=None)
288+
assert isinstance(response, str)
289+
assert '"email":"[email protected]"' in response
290+
assert "row1" not in response
291+
assert "row2" in response
292+
assert "row3" not in response
293+
assert "row4" not in response
294+
assert "row5" not in response
295+
assert "row6" not in response
296+
297+
async def test_run_tool_with_optional_id_provided(self, toolbox: ToolboxClient):
298+
"""Invoke a tool providing both required and optional parameters."""
299+
tool = await toolbox.load_tool("search-rows")
300+
301+
response = await tool(email="[email protected]", id=1)
302+
assert isinstance(response, str)
303+
assert response == "null"
304+
305+
async def test_run_tool_with_optional_id_null(self, toolbox: ToolboxClient):
306+
"""Invoke a tool providing both required and optional parameters."""
307+
tool = await toolbox.load_tool("search-rows")
308+
309+
response = await tool(email="[email protected]", id=None)
310+
assert isinstance(response, str)
311+
assert '"email":"[email protected]"' in response
312+
assert "row1" not in response
313+
assert "row2" in response
314+
assert "row3" not in response
315+
assert "row4" not in response
316+
assert "row5" not in response
317+
assert "row6" not in response
318+
319+
async def test_run_tool_with_missing_required_param(self, toolbox: ToolboxClient):
320+
"""Invoke a tool without its required parameter."""
321+
tool = await toolbox.load_tool("search-rows")
322+
with pytest.raises(TypeError, match="missing a required argument: 'email'"):
323+
await tool(id=5, data="row5")
324+
325+
async def test_run_tool_with_required_param_null(self, toolbox: ToolboxClient):
326+
"""Invoke a tool without its required parameter."""
327+
tool = await toolbox.load_tool("search-rows")
328+
with pytest.raises(ValidationError, match="email"):
329+
await tool(email=None, id=5, data="row5")
330+
331+
async def test_run_tool_with_all_default_params(self, toolbox: ToolboxClient):
332+
"""Invoke a tool providing all parameters."""
333+
tool = await toolbox.load_tool("search-rows")
334+
335+
response = await tool(email="[email protected]", id=0, data="row2")
336+
assert isinstance(response, str)
337+
assert '"email":"[email protected]"' in response
338+
assert "row1" not in response
339+
assert "row2" in response
340+
assert "row3" not in response
341+
assert "row4" not in response
342+
assert "row5" not in response
343+
assert "row6" not in response
344+
345+
async def test_run_tool_with_all_valid_params(self, toolbox: ToolboxClient):
346+
"""Invoke a tool providing all parameters."""
347+
tool = await toolbox.load_tool("search-rows")
348+
349+
response = await tool(email="[email protected]", id=3, data="row3")
350+
assert isinstance(response, str)
351+
assert '"email":"[email protected]"' in response
352+
assert "row1" not in response
353+
assert "row2" not in response
354+
assert "row3" in response
355+
assert "row4" not in response
356+
assert "row5" not in response
357+
assert "row6" not in response
358+
359+
async def test_run_tool_with_different_email(self, toolbox: ToolboxClient):
360+
"""Invoke a tool providing all parameters but with a different email."""
361+
tool = await toolbox.load_tool("search-rows")
362+
363+
response = await tool(email="[email protected]", id=3, data="row3")
364+
assert isinstance(response, str)
365+
assert response == "null"
366+
367+
async def test_run_tool_with_different_data(self, toolbox: ToolboxClient):
368+
"""Invoke a tool providing all parameters but with a different data."""
369+
tool = await toolbox.load_tool("search-rows")
370+
371+
response = await tool(email="[email protected]", id=3, data="row4")
372+
assert isinstance(response, str)
373+
assert response == "null"
374+
375+
async def test_run_tool_with_different_id(self, toolbox: ToolboxClient):
376+
"""Invoke a tool providing all parameters but with a different data."""
377+
tool = await toolbox.load_tool("search-rows")
378+
379+
response = await tool(email="[email protected]", id=4, data="row3")
380+
assert isinstance(response, str)
381+
assert response == "null"

packages/toolbox-core/tests/test_protocol.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
from inspect import Parameter
17+
from typing import Optional
1718

1819
import pytest
1920

@@ -106,3 +107,66 @@ def test_parameter_schema_unsupported_type_error():
106107

107108
with pytest.raises(ValueError, match=expected_error_msg):
108109
schema.to_param()
110+
111+
112+
def test_parameter_schema_string_optional():
113+
"""Tests an optional ParameterSchema with type 'string'."""
114+
schema = ParameterSchema(
115+
name="nickname",
116+
type="string",
117+
description="An optional nickname",
118+
required=False,
119+
)
120+
expected_type = Optional[str]
121+
122+
# Test __get_type()
123+
assert schema._ParameterSchema__get_type() == expected_type
124+
125+
# Test to_param()
126+
param = schema.to_param()
127+
assert isinstance(param, Parameter)
128+
assert param.name == "nickname"
129+
assert param.annotation == expected_type
130+
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD
131+
assert param.default is None
132+
133+
134+
def test_parameter_schema_required_by_default():
135+
"""Tests that a parameter is required by default."""
136+
# 'required' is not specified, so it should default to True.
137+
schema = ParameterSchema(name="id", type="integer", description="A required ID")
138+
expected_type = int
139+
140+
# Test __get_type()
141+
assert schema._ParameterSchema__get_type() == expected_type
142+
143+
# Test to_param()
144+
param = schema.to_param()
145+
assert isinstance(param, Parameter)
146+
assert param.name == "id"
147+
assert param.annotation == expected_type
148+
assert param.default == Parameter.empty
149+
150+
151+
def test_parameter_schema_array_optional():
152+
"""Tests an optional ParameterSchema with type 'array'."""
153+
item_schema = ParameterSchema(name="", type="integer", description="")
154+
schema = ParameterSchema(
155+
name="optional_scores",
156+
type="array",
157+
description="An optional list of scores",
158+
items=item_schema,
159+
required=False,
160+
)
161+
expected_type = Optional[list[int]]
162+
163+
# Test __get_type()
164+
assert schema._ParameterSchema__get_type() == expected_type
165+
166+
# Test to_param()
167+
param = schema.to_param()
168+
assert isinstance(param, Parameter)
169+
assert param.name == "optional_scores"
170+
assert param.annotation == expected_type
171+
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD
172+
assert param.default is None

0 commit comments

Comments
 (0)