Skip to content

Commit a8bdb58

Browse files
revert changes to atomic agent
1 parent 86e8396 commit a8bdb58

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

atomic-agents/atomic_agents/agents/atomic_agent.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -97,19 +97,8 @@ class AtomicAgent[InputSchema: BaseIOSchema, OutputSchema: BaseIOSchema]:
9797
def __init_subclass__(cls, **kwargs):
9898
"""
9999
Hook called when a class is subclassed.
100-
101-
Captures generic type parameters during class creation and stores them as class attributes
102-
to work around the unreliable __orig_class__ attribute in modern Python generic syntax.
103100
"""
104101
super().__init_subclass__(**kwargs)
105-
if hasattr(cls, "__orig_bases__"):
106-
for base in cls.__orig_bases__:
107-
if hasattr(base, "__origin__") and base.__origin__ is AtomicAgent:
108-
args = get_args(base)
109-
if len(args) == 2:
110-
cls._input_schema_cls = args[0]
111-
cls._output_schema_cls = args[1]
112-
break
113102

114103
def __init__(self, config: AgentConfig):
115104
"""
@@ -135,11 +124,33 @@ def reset_history(self):
135124

136125
@property
137126
def input_schema(self) -> Type[BaseIOSchema]:
138-
return getattr(self.__class__, "_input_schema_cls", BasicChatInputSchema)
127+
"""
128+
Returns the input schema class for the agent.
129+
130+
Returns:
131+
Type[BaseIOSchema]: The input schema class.
132+
"""
133+
if hasattr(self, "__orig_class__"):
134+
from typing import get_args
135+
args = get_args(self.__orig_class__)
136+
if len(args) >= 1:
137+
return args[0]
138+
return BasicChatInputSchema
139139

140140
@property
141141
def output_schema(self) -> Type[BaseIOSchema]:
142-
return getattr(self.__class__, "_output_schema_cls", BasicChatOutputSchema)
142+
"""
143+
Returns the output schema class for the agent.
144+
145+
Returns:
146+
Type[BaseIOSchema]: The output schema class.
147+
"""
148+
if hasattr(self, "__orig_class__"):
149+
from typing import get_args
150+
args = get_args(self.__orig_class__)
151+
if len(args) >= 2:
152+
return args[1]
153+
return BasicChatOutputSchema
143154

144155
def _prepare_messages(self):
145156
if self.system_role is None:

0 commit comments

Comments
 (0)