diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 15374b02..3c6b633c 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -23,10 +23,13 @@ Mapping, Optional, Sequence, + Type, Union, + cast, ) from aiohttp import ClientSession +from pydantic import BaseModel, Field, create_model from toolbox_core.protocol import ParameterSchema @@ -78,6 +81,8 @@ def __init__( self.__url = f"{base_url}/api/tool/{name}/invoke" self.__description = description self.__params = params + self.__pydantic_model = params_to_pydantic_model(name, self.__params) + inspect_type_params = [param.to_param() for param in self.__params] # the following properties are set to help anyone that might inspect it determine usage @@ -86,6 +91,7 @@ def __init__( self.__signature__ = Signature( parameters=inspect_type_params, return_annotation=str ) + self.__annotations__ = {p.name: p.annotation for p in inspect_type_params} # TODO: self.__qualname__ ?? @@ -170,6 +176,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: all_args.apply_defaults() # Include default values if not provided payload = all_args.arguments + # Perform argument type validations using pydantic + self.__pydantic_model.model_validate(payload) + # apply bounded parameters for param, value in self.__bound_parameters.items(): if asyncio.iscoroutinefunction(value): @@ -305,3 +314,19 @@ def identify_required_authn_params( if required: required_params[param] = services return required_params + + +def params_to_pydantic_model( + tool_name: str, params: Sequence[ParameterSchema] +) -> Type[BaseModel]: + """Converts the given parameters to a Pydantic BaseModel class.""" + field_definitions = {} + for field in params: + field_definitions[field.name] = cast( + Any, + ( + field.to_param().annotation, + Field(description=field.description), + ), + ) + return create_model(tool_name, **field_definitions) diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 43f4d0f8..5e744ae1 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest import pytest_asyncio +from pydantic import ValidationError from toolbox_core.client import ToolboxClient from toolbox_core.tool import ToolboxTool @@ -77,8 +78,8 @@ async def test_run_tool_missing_params(self, get_n_rows_tool: ToolboxTool): async def test_run_tool_wrong_param_type(self, get_n_rows_tool: ToolboxTool): """Invoke a tool with wrong param type.""" with pytest.raises( - Exception, - match='provided parameters were invalid: unable to parse value for "num_rows": .* not type "string"', + ValidationError, + match=r"num_rows\s+Input should be a valid string\s+\[type=string_type,\s+input_value=2,\s+input_type=int\]", ): await get_n_rows_tool(num_rows=2)