Skip to content
This repository was archived by the owner on Nov 10, 2025. It is now read-only.

Commit 9ad5991

Browse files
authored
refactor: renaming init_params and run_params to reflect their schema. (#332) (#333)
We’re currently using the JSON Schema standard for these fields
1 parent 6cb02cf commit 9ad5991

File tree

3 files changed

+5872
-845
lines changed

3 files changed

+5872
-845
lines changed

generate_tool_specs.py

Lines changed: 29 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,17 @@
55
from pathlib import Path
66
from typing import Any, Dict, List, Optional, Type
77

8+
from pydantic import BaseModel
9+
810
from crewai_tools import tools
9-
from crewai.tools.base_tool import EnvVar
11+
from crewai.tools.base_tool import BaseTool, EnvVar
12+
13+
from pydantic.json_schema import GenerateJsonSchema
14+
from pydantic_core import PydanticOmit
1015

16+
class SchemaGenerator(GenerateJsonSchema):
17+
def handle_invalid_for_json_schema(self, schema, error_info):
18+
raise PydanticOmit
1119

1220
class ToolSpecExtractor:
1321
def __init__(self) -> None:
@@ -22,20 +30,21 @@ def extract_all_tools(self) -> List[Dict[str, Any]]:
2230
self.extract_tool_info(obj)
2331
self.processed_tools.add(name)
2432
return self.tools_spec
25-
26-
def extract_tool_info(self, tool_class: Type) -> None:
33+
def extract_tool_info(self, tool_class: BaseTool) -> None:
2734
try:
2835
core_schema = tool_class.__pydantic_core_schema__
2936
if not core_schema:
3037
return
3138

3239
schema = self._unwrap_schema(core_schema)
3340
fields = schema.get("schema", {}).get("fields", {})
41+
3442
tool_info = {
3543
"name": tool_class.__name__,
3644
"humanized_name": self._extract_field_default(fields.get("name"), fallback=tool_class.__name__),
3745
"description": self._extract_field_default(fields.get("description")).strip(),
38-
"run_params": self._extract_params(fields.get("args_schema")),
46+
"run_params_schema": self._extract_params(fields.get("args_schema")),
47+
"init_params_schema": self._extract_init_params(tool_class),
3948
"env_vars": self._extract_env_vars(fields.get("env_vars")),
4049
"package_dependencies": self._extract_field_default(fields.get("package_dependencies"), fallback=[]),
4150
}
@@ -60,35 +69,17 @@ def _extract_field_default(self, field: Optional[Dict], fallback: str = "") -> s
6069

6170
def _extract_params(self, args_schema_field: Optional[Dict]) -> List[Dict[str, str]]:
6271
if not args_schema_field:
63-
return []
72+
return {}
6473

6574
args_schema_class = args_schema_field.get("schema", {}).get("default")
6675
if not (inspect.isclass(args_schema_class) and hasattr(args_schema_class, "__pydantic_core_schema__")):
67-
return []
76+
return {}
6877

6978
try:
70-
core_schema = args_schema_class.__pydantic_core_schema__
71-
schema = self._unwrap_schema(core_schema)
72-
fields = schema.get("schema", {}).get("fields", {})
73-
74-
params = []
75-
for name, info in fields.items():
76-
_type = self._extract_param_type(info)
77-
if _type == "union":
78-
breakpoint()
79-
param = {
80-
"name": name,
81-
"description": self._extract_field_description_from_metadata(info),
82-
"type": _type,
83-
"default": self._extract_field_default(info),
84-
}
85-
params.append(param)
86-
87-
return params
88-
79+
return args_schema_class.model_json_schema(schema_generator=SchemaGenerator, mode='validation')
8980
except Exception as e:
9081
print(f"Error extracting params from {args_schema_class}: {e}")
91-
return []
82+
return {}
9283

9384
def _extract_env_vars(self, env_vars_field: Optional[Dict]) -> List[Dict[str, str]]:
9485
if not env_vars_field:
@@ -105,47 +96,18 @@ def _extract_env_vars(self, env_vars_field: Optional[Dict]) -> List[Dict[str, st
10596
})
10697
return env_vars
10798

108-
def _extract_field_description_from_metadata(self, field: Dict) -> str:
109-
if metadata := field.get("metadata"):
110-
return metadata.get("pydantic_js_updates", {}).get("description", "")
111-
return ""
112-
113-
def _extract_param_type(self, info: Dict) -> Optional[str]:
114-
schema = info.get("schema", {})
115-
schema = self._unwrap_schema(schema)
116-
117-
if schema.get("type") == "nullable":
118-
inner = schema.get("schema", {})
119-
return self._schema_type_to_str(inner)
120-
121-
return self._schema_type_to_str(schema)
122-
123-
def _schema_type_to_str(self, schema: Dict) -> str:
124-
schema_type = schema.get("type", "")
125-
126-
if schema_type == "list" and "items_schema" in schema:
127-
item_type = self._schema_type_to_str(schema["items_schema"])
128-
return f"list[{item_type}]"
129-
130-
if schema_type == "union" and "choices" in schema:
131-
choices = schema["choices"]
132-
item_types = [self._schema_type_to_str(choice) for choice in choices]
133-
return f"union[{', '.join(item_types)}]"
134-
135-
if schema_type == "dict" and "keys_schema" in schema and "values_schema" in schema:
136-
key_type = self._schema_type_to_str(schema["keys_schema"])
137-
value_type = self._schema_type_to_str(schema["values_schema"])
138-
return f"dict[{key_type}, {value_type}]"
139-
140-
return {
141-
"str": "str",
142-
"int": "int",
143-
"float": "float",
144-
"bool": "bool",
145-
"list": "list",
146-
"dict": "dict",
147-
"any": "any",
148-
}.get(schema_type, schema_type or "unknown")
99+
def _extract_init_params(self, tool_class: BaseTool) -> dict:
100+
ignored_init_params = ['name', 'description', 'env_vars', 'args_schema', 'description_updated', 'cache_function', 'result_as_answer', 'max_usage_count', 'current_usage_count', 'package_dependencies']
101+
102+
json_schema = tool_class.model_json_schema(schema_generator=SchemaGenerator, mode='serialization')
103+
104+
properties = {}
105+
for key, value in json_schema['properties'].items():
106+
if key not in ignored_init_params:
107+
properties[key] = value
108+
109+
json_schema['properties'] = properties
110+
return json_schema
149111

150112
def save_to_json(self, output_path: str) -> None:
151113
with open(output_path, "w", encoding="utf-8") as f:

tests/test_generate_tool_specs.py

Lines changed: 82 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -44,91 +44,106 @@ def test_unwrap_schema(extractor):
4444
assert result["value"] == "test"
4545

4646

47-
@pytest.mark.parametrize(
48-
"schema, expected",
49-
[
50-
({"type": "str"}, "str"),
51-
({"type": "list", "items_schema": {"type": "str"}}, "list[str]"),
52-
({"type": "dict", "keys_schema": {"type": "str"}, "values_schema": {"type": "int"}}, "dict[str, int]"),
53-
({"type": "union", "choices": [{"type": "str"}, {"type": "int"}]}, "union[str, int]"),
54-
({"type": "custom_type"}, "custom_type"),
55-
({}, "unknown"),
56-
]
57-
)
58-
def test_schema_type_to_str(extractor, schema, expected):
59-
assert extractor._schema_type_to_str(schema) == expected
60-
61-
62-
@pytest.mark.parametrize(
63-
"info, expected_type",
64-
[
65-
({"schema": {"type": "str"}}, "str"),
66-
({"schema": {"type": "nullable", "schema": {"type": "int"}}}, "int"),
67-
({"schema": {"type": "default", "schema": {"type": "list", "items_schema": {"type": "str"}}}}, "list[str]"),
68-
]
69-
)
70-
def test_extract_param_type(extractor, info, expected_type):
71-
assert extractor._extract_param_type(info) == expected_type
72-
73-
74-
def test_extract_all_tools(extractor):
47+
@pytest.fixture
48+
def mock_tool_extractor(extractor):
7549
with mock.patch("generate_tool_specs.dir", return_value=["MockTool"]), \
7650
mock.patch("generate_tool_specs.getattr", return_value=MockTool):
7751
extractor.extract_all_tools()
78-
7952
assert len(extractor.tools_spec) == 1
80-
tool_info = extractor.tools_spec[0]
81-
82-
assert tool_info.keys() == {
83-
"name",
84-
"humanized_name",
85-
"description",
86-
"run_params",
87-
"env_vars",
88-
"init_params",
89-
"package_dependencies",
90-
}
91-
92-
assert tool_info["name"] == "MockTool"
93-
assert tool_info["humanized_name"] == "Mock Search Tool"
94-
assert tool_info["description"] == "A tool that mocks search functionality"
53+
return extractor.tools_spec[0]
54+
55+
def test_extract_basic_tool_info(mock_tool_extractor):
56+
tool_info = mock_tool_extractor
57+
58+
assert tool_info.keys() == {
59+
"name",
60+
"humanized_name",
61+
"description",
62+
"run_params_schema",
63+
"env_vars",
64+
"init_params_schema",
65+
"package_dependencies",
66+
}
9567

96-
assert len(tool_info["env_vars"]) == 2
97-
api_key_var, rate_limit_var = tool_info["env_vars"]
68+
assert tool_info["name"] == "MockTool"
69+
assert tool_info["humanized_name"] == "Mock Search Tool"
70+
assert tool_info["description"] == "A tool that mocks search functionality"
9871

99-
assert api_key_var["name"] == "SERPER_API_KEY"
100-
assert api_key_var["description"] == "API key for Serper"
101-
assert api_key_var["required"] == True
102-
assert api_key_var["default"] == None
72+
def test_extract_init_params_schema(mock_tool_extractor):
73+
tool_info = mock_tool_extractor
74+
init_params_schema = tool_info["init_params_schema"]
10375

104-
assert rate_limit_var["name"] == "API_RATE_LIMIT"
105-
assert rate_limit_var["description"] == "API rate limit"
106-
assert rate_limit_var["required"] == False
107-
assert rate_limit_var["default"] == "100"
76+
assert init_params_schema.keys() == {
77+
"$defs",
78+
"properties",
79+
"title",
80+
"type",
81+
}
10882

109-
assert len(tool_info["run_params"]) == 3
83+
another_parameter = init_params_schema['properties']['another_parameter']
84+
assert another_parameter["description"] == ""
85+
assert another_parameter["default"] == "Another way to define a default value"
86+
assert another_parameter["type"] == "string"
87+
88+
my_parameter = init_params_schema['properties']['my_parameter']
89+
assert my_parameter["description"] == "What a description"
90+
assert my_parameter["default"] == "This is default value"
91+
assert my_parameter["type"] == "string"
92+
93+
my_parameter_bool = init_params_schema['properties']['my_parameter_bool']
94+
assert my_parameter_bool["default"] == False
95+
assert my_parameter_bool["type"] == "boolean"
96+
97+
def test_extract_env_vars(mock_tool_extractor):
98+
tool_info = mock_tool_extractor
99+
100+
assert len(tool_info["env_vars"]) == 2
101+
api_key_var, rate_limit_var = tool_info["env_vars"]
102+
assert api_key_var["name"] == "SERPER_API_KEY"
103+
assert api_key_var["description"] == "API key for Serper"
104+
assert api_key_var["required"] == True
105+
assert api_key_var["default"] == None
106+
107+
assert rate_limit_var["name"] == "API_RATE_LIMIT"
108+
assert rate_limit_var["description"] == "API rate limit"
109+
assert rate_limit_var["required"] == False
110+
assert rate_limit_var["default"] == "100"
111+
112+
def test_extract_run_params_schema(mock_tool_extractor):
113+
tool_info = mock_tool_extractor
114+
115+
run_params_schema = tool_info["run_params_schema"]
116+
assert run_params_schema.keys() == {
117+
"properties",
118+
"required",
119+
"title",
120+
"type",
121+
}
110122

111-
params = {p["name"]: p for p in tool_info["run_params"]}
112-
assert params["query"]["description"] == "The query parameter"
113-
assert params["query"]["type"] == "str"
114-
assert params["query"]["default"] == ""
123+
query_param = run_params_schema["properties"]["query"]
124+
assert query_param["description"] == "The query parameter"
125+
assert query_param["type"] == "string"
115126

116-
assert params["count"]["type"] == "int"
117-
assert params["count"]["default"] == 5
127+
count_param = run_params_schema["properties"]["count"]
128+
assert count_param["type"] == "integer"
129+
assert count_param["default"] == 5
118130

119-
assert params["filters"]["description"] == "Optional filters to apply"
120-
assert params["filters"]["type"] == "list[str]"
121-
assert params["filters"]["default"] == ""
131+
filters_param = run_params_schema["properties"]["filters"]
132+
assert filters_param["description"] == "Optional filters to apply"
133+
assert filters_param["default"] == None
134+
assert filters_param['anyOf'] == [{'items': {'type': 'string'}, 'type': 'array'}, {'type': 'null'}]
122135

123-
assert tool_info["package_dependencies"] == ["this-is-a-required-package", "another-required-package"]
136+
def test_extract_package_dependencies(mock_tool_extractor):
137+
tool_info = mock_tool_extractor
138+
assert tool_info["package_dependencies"] == ["this-is-a-required-package", "another-required-package"]
124139

125140

126141
def test_save_to_json(extractor, tmp_path):
127142
extractor.tools_spec = [{
128143
"name": "TestTool",
129144
"humanized_name": "Test Tool",
130145
"description": "A test tool",
131-
"run_params": [
146+
"run_params_schema": [
132147
{"name": "param1", "description": "Test parameter", "type": "str"}
133148
]
134149
}]
@@ -144,20 +159,4 @@ def test_save_to_json(extractor, tmp_path):
144159
assert "tools" in data
145160
assert len(data["tools"]) == 1
146161
assert data["tools"][0]["humanized_name"] == "Test Tool"
147-
assert data["tools"][0]["run_params"][0]["name"] == "param1"
148-
149-
150-
@pytest.mark.integration
151-
def test_full_extraction_process():
152-
extractor = ToolSpecExtractor()
153-
specs = extractor.extract_all_tools()
154-
155-
assert len(specs) > 0
156-
157-
for tool in specs:
158-
assert "name" in tool
159-
assert "humanized_name" in tool and tool["humanized_name"]
160-
assert "description" in tool
161-
assert isinstance(tool["run_params"], list)
162-
for param in tool["run_params"]:
163-
assert "name" in param and param["name"]
162+
assert data["tools"][0]["run_params_schema"][0]["name"] == "param1"

0 commit comments

Comments
 (0)