Skip to content

Commit 510ee8d

Browse files
committed
feat(toolbox-core): Add support for optional parameters
1 parent ac2bcf5 commit 510ee8d

File tree

3 files changed

+37
-10
lines changed

3 files changed

+37
-10
lines changed

packages/toolbox-core/src/toolbox_core/protocol.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,31 +25,43 @@ class ParameterSchema(BaseModel):
2525

2626
name: str
2727
type: str
28+
required: bool = True
2829
description: str
2930
authSources: Optional[list[str]] = None
3031
items: Optional["ParameterSchema"] = None
3132

3233
def __get_type(self) -> Type:
34+
base_type: Type
3335
if self.type == "string":
34-
return str
36+
base_type = str
3537
elif self.type == "integer":
36-
return int
38+
base_type = int
3739
elif self.type == "float":
38-
return float
40+
base_type = float
3941
elif self.type == "boolean":
40-
return bool
42+
base_type = bool
4143
elif self.type == "array":
4244
if self.items is None:
4345
raise Exception("Unexpected value: type is 'list' but items is None")
44-
return list[self.items.__get_type()] # type: ignore
46+
base_type = list[self.items.__get_type()]
47+
else:
48+
raise ValueError(f"Unsupported schema type: {self.type}")
4549

46-
raise ValueError(f"Unsupported schema type: {self.type}")
50+
if not self.required:
51+
return Optional[base_type]
52+
53+
return base_type
4754

4855
def to_param(self) -> Parameter:
56+
default = Parameter.empty
57+
if not self.required:
58+
default = None
59+
4960
return Parameter(
5061
self.name,
5162
Parameter.POSITIONAL_OR_KEYWORD,
5263
annotation=self.__get_type(),
64+
default=default,
5365
)
5466

5567

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import copy
16-
from inspect import Signature
16+
from inspect import Signature, Parameter
1717
from types import MappingProxyType
1818
from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union
1919
from warnings import warn
@@ -89,7 +89,11 @@ def __init__(
8989
self.__params = params
9090
self.__pydantic_model = params_to_pydantic_model(name, self.__params)
9191

92-
inspect_type_params = [param.to_param() for param in self.__params]
92+
# Sort parameters to ensure required ones (required=True) come before
93+
# optional ones (required=False). This prevents the "non-default argument
94+
# follows default argument" error when creating the signature.
95+
sorted_params = sorted(self.__params, key=lambda p: p.required, reverse=True)
96+
inspect_type_params = [param.to_param() for param in sorted_params]
9397

9498
# the following properties are set to help anyone that might inspect it determine usage
9599
self.__name__ = name
@@ -268,7 +272,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
268272

269273
# validate inputs to this call using the signature
270274
all_args = self.__signature__.bind(*args, **kwargs)
271-
all_args.apply_defaults() # Include default values if not provided
275+
276+
# The payload will only contain arguments explicitly provided by the user.
277+
# Optional arguments not provided by the user will not be in the payload.
272278
payload = all_args.arguments
273279

274280
# Perform argument type validations using pydantic

packages/toolbox-core/src/toolbox_core/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,20 @@ def params_to_pydantic_model(
111111
"""Converts the given parameters to a Pydantic BaseModel class."""
112112
field_definitions = {}
113113
for field in params:
114+
115+
# Determine the default value based on the 'required' flag.
116+
# '...' (Ellipsis) signifies a required field in Pydantic.
117+
# 'None' makes the field optional with a default value of None.
118+
default_value = ... if field.required else None
119+
114120
field_definitions[field.name] = cast(
115121
Any,
116122
(
117123
field.to_param().annotation,
118-
Field(description=field.description),
124+
Field(
125+
description=field.description,
126+
default=default_value,
127+
),
119128
),
120129
)
121130
return create_model(tool_name, **field_definitions)

0 commit comments

Comments
 (0)