Skip to content

Commit 38b0b12

Browse files
feat: use json schema for tool argument serialization
- Replace Python representation with JsonSchema for tool arguments - Remove deprecated PydanticSchemaParser in favor of direct schema generation - Add handling for VAR_POSITIONAL and VAR_KEYWORD parameters - Improve tool argument schema collection
1 parent 9bd8ad5 commit 38b0b12

File tree

15 files changed

+442
-602
lines changed

15 files changed

+442
-602
lines changed

lib/crewai/src/crewai/agent/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
KnowledgeSearchQueryFailedEvent,
1717
)
1818
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
19-
from crewai.utilities.converter import generate_model_description
19+
from crewai.utilities.pydantic_schema_utils import generate_model_description
2020

2121

2222
if TYPE_CHECKING:

lib/crewai/src/crewai/agents/agent_adapters/base_converter_adapter.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
from abc import ABC, abstractmethod
66
import json
77
import re
8-
from typing import TYPE_CHECKING, Final, Literal
9-
10-
from crewai.utilities.converter import generate_model_description
8+
from typing import TYPE_CHECKING, Any, Final, Literal
119

10+
from crewai.utilities.pydantic_schema_utils import generate_model_description
1211

1312

1413
if TYPE_CHECKING:
@@ -42,7 +41,7 @@ def __init__(self, agent_adapter: BaseAgentAdapter) -> None:
4241
"""
4342
self.agent_adapter = agent_adapter
4443
self._output_format: Literal["json", "pydantic"] | None = None
45-
self._schema: str | None = None
44+
self._schema: dict[str, Any] | None = None
4645

4746
@abstractmethod
4847
def configure_structured_output(self, task: Task) -> None:
@@ -129,7 +128,7 @@ def _extract_json_from_text(result: str) -> str:
129128
@staticmethod
130129
def _configure_format_from_task(
131130
task: Task,
132-
) -> tuple[Literal["json", "pydantic"] | None, str | None]:
131+
) -> tuple[Literal["json", "pydantic"] | None, dict[str, Any] | None]:
133132
"""Determine output format and schema from task requirements.
134133
135134
This is a helper method that examines the task's output requirements

