|
17 | 17 | import copy
|
18 | 18 | import types
|
19 | 19 | from inspect import Signature
|
| 20 | +from pydantic import BaseModel, Field, create_model |
| 21 | + |
20 | 22 | from typing import (
|
21 | 23 | Any,
|
22 | 24 | Callable,
|
23 | 25 | Iterable,
|
24 | 26 | Mapping,
|
25 | 27 | Optional,
|
26 | 28 | Union,
|
| 29 | + cast, |
| 30 | + Type |
27 | 31 | )
|
28 | 32 |
|
29 | 33 | from aiohttp import ClientSession
|
@@ -176,6 +180,10 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
|
176 | 180 | all_args.apply_defaults() # Include default values if not provided
|
177 | 181 | payload = all_args.arguments
|
178 | 182 |
|
| 183 | + # Perform argument type checks using pydantic |
| 184 | + model = _schema_to_model(self.__name__, self.__schema) |
| 185 | + model.model_validate(payload) |
| 186 | + |
179 | 187 | # apply bounded parameters
|
180 | 188 | for param, value in self.__bound_parameters.items():
|
181 | 189 | if asyncio.iscoroutinefunction(value):
|
@@ -306,3 +314,17 @@ def identify_required_authn_params(
|
306 | 314 | if required:
|
307 | 315 | required_params[param] = services
|
308 | 316 | return required_params
|
| 317 | + |
| 318 | +def _schema_to_model(model_name: str, tool_schema: ToolSchema) -> Type[BaseModel]: |
| 319 | + """Converts the given manifest schema to a Pydantic BaseModel class.""" |
| 320 | + field_definitions = {} |
| 321 | + for field in tool_schema.parameters: |
| 322 | + field_definitions[field.name] = cast( |
| 323 | + Any, |
| 324 | + ( |
| 325 | + field.to_param().annotation, |
| 326 | + Field(description=field.description), |
| 327 | + ), |
| 328 | + ) |
| 329 | + |
| 330 | + return create_model(model_name, **field_definitions) |
0 commit comments