33from abc import ABC , abstractmethod
44import asyncio
55from collections .abc import Awaitable , Callable
6- from inspect import signature
6+ from inspect import Parameter , signature
7+ import json
78from typing import (
89 Any ,
910 Generic ,
1011 ParamSpec ,
1112 TypeVar ,
12- cast ,
13- get_args ,
14- get_origin ,
1513 overload ,
1614)
1715
2725
2826from crewai .tools .structured_tool import CrewStructuredTool
2927from crewai .utilities .printer import Printer
28+ from crewai .utilities .pydantic_schema_utils import generate_model_description
3029
3130
3231_printer = Printer ()
@@ -103,20 +102,40 @@ def _default_args_schema(
103102 if v != cls ._ArgsSchemaPlaceholder :
104103 return v
105104
106- return cast (
107- type [PydanticBaseModel ],
108- type (
109- f"{ cls .__name__ } Schema" ,
110- (PydanticBaseModel ,),
111- {
112- "__annotations__" : {
113- k : v
114- for k , v in cls ._run .__annotations__ .items ()
115- if k != "return"
116- },
117- },
118- ),
119- )
105+ run_sig = signature (cls ._run )
106+ fields : dict [str , Any ] = {}
107+
108+ for param_name , param in run_sig .parameters .items ():
109+ if param_name in ("self" , "return" ):
110+ continue
111+ if param .kind in (Parameter .VAR_POSITIONAL , Parameter .VAR_KEYWORD ):
112+ continue
113+
114+ annotation = param .annotation if param .annotation != param .empty else Any
115+
116+ if param .default is param .empty :
117+ fields [param_name ] = (annotation , ...)
118+ else :
119+ fields [param_name ] = (annotation , param .default )
120+
121+ if not fields :
122+ arun_sig = signature (cls ._arun )
123+ for param_name , param in arun_sig .parameters .items ():
124+ if param_name in ("self" , "return" ):
125+ continue
126+ if param .kind in (Parameter .VAR_POSITIONAL , Parameter .VAR_KEYWORD ):
127+ continue
128+
129+ annotation = (
130+ param .annotation if param .annotation != param .empty else Any
131+ )
132+
133+ if param .default is param .empty :
134+ fields [param_name ] = (annotation , ...)
135+ else :
136+ fields [param_name ] = (annotation , param .default )
137+
138+ return create_model (f"{ cls .__name__ } Schema" , ** fields )
120139
121140 @field_validator ("max_usage_count" , mode = "before" )
122141 @classmethod
@@ -226,24 +245,23 @@ def from_langchain(cls, tool: Any) -> BaseTool:
226245 args_schema = getattr (tool , "args_schema" , None )
227246
228247 if args_schema is None :
229- # Infer args_schema from the function signature if not provided
230248 func_signature = signature (tool .func )
231- annotations = func_signature .parameters
232- args_fields : dict [str , Any ] = {}
233- for name , param in annotations .items ():
234- if name != "self" :
235- param_annotation = (
236- param .annotation if param .annotation != param .empty else Any
237- )
238- field_info = Field (
239- default = ...,
240- description = "" ,
241- )
242- args_fields [name ] = (param_annotation , field_info )
243- if args_fields :
244- args_schema = create_model (f"{ tool .name } Input" , ** args_fields )
249+ fields : dict [str , Any ] = {}
250+ for name , param in func_signature .parameters .items ():
251+ if name == "self" :
252+ continue
253+ if param .kind in (Parameter .VAR_POSITIONAL , Parameter .VAR_KEYWORD ):
254+ continue
255+ param_annotation = (
256+ param .annotation if param .annotation != param .empty else Any
257+ )
258+ if param .default is param .empty :
259+ fields [name ] = (param_annotation , ...)
260+ else :
261+ fields [name ] = (param_annotation , param .default )
262+ if fields :
263+ args_schema = create_model (f"{ tool .name } Input" , ** fields )
245264 else :
246- # Create a default schema with no fields if no parameters are found
247265 args_schema = create_model (
248266 f"{ tool .name } Input" , __base__ = PydanticBaseModel
249267 )
@@ -257,53 +275,37 @@ def from_langchain(cls, tool: Any) -> BaseTool:
257275
258276 def _set_args_schema (self ) -> None :
259277 if self .args_schema is None :
260- class_name = f"{ self .__class__ .__name__ } Schema"
261- self .args_schema = cast (
262- type [PydanticBaseModel ],
263- type (
264- class_name ,
265- (PydanticBaseModel ,),
266- {
267- "__annotations__" : {
268- k : v
269- for k , v in self ._run .__annotations__ .items ()
270- if k != "return"
271- },
272- },
273- ),
274- )
278+ run_sig = signature (self ._run )
279+ fields : dict [str , Any ] = {}
275280
276- def _generate_description (self ) -> None :
277- args_schema = {
278- name : {
279- "description" : field .description ,
280- "type" : BaseTool ._get_arg_annotations (field .annotation ),
281- }
282- for name , field in self .args_schema .model_fields .items ()
283- }
284-
285- self .description = f"Tool Name: { self .name } \n Tool Arguments: { args_schema } \n Tool Description: { self .description } "
286-
287- @staticmethod
288- def _get_arg_annotations (annotation : type [Any ] | None ) -> str :
289- if annotation is None :
290- return "None"
291-
292- origin = get_origin (annotation )
293- args = get_args (annotation )
294-
295- if origin is None :
296- return (
297- annotation .__name__
298- if hasattr (annotation , "__name__" )
299- else str (annotation )
300- )
281+ for param_name , param in run_sig .parameters .items ():
282+ if param_name in ("self" , "return" ):
283+ continue
284+ if param .kind in (Parameter .VAR_POSITIONAL , Parameter .VAR_KEYWORD ):
285+ continue
286+
287+ annotation = (
288+ param .annotation if param .annotation != param .empty else Any
289+ )
301290
302- if args :
303- args_str = ", " .join (BaseTool ._get_arg_annotations (arg ) for arg in args )
304- return str (f"{ origin .__name__ } [{ args_str } ]" )
291+ if param .default is param .empty :
292+ fields [param_name ] = (annotation , ...)
293+ else :
294+ fields [param_name ] = (annotation , param .default )
305295
306- return str (origin .__name__ )
296+ self .args_schema = create_model (
297+ f"{ self .__class__ .__name__ } Schema" , ** fields
298+ )
299+
300+ def _generate_description (self ) -> None :
301+ """Generate the tool description with a JSON schema for arguments."""
302+ schema = generate_model_description (self .args_schema )
303+ args_json = json .dumps (schema ["json_schema" ]["schema" ], indent = 2 )
304+ self .description = (
305+ f"Tool Name: { self .name } \n "
306+ f"Tool Arguments: { args_json } \n "
307+ f"Tool Description: { self .description } "
308+ )
307309
308310
309311class Tool (BaseTool , Generic [P , R ]):
@@ -406,24 +408,23 @@ def from_langchain(cls, tool: Any) -> Tool[..., Any]:
406408 args_schema = getattr (tool , "args_schema" , None )
407409
408410 if args_schema is None :
409- # Infer args_schema from the function signature if not provided
410411 func_signature = signature (tool .func )
411- annotations = func_signature .parameters
412- args_fields : dict [str , Any ] = {}
413- for name , param in annotations .items ():
414- if name != "self" :
415- param_annotation = (
416- param .annotation if param .annotation != param .empty else Any
417- )
418- field_info = Field (
419- default = ...,
420- description = "" ,
421- )
422- args_fields [name ] = (param_annotation , field_info )
423- if args_fields :
424- args_schema = create_model (f"{ tool .name } Input" , ** args_fields )
412+ fields : dict [str , Any ] = {}
413+ for name , param in func_signature .parameters .items ():
414+ if name == "self" :
415+ continue
416+ if param .kind in (Parameter .VAR_POSITIONAL , Parameter .VAR_KEYWORD ):
417+ continue
418+ param_annotation = (
419+ param .annotation if param .annotation != param .empty else Any
420+ )
421+ if param .default is param .empty :
422+ fields [name ] = (param_annotation , ...)
423+ else :
424+ fields [name ] = (param_annotation , param .default )
425+ if fields :
426+ args_schema = create_model (f"{ tool .name } Input" , ** fields )
425427 else :
426- # Create a default schema with no fields if no parameters are found
427428 args_schema = create_model (
428429 f"{ tool .name } Input" , __base__ = PydanticBaseModel
429430 )
@@ -502,32 +503,38 @@ def _make_with_name(tool_name: str) -> Callable[[Callable[P2, R2]], Tool[P2, R2]
502503 def _make_tool (f : Callable [P2 , R2 ]) -> Tool [P2 , R2 ]:
503504 if f .__doc__ is None :
504505 raise ValueError ("Function must have a docstring" )
505-
506- func_annotations = getattr (f , "__annotations__" , None )
507- if func_annotations is None :
506+ if f .__annotations__ is None :
508507 raise ValueError ("Function must have type annotations" )
509508
509+ func_sig = signature (f )
510+ fields : dict [str , Any ] = {}
511+
512+ for param_name , param in func_sig .parameters .items ():
513+ if param_name == "return" :
514+ continue
515+ if param .kind in (Parameter .VAR_POSITIONAL , Parameter .VAR_KEYWORD ):
516+ continue
517+
518+ annotation = (
519+ param .annotation if param .annotation != param .empty else Any
520+ )
521+
522+ if param .default is param .empty :
523+ fields [param_name ] = (annotation , ...)
524+ else :
525+ fields [param_name ] = (annotation , param .default )
526+
510527 class_name = "" .join (tool_name .split ()).title ()
511- tool_args_schema = cast (
512- type [PydanticBaseModel ],
513- type (
514- class_name ,
515- (PydanticBaseModel ,),
516- {
517- "__annotations__" : {
518- k : v for k , v in func_annotations .items () if k != "return"
519- },
520- },
521- ),
522- )
528+ args_schema = create_model (class_name , ** fields )
523529
524530 return Tool (
525531 name = tool_name ,
526532 description = f .__doc__ ,
527533 func = f ,
528- args_schema = tool_args_schema ,
534+ args_schema = args_schema ,
529535 result_as_answer = result_as_answer ,
530536 max_usage_count = max_usage_count ,
537+ current_usage_count = 0 ,
531538 )
532539
533540 return _make_tool
0 commit comments