lib/crewai/src/crewai/agents/agent_adapters/openai_agents/structured_output_converter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
output conversion for OpenAI agents, supporting JSON and Pydantic model formats.
55
"""
66

7+
import json
78
from typing import Any
89

910
from crewai.agents.agent_adapters.base_converter_adapter import BaseConverterAdapter
@@ -61,7 +62,7 @@ def enhance_system_prompt(self, base_prompt: str) -> str:
6162
output_schema: str = (
6263
get_i18n()
6364
.slice("formatted_task_instructions")
64-
.format(output_format=self._schema)
65+
.format(output_format=json.dumps(self._schema, indent=2))
6566
)
6667

6768
return f"{base_prompt}\n\n{output_schema}"

lib/crewai/src/crewai/llms/providers/azure/completion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from typing_extensions import Self
1010

1111
from crewai.utilities.agent_utils import is_context_length_exceeded
12-
from crewai.utilities.converter import generate_model_description
1312
from crewai.utilities.exceptions.context_window_exceeding_exception import (
1413
LLMContextLengthExceededError,
1514
)
15+
from crewai.utilities.pydantic_schema_utils import generate_model_description
1616
from crewai.utilities.types import LLMMessage
1717

1818

lib/crewai/src/crewai/llms/providers/openai/completion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
from crewai.llms.base_llm import BaseLLM
1919
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
2020
from crewai.utilities.agent_utils import is_context_length_exceeded
21-
from crewai.utilities.converter import generate_model_description
2221
from crewai.utilities.exceptions.context_window_exceeding_exception import (
2322
LLMContextLengthExceededError,
2423
)
24+
from crewai.utilities.pydantic_schema_utils import generate_model_description
2525
from crewai.utilities.types import LLMMessage
2626

2727

lib/crewai/src/crewai/tools/base_tool.py

Lines changed: 117 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,13 @@
33
from abc import ABC, abstractmethod
44
import asyncio
55
from collections.abc import Awaitable, Callable
6-
from inspect import signature
6+
from inspect import Parameter, signature
7+
import json
78
from typing import (
89
Any,
910
Generic,
1011
ParamSpec,
1112
TypeVar,
12-
cast,
13-
get_args,
14-
get_origin,
1513
overload,
1614
)
1715

@@ -27,6 +25,7 @@
2725

2826
from crewai.tools.structured_tool import CrewStructuredTool
2927
from crewai.utilities.printer import Printer
28+
from crewai.utilities.pydantic_schema_utils import generate_model_description
3029

3130

3231
_printer = Printer()
@@ -103,20 +102,40 @@ def _default_args_schema(
103102
if v != cls._ArgsSchemaPlaceholder:
104103
return v
105104

106-
return cast(
107-
type[PydanticBaseModel],
108-
type(
109-
f"{cls.__name__}Schema",
110-
(PydanticBaseModel,),
111-
{
112-
"__annotations__": {
113-
k: v
114-
for k, v in cls._run.__annotations__.items()
115-
if k != "return"
116-
},
117-
},
118-
),
119-
)
105+
run_sig = signature(cls._run)
106+
fields: dict[str, Any] = {}
107+
108+
for param_name, param in run_sig.parameters.items():
109+
if param_name in ("self", "return"):
110+
continue
111+
if param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
112+
continue
113+
114+
annotation = param.annotation if param.annotation != param.empty else Any
115+
116+
if param.default is param.empty:
117+
fields[param_name] = (annotation, ...)
118+
else:
119+
fields[param_name] = (annotation, param.default)
120+
121+
if not fields:
122+
arun_sig = signature(cls._arun)
123+
for param_name, param in arun_sig.parameters.items():
124+
if param_name in ("self", "return"):
125+
continue
126+
if param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
127+
continue
128+
129+
annotation = (
130+
param.annotation if param.annotation != param.empty else Any
131+
)
132+
133+
if param.default is param.empty:
134+
fields[param_name] = (annotation, ...)
135+
else:
136+
fields[param_name] = (annotation, param.default)
137+
138+
return create_model(f"{cls.__name__}Schema", **fields)
120139

121140
@field_validator("max_usage_count", mode="before")
122141
@classmethod
@@ -226,24 +245,23 @@ def from_langchain(cls, tool: Any) -> BaseTool:
226245
args_schema = getattr(tool, "args_schema", None)
227246

228247
if args_schema is None:
229-
# Infer args_schema from the function signature if not provided
230248
func_signature = signature(tool.func)
231-
annotations = func_signature.parameters
232-
args_fields: dict[str, Any] = {}
233-
for name, param in annotations.items():
234-
if name != "self":
235-
param_annotation = (
236-
param.annotation if param.annotation != param.empty else Any
237-
)
238-
field_info = Field(
239-
default=...,
240-
description="",
241-
)
242-
args_fields[name] = (param_annotation, field_info)
243-
if args_fields:
244-
args_schema = create_model(f"{tool.name}Input", **args_fields)
249+
fields: dict[str, Any] = {}
250+
for name, param in func_signature.parameters.items():
251+
if name == "self":
252+
continue
253+
if param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
254+
continue
255+
param_annotation = (
256+
param.annotation if param.annotation != param.empty else Any
257+
)
258+
if param.default is param.empty:
259+
fields[name] = (param_annotation, ...)
260+
else:
261+
fields[name] = (param_annotation, param.default)
262+
if fields:
263+
args_schema = create_model(f"{tool.name}Input", **fields)
245264
else:
246-
# Create a default schema with no fields if no parameters are found
247265
args_schema = create_model(
248266
f"{tool.name}Input", __base__=PydanticBaseModel
249267
)
@@ -257,53 +275,37 @@ def from_langchain(cls, tool: Any) -> BaseTool:
257275

258276
def _set_args_schema(self) -> None:
259277
if self.args_schema is None:
260-
class_name = f"{self.__class__.__name__}Schema"
261-
self.args_schema = cast(
262-
type[PydanticBaseModel],
263-
type(
264-
class_name,
265-
(PydanticBaseModel,),
266-
{
267-
"__annotations__": {
268-
k: v
269-
for k, v in self._run.__annotations__.items()
270-
if k != "return"
271-
},
272-
},
273-
),
274-
)
278+
run_sig = signature(self._run)
279+
fields: dict[str, Any] = {}
275280

276-
def _generate_description(self) -> None:
277-
args_schema = {
278-
name: {
279-
"description": field.description,
280-
"type": BaseTool._get_arg_annotations(field.annotation),
281-
}
282-
for name, field in self.args_schema.model_fields.items()
283-
}
284-
285-
self.description = f"Tool Name: {self.name}\nTool Arguments: {args_schema}\nTool Description: {self.description}"
286-
287-
@staticmethod
288-
def _get_arg_annotations(annotation: type[Any] | None) -> str:
289-
if annotation is None:
290-
return "None"
291-
292-
origin = get_origin(annotation)
293-
args = get_args(annotation)
294-
295-
if origin is None:
296-
return (
297-
annotation.__name__
298-
if hasattr(annotation, "__name__")
299-
else str(annotation)
300-
)
281+
for param_name, param in run_sig.parameters.items():
282+
if param_name in ("self", "return"):
283+
continue
284+
if param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
285+
continue
286+
287+
annotation = (
288+
param.annotation if param.annotation != param.empty else Any
289+
)
301290

302-
if args:
303-
args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args)
304-
return str(f"{origin.__name__}[{args_str}]")
291+
if param.default is param.empty:
292+
fields[param_name] = (annotation, ...)
293+
else:
294+
fields[param_name] = (annotation, param.default)
305295

306-
return str(origin.__name__)
296+
self.args_schema = create_model(
297+
f"{self.__class__.__name__}Schema", **fields
298+
)
299+
300+
def _generate_description(self) -> None:
301+
"""Generate the tool description with a JSON schema for arguments."""
302+
schema = generate_model_description(self.args_schema)
303+
args_json = json.dumps(schema["json_schema"]["schema"], indent=2)
304+
self.description = (
305+
f"Tool Name: {self.name}\n"
306+
f"Tool Arguments: {args_json}\n"
307+
f"Tool Description: {self.description}"
308+
)
307309

308310

309311
class Tool(BaseTool, Generic[P, R]):
@@ -406,24 +408,23 @@ def from_langchain(cls, tool: Any) -> Tool[..., Any]:
406408
args_schema = getattr(tool, "args_schema", None)
407409

408410
if args_schema is None:
409-
# Infer args_schema from the function signature if not provided
410411
func_signature = signature(tool.func)
411-
annotations = func_signature.parameters
412-
args_fields: dict[str, Any] = {}
413-
for name, param in annotations.items():
414-
if name != "self":
415-
param_annotation = (
416-
param.annotation if param.annotation != param.empty else Any
417-
)
418-
field_info = Field(
419-
default=...,
420-
description="",
421-
)
422-
args_fields[name] = (param_annotation, field_info)
423-
if args_fields:
424-
args_schema = create_model(f"{tool.name}Input", **args_fields)
412+
fields: dict[str, Any] = {}
413+
for name, param in func_signature.parameters.items():
414+
if name == "self":
415+
continue
416+
if param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
417+
continue
418+
param_annotation = (
419+
param.annotation if param.annotation != param.empty else Any
420+
)
421+
if param.default is param.empty:
422+
fields[name] = (param_annotation, ...)
423+
else:
424+
fields[name] = (param_annotation, param.default)
425+
if fields:
426+
args_schema = create_model(f"{tool.name}Input", **fields)
425427
else:
426-
# Create a default schema with no fields if no parameters are found
427428
args_schema = create_model(
428429
f"{tool.name}Input", __base__=PydanticBaseModel
429430
)
@@ -502,32 +503,38 @@ def _make_with_name(tool_name: str) -> Callable[[Callable[P2, R2]], Tool[P2, R2]
502503
def _make_tool(f: Callable[P2, R2]) -> Tool[P2, R2]:
503504
if f.__doc__ is None:
504505
raise ValueError("Function must have a docstring")
505-
506-
func_annotations = getattr(f, "__annotations__", None)
507-
if func_annotations is None:
506+
if f.__annotations__ is None:
508507
raise ValueError("Function must have type annotations")
509508

509+
func_sig = signature(f)
510+
fields: dict[str, Any] = {}
511+
512+
for param_name, param in func_sig.parameters.items():
513+
if param_name == "return":
514+
continue
515+
if param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
516+
continue
517+
518+
annotation = (
519+
param.annotation if param.annotation != param.empty else Any
520+
)
521+
522+
if param.default is param.empty:
523+
fields[param_name] = (annotation, ...)
524+
else:
525+
fields[param_name] = (annotation, param.default)
526+
510527
class_name = "".join(tool_name.split()).title()
511-
tool_args_schema = cast(
512-
type[PydanticBaseModel],
513-
type(
514-
class_name,
515-
(PydanticBaseModel,),
516-
{
517-
"__annotations__": {
518-
k: v for k, v in func_annotations.items() if k != "return"
519-
},
520-
},
521-
),
522-
)
528+
args_schema = create_model(class_name, **fields)
523529

524530
return Tool(
525531
name=tool_name,
526532
description=f.__doc__,
527533
func=f,
528-
args_schema=tool_args_schema,
534+
args_schema=args_schema,
529535
result_as_answer=result_as_answer,
530536
max_usage_count=max_usage_count,
537+
current_usage_count=0,
531538
)
532539

533540
return _make_tool

0 commit comments

Comments
 (0)