@@ -1678,7 +1678,7 @@ def plan_task(
16781678 def _generate_detailed_schema (self , model : Type [BaseModel ], depth : int = 0 ) -> str :
16791679 """
16801680 Recursively generates a detailed schema representation of a Pydantic model,
1681- including nested models.
1681+ including nested models and complex types .
16821682 """
16831683 fields = model .__annotations__
16841684 field_descriptions = []
@@ -1687,29 +1687,55 @@ def _generate_detailed_schema(self, model: Type[BaseModel], depth: int = 0) -> s
16871687 for field , field_type in fields .items ():
16881688 description = f"{ indent } { field } : "
16891689
1690- if get_origin (field_type ) == Union :
1691- field_type = get_args (field_type )[0 ]
1690+ origin_type = get_origin (field_type )
1691+ if origin_type is None :
1692+ origin_type = field_type
16921693
1693- if isinstance ( field_type , type ) and issubclass (field_type , BaseModel ):
1694- description += f"Nested Model:\n { self ._generate_detailed_schema (field_type , depth + 1 )} "
1695- elif get_origin ( field_type ) == List :
1694+ if issubclass (origin_type , BaseModel ):
1695+ description += f"Nested Model:\n { self ._generate_detailed_schema (origin_type , depth + 1 )} "
1696+ elif origin_type == List :
16961697 list_type = get_args (field_type )[0 ]
16971698 if isinstance (list_type , type ) and issubclass (list_type , BaseModel ):
16981699 description += f"List of Nested Model:\n { self ._generate_detailed_schema (list_type , depth + 1 )} "
1700+ elif get_origin (list_type ) == Union :
1701+ union_types = get_args (list_type )
1702+ description += f"List of Union:\n "
1703+ for union_type in union_types :
1704+ if issubclass (union_type , BaseModel ):
1705+ description += f"{ indent } - Nested Model:\n { self ._generate_detailed_schema (union_type , depth + 2 )} "
1706+ else :
1707+ description += f"{ indent } - { union_type .__name__ } \n "
16991708 else :
1700- description += f"List[{ list_type . __name__ } ]"
1701- elif get_origin ( field_type ) == Dict :
1709+ description += f"List[{ self . _get_type_name ( list_type ) } ]"
1710+ elif origin_type == Dict :
17021711 key_type , value_type = get_args (field_type )
1703- description += f"Dict[{ key_type .__name__ } , { value_type .__name__ } ]"
1704- elif isinstance (field_type , type ) and issubclass (field_type , Enum ):
1705- enum_values = ", " .join ([f"{ e .name } = { e .value } " for e in field_type ])
1706- description += f"{ field_type .__name__ } (Enum values: { enum_values } )"
1712+ description += f"Dict[{ self ._get_type_name (key_type )} , { self ._get_type_name (value_type )} ]"
1713+ elif origin_type == Union :
1714+ union_types = get_args (field_type )
1715+ description += "Union of:\n "
1716+ for union_type in union_types :
1717+ if issubclass (union_type , BaseModel ):
1718+ description += f"{ indent } - Nested Model:\n { self ._generate_detailed_schema (union_type , depth + 2 )} "
1719+ else :
1720+ description += (
1721+ f"{ indent } - { self ._get_type_name (union_type )} \n "
1722+ )
1723+ elif issubclass (origin_type , Enum ):
1724+ enum_values = ", " .join ([f"{ e .name } = { e .value } " for e in origin_type ])
1725+ description += f"{ origin_type .__name__ } (Enum values: { enum_values } )"
17071726 else :
1708- description += f" { field_type . __name__ } "
1727+ description += self . _get_type_name ( origin_type )
17091728
17101729 field_descriptions .append (description )
1730+
17111731 return "\n " .join (field_descriptions )
17121732
1733+ def _get_type_name (self , type_ ):
1734+ """Helper method to get the name of a type, handling some special cases."""
1735+ if hasattr (type_ , "__name__" ):
1736+ return type_ .__name__
1737+ return str (type_ ).replace ("typing." , "" )
1738+
17131739 def convert_to_model (
17141740 self ,
17151741 input_string : str ,
@@ -1732,12 +1758,12 @@ def convert_to_model(
17321758 """
17331759 input_string = str (input_string )
17341760 schema = self ._generate_detailed_schema (model )
1735-
1761+
17361762 if "user_input" in kwargs :
17371763 del kwargs ["user_input" ]
17381764 if "schema" in kwargs :
17391765 del kwargs ["schema" ]
1740-
1766+
17411767 response = self .prompt_agent (
17421768 agent_name = agent_name ,
17431769 prompt_name = "Convert to Model" ,
@@ -1747,12 +1773,12 @@ def convert_to_model(
17471773 ** kwargs ,
17481774 },
17491775 )
1750-
1776+
17511777 if "```json" in response :
17521778 response = response .split ("```json" )[1 ].split ("```" )[0 ].strip ()
17531779 elif "```" in response :
17541780 response = response .split ("```" )[1 ].strip ()
1755-
1781+
17561782 try :
17571783 response = json .loads (response )
17581784 if response_type == "json" :
@@ -1766,11 +1792,7 @@ def convert_to_model(
17661792 f"Error: { e } . Failed to convert the response to the model after { max_failures } attempts. Response: { response } "
17671793 )
17681794 self .failures = 0
1769- return (
1770- response
1771- if response
1772- else "Failed to convert the response to the model."
1773- )
1795+ return response if response else "Failed to convert the response to the model."
17741796 else :
17751797 self .failures = 1
17761798 print (
0 commit comments