Skip to content

Commit 86e8396

Browse files
resolve schema properties returning BaseIOSchema instead of custom types
Add __init_subclass__ hooks to capture generic type parameters at class creation time and simplify schema properties to use class-level attributes. This fixes issue #161 where BaseTool.input_schema and output_schema were incorrectly returning BaseIOSchema instead of the custom schema classes specified in generic type parameters like BaseTool[MyInputSchema, MyOutputSchema]. - Add __init_subclass__ to BaseTool and AtomicAgent classes - Replace broken __orig_class__ logic with reliable class-level attributes - Simplify schema properties to single-line getattr calls
1 parent 71ec731 commit 86e8396

File tree

2 files changed

+38
-24
lines changed

2 files changed

+38
-24
lines changed

atomic-agents/atomic_agents/agents/atomic_agent.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,23 @@ class AtomicAgent[InputSchema: BaseIOSchema, OutputSchema: BaseIOSchema]:
9494
- Use this for parameters like 'temperature', 'max_tokens', etc.
9595
"""
9696

97+
def __init_subclass__(cls, **kwargs):
98+
"""
99+
Hook called when a class is subclassed.
100+
101+
Captures generic type parameters during class creation and stores them as class attributes
102+
to work around the unreliable __orig_class__ attribute in modern Python generic syntax.
103+
"""
104+
super().__init_subclass__(**kwargs)
105+
if hasattr(cls, "__orig_bases__"):
106+
for base in cls.__orig_bases__:
107+
if hasattr(base, "__origin__") and base.__origin__ is AtomicAgent:
108+
args = get_args(base)
109+
if len(args) == 2:
110+
cls._input_schema_cls = args[0]
111+
cls._output_schema_cls = args[1]
112+
break
113+
97114
def __init__(self, config: AgentConfig):
98115
"""
99116
Initializes the AtomicAgent.
@@ -118,21 +135,11 @@ def reset_history(self):
118135

119136
@property
120137
def input_schema(self) -> Type[BaseIOSchema]:
121-
if hasattr(self, "__orig_class__"):
122-
TI, _ = get_args(self.__orig_class__)
123-
else:
124-
TI = BasicChatInputSchema
125-
126-
return TI
138+
return getattr(self.__class__, "_input_schema_cls", BasicChatInputSchema)
127139

128140
@property
129141
def output_schema(self) -> Type[BaseIOSchema]:
130-
if hasattr(self, "__orig_class__"):
131-
_, TO = get_args(self.__orig_class__)
132-
else:
133-
TO = BasicChatOutputSchema
134-
135-
return TO
142+
return getattr(self.__class__, "_output_schema_cls", BasicChatOutputSchema)
136143

137144
def _prepare_messages(self):
138145
if self.system_role is None:

atomic-agents/atomic_agents/base/base_tool.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,23 @@ class BaseTool[InputSchema: BaseIOSchema, OutputSchema: BaseIOSchema](ABC):
3838
tool_description (str): Description of the tool, derived from the input schema's description or overridden by the config.
3939
"""
4040

41+
def __init_subclass__(cls, **kwargs):
42+
"""
43+
Hook called when a class is subclassed.
44+
45+
Captures generic type parameters during class creation and stores them as class attributes
46+
to work around the unreliable __orig_class__ attribute in modern Python generic syntax.
47+
"""
48+
super().__init_subclass__(**kwargs)
49+
if hasattr(cls, "__orig_bases__"):
50+
for base in cls.__orig_bases__:
51+
if hasattr(base, "__origin__") and base.__origin__ is BaseTool:
52+
args = get_args(base)
53+
if len(args) == 2:
54+
cls._input_schema_cls = args[0]
55+
cls._output_schema_cls = args[1]
56+
break
57+
4158
def __init__(self, config: BaseToolConfig = BaseToolConfig()):
4259
"""
4360
Initializes the BaseTool with an optional configuration override.
@@ -55,12 +72,7 @@ def input_schema(self) -> Type[InputSchema]:
5572
Returns:
5673
Type[InputSchema]: The input schema class.
5774
"""
58-
if hasattr(self, "__orig_class__"):
59-
TI, _ = get_args(self.__orig_class__)
60-
else:
61-
TI = BaseIOSchema
62-
63-
return TI
75+
return getattr(self.__class__, "_input_schema_cls", BaseIOSchema)
6476

6577
@property
6678
def output_schema(self) -> Type[OutputSchema]:
@@ -70,12 +82,7 @@ def output_schema(self) -> Type[OutputSchema]:
7082
Returns:
7183
Type[OutputSchema]: The output schema class.
7284
"""
73-
if hasattr(self, "__orig_class__"):
74-
_, TO = get_args(self.__orig_class__)
75-
else:
76-
TO = BaseIOSchema
77-
78-
return TO
85+
return getattr(self.__class__, "_output_schema_cls", BaseIOSchema)
7986

8087
@property
8188
def tool_name(self) -> str:

0 commit comments

Comments
 (0)