Skip to content
This repository was archived by the owner on Nov 10, 2025. It is now read-only.

Commit 8094f3b

Browse files
authored
fix: handle properly anyOf oneOf allOf schema's props (#472)
1 parent 9bae563 commit 8094f3b

File tree

2 files changed

+409
-145
lines changed

2 files changed

+409
-145
lines changed

crewai_tools/tools/crewai_platform_tools/crewai_platform_action_tool.py

Lines changed: 201 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,83 @@
44
import re
55
import json
66
import requests
7-
from typing import Dict, Any, List, Type, Optional, Union, get_origin, cast, Literal
7+
from typing import Dict, Any, List, Type, Optional, Union, get_origin, cast
88
from pydantic import Field, create_model
99
from crewai.tools import BaseTool
1010
from crewai_tools.tools.crewai_platform_tools.misc import get_platform_api_base_url, get_platform_integration_token
1111

1212

13+
class AllOfSchemaAnalyzer:
14+
"""Helper class to analyze and merge allOf schemas."""
15+
16+
def __init__(self, schemas: List[Dict[str, Any]]):
17+
self.schemas = schemas
18+
self._explicit_types = []
19+
self._merged_properties = {}
20+
self._merged_required = []
21+
self._analyze_schemas()
22+
23+
def _analyze_schemas(self) -> None:
24+
"""Analyze all schemas and extract relevant information."""
25+
for schema in self.schemas:
26+
if "type" in schema:
27+
self._explicit_types.append(schema["type"])
28+
29+
# Merge object properties
30+
if schema.get("type") == "object" and "properties" in schema:
31+
self._merged_properties.update(schema["properties"])
32+
if "required" in schema:
33+
self._merged_required.extend(schema["required"])
34+
35+
def has_consistent_type(self) -> bool:
36+
"""Check if all schemas have the same explicit type."""
37+
return len(set(self._explicit_types)) == 1 if self._explicit_types else False
38+
39+
def get_consistent_type(self) -> Type[Any]:
40+
"""Get the consistent type if all schemas agree."""
41+
if not self.has_consistent_type():
42+
raise ValueError("No consistent type found")
43+
44+
type_mapping = {
45+
"string": str,
46+
"integer": int,
47+
"number": float,
48+
"boolean": bool,
49+
"array": list,
50+
"object": dict,
51+
"null": type(None),
52+
}
53+
return type_mapping.get(self._explicit_types[0], str)
54+
55+
def has_object_schemas(self) -> bool:
56+
"""Check if any schemas are object types with properties."""
57+
return bool(self._merged_properties)
58+
59+
def get_merged_properties(self) -> Dict[str, Any]:
60+
"""Get merged properties from all object schemas."""
61+
return self._merged_properties
62+
63+
def get_merged_required_fields(self) -> List[str]:
64+
"""Get merged required fields from all object schemas."""
65+
return list(set(self._merged_required)) # Remove duplicates
66+
67+
def get_fallback_type(self) -> Type[Any]:
68+
"""Get a fallback type when merging fails."""
69+
if self._explicit_types:
70+
# Use the first explicit type
71+
type_mapping = {
72+
"string": str,
73+
"integer": int,
74+
"number": float,
75+
"boolean": bool,
76+
"array": list,
77+
"object": dict,
78+
"null": type(None),
79+
}
80+
return type_mapping.get(self._explicit_types[0], str)
81+
return str
82+
83+
1384
class CrewAIPlatformActionTool(BaseTool):
1485
action_name: str = Field(default="", description="The name of the action")
1586
action_schema: Dict[str, Any] = Field(
@@ -84,40 +155,150 @@ def _extract_schema_info(
84155
return schema_props, required
85156

86157
def _process_schema_type(self, schema: Dict[str, Any], type_name: str) -> Type[Any]:
87-
if "anyOf" in schema:
88-
any_of_types = schema["anyOf"]
89-
is_nullable = any(t.get("type") == "null" for t in any_of_types)
90-
non_null_types = [t for t in any_of_types if t.get("type") != "null"]
158+
"""
159+
Process a JSON Schema type definition into a Python type.
160+
161+
Handles complex schema constructs like anyOf, oneOf, allOf, enums, arrays, and objects.
162+
"""
163+
# Handle composite schema types (anyOf, oneOf, allOf)
164+
if composite_type := self._process_composite_schema(schema, type_name):
165+
return composite_type
91166

92-
if non_null_types:
93-
base_type = self._process_schema_type(non_null_types[0], type_name)
94-
return Optional[base_type] if is_nullable else base_type
95-
return cast(Type[Any], Optional[str])
167+
# Handle primitive types and simple constructs
168+
return self._process_primitive_schema(schema, type_name)
96169

97-
if "oneOf" in schema:
98-
return self._process_schema_type(schema["oneOf"][0], type_name)
170+
def _process_composite_schema(self, schema: Dict[str, Any], type_name: str) -> Optional[Type[Any]]:
171+
"""Process composite schema types: anyOf, oneOf, allOf."""
172+
if "anyOf" in schema:
173+
return self._process_any_of_schema(schema["anyOf"], type_name)
174+
elif "oneOf" in schema:
175+
return self._process_one_of_schema(schema["oneOf"], type_name)
176+
elif "allOf" in schema:
177+
return self._process_all_of_schema(schema["allOf"], type_name)
178+
return None
179+
180+
def _process_any_of_schema(self, any_of_types: List[Dict[str, Any]], type_name: str) -> Type[Any]:
181+
"""Process anyOf schema - creates Union of possible types."""
182+
is_nullable = any(t.get("type") == "null" for t in any_of_types)
183+
non_null_types = [t for t in any_of_types if t.get("type") != "null"]
184+
185+
if not non_null_types:
186+
return cast(Type[Any], Optional[str]) # fallback for only-null case
187+
188+
base_type = (
189+
self._process_schema_type(non_null_types[0], type_name)
190+
if len(non_null_types) == 1
191+
else self._create_union_type(non_null_types, type_name, "AnyOf")
192+
)
193+
return Optional[base_type] if is_nullable else base_type
194+
195+
def _process_one_of_schema(self, one_of_types: List[Dict[str, Any]], type_name: str) -> Type[Any]:
196+
"""Process oneOf schema - creates Union of mutually exclusive types."""
197+
return (
198+
self._process_schema_type(one_of_types[0], type_name)
199+
if len(one_of_types) == 1
200+
else self._create_union_type(one_of_types, type_name, "OneOf")
201+
)
99202

100-
if "allOf" in schema:
101-
return self._process_schema_type(schema["allOf"][0], type_name)
203+
def _process_all_of_schema(self, all_of_schemas: List[Dict[str, Any]], type_name: str) -> Type[Any]:
204+
"""Process allOf schema - merges schemas that must all be satisfied."""
205+
if len(all_of_schemas) == 1:
206+
return self._process_schema_type(all_of_schemas[0], type_name)
207+
return self._merge_all_of_schemas(all_of_schemas, type_name)
208+
209+
def _create_union_type(self, schemas: List[Dict[str, Any]], type_name: str, prefix: str) -> Type[Any]:
210+
"""Create a Union type from multiple schemas."""
211+
return Union[
212+
tuple(
213+
self._process_schema_type(schema, f"{type_name}{prefix}{i}")
214+
for i, schema in enumerate(schemas)
215+
)
216+
]
102217

218+
def _process_primitive_schema(self, schema: Dict[str, Any], type_name: str) -> Type[Any]:
219+
"""Process primitive schema types: string, number, array, object, etc."""
103220
json_type = schema.get("type", "string")
104221

105222
if "enum" in schema:
106-
enum_values = schema["enum"]
107-
if not enum_values:
108-
return self._map_json_type_to_python(json_type)
109-
return Literal[tuple(enum_values)]
223+
return self._process_enum_schema(schema, json_type)
110224

111225
if json_type == "array":
112-
items_schema = schema.get("items", {"type": "string"})
113-
item_type = self._process_schema_type(items_schema, f"{type_name}Item")
114-
return List[item_type]
226+
return self._process_array_schema(schema, type_name)
115227

116228
if json_type == "object":
117229
return self._create_nested_model(schema, type_name)
118230

119231
return self._map_json_type_to_python(json_type)
120232

233+
def _process_enum_schema(self, schema: Dict[str, Any], json_type: str) -> Type[Any]:
234+
"""Process enum schema - currently falls back to base type."""
235+
enum_values = schema["enum"]
236+
if not enum_values:
237+
return self._map_json_type_to_python(json_type)
238+
239+
# For Literal types, we need to pass the values directly, not as a tuple
240+
# This is a workaround since we can't dynamically create Literal types easily
241+
# Fall back to the base JSON type for now
242+
return self._map_json_type_to_python(json_type)
243+
244+
def _process_array_schema(self, schema: Dict[str, Any], type_name: str) -> Type[Any]:
245+
items_schema = schema.get("items", {"type": "string"})
246+
item_type = self._process_schema_type(items_schema, f"{type_name}Item")
247+
return List[item_type]
248+
249+
def _merge_all_of_schemas(self, schemas: List[Dict[str, Any]], type_name: str) -> Type[Any]:
250+
schema_analyzer = AllOfSchemaAnalyzer(schemas)
251+
252+
if schema_analyzer.has_consistent_type():
253+
return schema_analyzer.get_consistent_type()
254+
255+
if schema_analyzer.has_object_schemas():
256+
return self._create_merged_object_model(
257+
schema_analyzer.get_merged_properties(),
258+
schema_analyzer.get_merged_required_fields(),
259+
type_name
260+
)
261+
262+
return schema_analyzer.get_fallback_type()
263+
264+
def _create_merged_object_model(self, properties: Dict[str, Any], required: List[str], model_name: str) -> Type[Any]:
265+
full_model_name = f"{self._base_name}{model_name}AllOf"
266+
267+
if full_model_name in self._model_registry:
268+
return self._model_registry[full_model_name]
269+
270+
if not properties:
271+
return dict
272+
273+
field_definitions = self._build_field_definitions(properties, required, model_name)
274+
275+
try:
276+
merged_model = create_model(full_model_name, **field_definitions)
277+
self._model_registry[full_model_name] = merged_model
278+
return merged_model
279+
except Exception as e:
280+
return dict
281+
282+
def _build_field_definitions(self, properties: Dict[str, Any], required: List[str], model_name: str) -> Dict[str, Any]:
283+
field_definitions = {}
284+
285+
for prop_name, prop_schema in properties.items():
286+
prop_desc = prop_schema.get("description", "")
287+
is_required = prop_name in required
288+
289+
try:
290+
prop_type = self._process_schema_type(
291+
prop_schema, f"{model_name}{self._sanitize_name(prop_name).title()}"
292+
)
293+
except Exception:
294+
prop_type = str
295+
296+
field_definitions[prop_name] = self._create_field_definition(
297+
prop_type, is_required, prop_desc
298+
)
299+
300+
return field_definitions
301+
121302
def _create_nested_model(self, schema: Dict[str, Any], model_name: str) -> Type[Any]:
122303
full_model_name = f"{self._base_name}{model_name}"
123304

0 commit comments

Comments
 (0)