Skip to content

Commit 4b41cd6

Browse files
resolve generic type schema properties in BaseTool
Add __init_subclass__ hook to capture generic type parameters during class definition, fixing issue where input_schema and output_schema properties returned BaseIOSchema instead of the specified custom schema types. - Implement type parameter capture for inheritance pattern - Maintain fallback for dynamic instantiation and edge cases - Add focused test coverage for schema resolution - Resolves GitHub issue #161
1 parent 0b9b564 commit 4b41cd6

File tree

2 files changed

+61
-7
lines changed

2 files changed

+61
-7
lines changed

atomic-agents/atomic_agents/base/base_tool.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Type, get_args
1+
from typing import Optional, Type, get_args, get_origin
22
from abc import ABC, abstractmethod
33
from pydantic import BaseModel
44

@@ -47,6 +47,23 @@ def __init__(self, config: BaseToolConfig = BaseToolConfig()):
4747
"""
4848
self.config = config
4949

50+
def __init_subclass__(cls, **kwargs):
51+
"""
52+
Hook called when a class is subclassed.
53+
54+
Captures generic type parameters during class creation and stores them as class attributes
55+
to work around the unreliable __orig_class__ attribute in modern Python generic syntax.
56+
"""
57+
super().__init_subclass__(**kwargs)
58+
if hasattr(cls, "__orig_bases__"):
59+
for base in cls.__orig_bases__:
60+
if get_origin(base) is BaseTool:
61+
args = get_args(base)
62+
if len(args) == 2:
63+
cls._input_schema_cls = args[0]
64+
cls._output_schema_cls = args[1]
65+
break
66+
5067
@property
5168
def input_schema(self) -> Type[InputSchema]:
5269
"""
@@ -55,12 +72,17 @@ def input_schema(self) -> Type[InputSchema]:
5572
Returns:
5673
Type[InputSchema]: The input schema class.
5774
"""
75+
# Inheritance pattern: MyTool(BaseTool[Schema1, Schema2])
76+
if hasattr(self.__class__, "_input_schema_cls"):
77+
return self.__class__._input_schema_cls
78+
79+
# Dynamic instantiation: MockTool[Schema1, Schema2]()
5880
if hasattr(self, "__orig_class__"):
5981
TI, _ = get_args(self.__orig_class__)
60-
else:
61-
TI = BaseIOSchema
82+
return TI
6283

63-
return TI
84+
# No type info available: MockTool()
85+
return BaseIOSchema
6486

6587
@property
6688
def output_schema(self) -> Type[OutputSchema]:
@@ -70,12 +92,17 @@ def output_schema(self) -> Type[OutputSchema]:
7092
Returns:
7193
Type[OutputSchema]: The output schema class.
7294
"""
95+
# Inheritance pattern: MyTool(BaseTool[Schema1, Schema2])
96+
if hasattr(self.__class__, "_output_schema_cls"):
97+
return self.__class__._output_schema_cls
98+
99+
# Dynamic instantiation: MockTool[Schema1, Schema2]()
73100
if hasattr(self, "__orig_class__"):
74101
_, TO = get_args(self.__orig_class__)
75-
else:
76-
TO = BaseIOSchema
102+
return TO
77103

78-
return TO
104+
# No type info available: MockTool()
105+
return BaseIOSchema
79106

80107
@property
81108
def tool_name(self) -> str:

atomic-agents/tests/base/test_base_tool.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,30 @@ def test_base_tool_config_optional_fields():
9494
config = BaseToolConfig()
9595
assert hasattr(config, "title")
9696
assert hasattr(config, "description")
97+
98+
99+
# Test for GitHub issue #161 fix: proper schema resolution
100+
def test_base_tool_schema_resolution():
101+
"""Test that input_schema and output_schema return correct types (not BaseIOSchema)"""
102+
103+
class CustomInput(BaseIOSchema):
104+
"""Custom input schema for testing"""
105+
106+
name: str
107+
108+
class CustomOutput(BaseIOSchema):
109+
"""Custom output schema for testing"""
110+
111+
result: str
112+
113+
class TestTool(BaseTool[CustomInput, CustomOutput]):
114+
def run(self, params: CustomInput) -> CustomOutput:
115+
return CustomOutput(result=f"processed_{params.name}")
116+
117+
tool = TestTool()
118+
119+
# These should return the specific types, not BaseIOSchema
120+
assert tool.input_schema == CustomInput
121+
assert tool.output_schema == CustomOutput
122+
assert tool.input_schema != BaseIOSchema
123+
assert tool.output_schema != BaseIOSchema

0 commit comments

Comments
 (0)