Skip to content

Commit c7b76da

Browse files
committed
add basic pydantic type checking
1 parent e60ab95 commit c7b76da

File tree

1 file changed

+22
-0
lines changed
  • packages/toolbox-core/src/toolbox_core

1 file changed

+22
-0
lines changed

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@
1717
import copy
1818
import types
1919
from inspect import Signature
20+
from pydantic import BaseModel, Field, create_model
21+
2022
from typing import (
2123
Any,
2224
Callable,
2325
Iterable,
2426
Mapping,
2527
Optional,
2628
Union,
29+
cast,
30+
Type
2731
)
2832

2933
from aiohttp import ClientSession
@@ -176,6 +180,10 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
176180
all_args.apply_defaults() # Include default values if not provided
177181
payload = all_args.arguments
178182

183+
# Perform argument type checks using pydantic
184+
model = _schema_to_model(self.__name__, self.__schema)
185+
model.model_validate(payload)
186+
179187
# apply bounded parameters
180188
for param, value in self.__bound_parameters.items():
181189
if asyncio.iscoroutinefunction(value):
@@ -306,3 +314,17 @@ def identify_required_authn_params(
306314
if required:
307315
required_params[param] = services
308316
return required_params
317+
318+
def _schema_to_model(model_name: str, tool_schema: ToolSchema) -> Type[BaseModel]:
319+
"""Converts the given manifest schema to a Pydantic BaseModel class."""
320+
field_definitions = {}
321+
for field in tool_schema.parameters:
322+
field_definitions[field.name] = cast(
323+
Any,
324+
(
325+
field.to_param().annotation,
326+
Field(description=field.description),
327+
),
328+
)
329+
330+
return create_model(model_name, **field_definitions)

0 commit comments

Comments
 (0)