Skip to content

Commit 087e656

Browse files
committed
v0.0.64
1 parent 022989c commit 087e656

File tree

2 files changed

+45
-23
lines changed

2 files changed

+45
-23
lines changed

agixtsdk/__init__.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,7 +1678,7 @@ def plan_task(
16781678
def _generate_detailed_schema(self, model: Type[BaseModel], depth: int = 0) -> str:
16791679
"""
16801680
Recursively generates a detailed schema representation of a Pydantic model,
1681-
including nested models.
1681+
including nested models and complex types.
16821682
"""
16831683
fields = model.__annotations__
16841684
field_descriptions = []
@@ -1687,29 +1687,55 @@ def _generate_detailed_schema(self, model: Type[BaseModel], depth: int = 0) -> s
16871687
for field, field_type in fields.items():
16881688
description = f"{indent}{field}: "
16891689

1690-
if get_origin(field_type) == Union:
1691-
field_type = get_args(field_type)[0]
1690+
origin_type = get_origin(field_type)
1691+
if origin_type is None:
1692+
origin_type = field_type
16921693

1693-
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
1694-
description += f"Nested Model:\n{self._generate_detailed_schema(field_type, depth + 1)}"
1695-
elif get_origin(field_type) == List:
1694+
if issubclass(origin_type, BaseModel):
1695+
description += f"Nested Model:\n{self._generate_detailed_schema(origin_type, depth + 1)}"
1696+
elif origin_type == List:
16961697
list_type = get_args(field_type)[0]
16971698
if isinstance(list_type, type) and issubclass(list_type, BaseModel):
16981699
description += f"List of Nested Model:\n{self._generate_detailed_schema(list_type, depth + 1)}"
1700+
elif get_origin(list_type) == Union:
1701+
union_types = get_args(list_type)
1702+
description += f"List of Union:\n"
1703+
for union_type in union_types:
1704+
if issubclass(union_type, BaseModel):
1705+
description += f"{indent} - Nested Model:\n{self._generate_detailed_schema(union_type, depth + 2)}"
1706+
else:
1707+
description += f"{indent} - {union_type.__name__}\n"
16991708
else:
1700-
description += f"List[{list_type.__name__}]"
1701-
elif get_origin(field_type) == Dict:
1709+
description += f"List[{self._get_type_name(list_type)}]"
1710+
elif origin_type == Dict:
17021711
key_type, value_type = get_args(field_type)
1703-
description += f"Dict[{key_type.__name__}, {value_type.__name__}]"
1704-
elif isinstance(field_type, type) and issubclass(field_type, Enum):
1705-
enum_values = ", ".join([f"{e.name} = {e.value}" for e in field_type])
1706-
description += f"{field_type.__name__} (Enum values: {enum_values})"
1712+
description += f"Dict[{self._get_type_name(key_type)}, {self._get_type_name(value_type)}]"
1713+
elif origin_type == Union:
1714+
union_types = get_args(field_type)
1715+
description += "Union of:\n"
1716+
for union_type in union_types:
1717+
if issubclass(union_type, BaseModel):
1718+
description += f"{indent} - Nested Model:\n{self._generate_detailed_schema(union_type, depth + 2)}"
1719+
else:
1720+
description += (
1721+
f"{indent} - {self._get_type_name(union_type)}\n"
1722+
)
1723+
elif issubclass(origin_type, Enum):
1724+
enum_values = ", ".join([f"{e.name} = {e.value}" for e in origin_type])
1725+
description += f"{origin_type.__name__} (Enum values: {enum_values})"
17071726
else:
1708-
description += f"{field_type.__name__}"
1727+
description += self._get_type_name(origin_type)
17091728

17101729
field_descriptions.append(description)
1730+
17111731
return "\n".join(field_descriptions)
17121732

1733+
def _get_type_name(self, type_):
1734+
"""Helper method to get the name of a type, handling some special cases."""
1735+
if hasattr(type_, "__name__"):
1736+
return type_.__name__
1737+
return str(type_).replace("typing.", "")
1738+
17131739
def convert_to_model(
17141740
self,
17151741
input_string: str,
@@ -1732,12 +1758,12 @@ def convert_to_model(
17321758
"""
17331759
input_string = str(input_string)
17341760
schema = self._generate_detailed_schema(model)
1735-
1761+
17361762
if "user_input" in kwargs:
17371763
del kwargs["user_input"]
17381764
if "schema" in kwargs:
17391765
del kwargs["schema"]
1740-
1766+
17411767
response = self.prompt_agent(
17421768
agent_name=agent_name,
17431769
prompt_name="Convert to Model",
@@ -1747,12 +1773,12 @@ def convert_to_model(
17471773
**kwargs,
17481774
},
17491775
)
1750-
1776+
17511777
if "```json" in response:
17521778
response = response.split("```json")[1].split("```")[0].strip()
17531779
elif "```" in response:
17541780
response = response.split("```")[1].strip()
1755-
1781+
17561782
try:
17571783
response = json.loads(response)
17581784
if response_type == "json":
@@ -1766,11 +1792,7 @@ def convert_to_model(
17661792
f"Error: {e} . Failed to convert the response to the model after {max_failures} attempts. Response: {response}"
17671793
)
17681794
self.failures = 0
1769-
return (
1770-
response
1771-
if response
1772-
else "Failed to convert the response to the model."
1773-
)
1795+
return response if response else "Failed to convert the response to the model."
17741796
else:
17751797
self.failures = 1
17761798
print(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
setup(
1010
name="agixtsdk",
11-
version="0.0.63",
11+
version="0.0.64",
1212
description="The AGiXT SDK for Python.",
1313
long_description=long_description,
1414
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)