Skip to content

Commit df8e4f7

Browse files
lorenzejaymplachta
authored andcommitted
refactor: enhance schema handling in EnterpriseActionTool (crewAIInc#355)
* refactor: enhance schema handling in EnterpriseActionTool - Extracted schema property and required field extraction into separate methods for better readability and maintainability. - Introduced methods to analyze field types and create Pydantic field definitions based on nullability and requirement status. - Updated the _run method to handle required nullable fields, ensuring they are set to None if not provided in kwargs. * refactor: streamline nullable field handling in EnterpriseActionTool - Removed commented-out code related to handling required nullable fields for clarity. - Simplified the logic in the _run method to focus on processing parameters without unnecessary comments.
1 parent 08ff4f2 commit df8e4f7

File tree

1 file changed

+93
-21
lines changed

1 file changed

+93
-21
lines changed

crewai_tools/adapters/enterprise_adapter.py

Lines changed: 93 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,34 +37,18 @@ def __init__(
3737
enterprise_action_kit_project_url: str = ENTERPRISE_ACTION_KIT_PROJECT_URL,
3838
enterprise_action_kit_project_id: str = ENTERPRISE_ACTION_KIT_PROJECT_ID,
3939
):
40-
schema_props = (
41-
action_schema.get("function", {})
42-
.get("parameters", {})
43-
.get("properties", {})
44-
)
45-
required = (
46-
action_schema.get("function", {}).get("parameters", {}).get("required", [])
47-
)
40+
schema_props, required = self._extract_schema_info(action_schema)
4841

4942
# Define field definitions for the model
5043
field_definitions = {}
5144
for param_name, param_details in schema_props.items():
52-
param_type = str # Default to string type
5345
param_desc = param_details.get("description", "")
5446
is_required = param_name in required
47+
is_nullable, param_type = self._analyze_field_type(param_details)
5548

56-
# Basic type mapping (can be extended)
57-
if param_details.get("type") == "integer":
58-
param_type = int
59-
elif param_details.get("type") == "number":
60-
param_type = float
61-
elif param_details.get("type") == "boolean":
62-
param_type = bool
63-
64-
# Create field with appropriate type and config
65-
field_definitions[param_name] = (
66-
param_type if is_required else Optional[param_type],
67-
Field(description=param_desc),
49+
# Create field definition based on nullable and required status
50+
field_definitions[param_name] = self._create_field_definition(
51+
param_type, is_required, is_nullable, param_desc
6852
)
6953

7054
# Create the model
@@ -89,9 +73,97 @@ def __init__(
8973
if enterprise_action_kit_project_url is not None:
9074
self.enterprise_action_kit_project_url = enterprise_action_kit_project_url
9175

76+
def _extract_schema_info(
77+
self, action_schema: Dict[str, Any]
78+
) -> tuple[Dict[str, Any], List[str]]:
79+
"""Extract schema properties and required fields from action schema."""
80+
schema_props = (
81+
action_schema.get("function", {})
82+
.get("parameters", {})
83+
.get("properties", {})
84+
)
85+
required = (
86+
action_schema.get("function", {}).get("parameters", {}).get("required", [])
87+
)
88+
return schema_props, required
89+
90+
def _analyze_field_type(self, param_details: Dict[str, Any]) -> tuple[bool, type]:
91+
"""Analyze field type and nullability from parameter details."""
92+
is_nullable = False
93+
param_type = str # Default type
94+
95+
if "anyOf" in param_details:
96+
any_of_types = param_details["anyOf"]
97+
is_nullable = any(t.get("type") == "null" for t in any_of_types)
98+
non_null_types = [t for t in any_of_types if t.get("type") != "null"]
99+
if non_null_types:
100+
first_type = non_null_types[0].get("type", "string")
101+
param_type = self._map_json_type_to_python(
102+
first_type, non_null_types[0]
103+
)
104+
else:
105+
json_type = param_details.get("type", "string")
106+
param_type = self._map_json_type_to_python(json_type, param_details)
107+
is_nullable = json_type == "null"
108+
109+
return is_nullable, param_type
110+
111+
def _create_field_definition(
112+
self, param_type: type, is_required: bool, is_nullable: bool, param_desc: str
113+
) -> tuple:
114+
"""Create Pydantic field definition based on type, requirement, and nullability."""
115+
if is_nullable:
116+
return (
117+
Optional[param_type],
118+
Field(default=None, description=param_desc),
119+
)
120+
elif is_required:
121+
return (
122+
param_type,
123+
Field(description=param_desc),
124+
)
125+
else:
126+
return (
127+
Optional[param_type],
128+
Field(default=None, description=param_desc),
129+
)
130+
131+
def _map_json_type_to_python(
132+
self, json_type: str, param_details: Dict[str, Any]
133+
) -> type:
134+
"""Map JSON schema types to Python types."""
135+
type_mapping = {
136+
"string": str,
137+
"integer": int,
138+
"number": float,
139+
"boolean": bool,
140+
"array": list,
141+
"object": dict,
142+
}
143+
return type_mapping.get(json_type, str)
144+
145+
def _get_required_nullable_fields(self) -> List[str]:
146+
"""Get a list of required nullable fields from the action schema."""
147+
schema_props, required = self._extract_schema_info(self.action_schema)
148+
149+
required_nullable_fields = []
150+
for param_name in required:
151+
param_details = schema_props.get(param_name, {})
152+
is_nullable, _ = self._analyze_field_type(param_details)
153+
if is_nullable:
154+
required_nullable_fields.append(param_name)
155+
156+
return required_nullable_fields
157+
92158
def _run(self, **kwargs) -> str:
93159
"""Execute the specific enterprise action with validated parameters."""
94160
try:
161+
required_nullable_fields = self._get_required_nullable_fields()
162+
163+
for field_name in required_nullable_fields:
164+
if field_name not in kwargs:
165+
kwargs[field_name] = None
166+
95167
params = {k: v for k, v in kwargs.items() if v is not None}
96168

97169
api_url = f"{self.enterprise_action_kit_project_url}/{self.enterprise_action_kit_project_id}/actions"

0 commit comments

Comments
 (0)