Skip to content

Commit 8f5e83b

Browse files
authored
Merge branch 'main' into anubhav-qualname
2 parents 710ad6a + 5d62138 commit 8f5e83b

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@
2323
Mapping,
2424
Optional,
2525
Sequence,
26+
Type,
2627
Union,
28+
cast,
2729
)
2830

2931
from aiohttp import ClientSession
32+
from pydantic import BaseModel, Field, create_model
3033

3134
from toolbox_core.protocol import ParameterSchema
3235

@@ -78,6 +81,8 @@ def __init__(
7881
self.__url = f"{base_url}/api/tool/{name}/invoke"
7982
self.__description = description
8083
self.__params = params
84+
self.__pydantic_model = params_to_pydantic_model(name, self.__params)
85+
8186
inspect_type_params = [param.to_param() for param in self.__params]
8287

8388
# the following properties are set to help anyone that might inspect it determine usage
@@ -86,6 +91,7 @@ def __init__(
8691
self.__signature__ = Signature(
8792
parameters=inspect_type_params, return_annotation=str
8893
)
94+
8995
self.__annotations__ = {p.name: p.annotation for p in inspect_type_params}
9096
self.__qualname__ = f"{self.__class__.__qualname__}.{self.__name__}"
9197

@@ -170,6 +176,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
170176
all_args.apply_defaults() # Include default values if not provided
171177
payload = all_args.arguments
172178

179+
# Perform argument type validations using pydantic
180+
self.__pydantic_model.model_validate(payload)
181+
173182
# apply bounded parameters
174183
for param, value in self.__bound_parameters.items():
175184
if asyncio.iscoroutinefunction(value):
@@ -305,3 +314,19 @@ def identify_required_authn_params(
305314
if required:
306315
required_params[param] = services
307316
return required_params
317+
318+
319+
def params_to_pydantic_model(
320+
tool_name: str, params: Sequence[ParameterSchema]
321+
) -> Type[BaseModel]:
322+
"""Converts the given parameters to a Pydantic BaseModel class."""
323+
field_definitions = {}
324+
for field in params:
325+
field_definitions[field.name] = cast(
326+
Any,
327+
(
328+
field.to_param().annotation,
329+
Field(description=field.description),
330+
),
331+
)
332+
return create_model(tool_name, **field_definitions)

packages/toolbox-core/tests/test_e2e.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import pytest
1515
import pytest_asyncio
16+
from pydantic import ValidationError
1617

1718
from toolbox_core.client import ToolboxClient
1819
from toolbox_core.tool import ToolboxTool
@@ -77,8 +78,8 @@ async def test_run_tool_missing_params(self, get_n_rows_tool: ToolboxTool):
7778
async def test_run_tool_wrong_param_type(self, get_n_rows_tool: ToolboxTool):
7879
"""Invoke a tool with wrong param type."""
7980
with pytest.raises(
80-
Exception,
81-
match='provided parameters were invalid: unable to parse value for "num_rows": .* not type "string"',
81+
ValidationError,
82+
match=r"num_rows\s+Input should be a valid string\s+\[type=string_type,\s+input_value=2,\s+input_type=int\]",
8283
):
8384
await get_n_rows_tool(num_rows=2)
8485

0 commit comments

Comments
 (0)