Skip to content

Commit d05aa9b

Browse files
committed
test: adds tests to confirm attribute extraction from classes
1 parent 5be86dd commit d05aa9b

File tree

1 file changed

+182
-0
lines changed

1 file changed

+182
-0
lines changed

scripts/microgenerator/tests/unit/test_generate_analyzer.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,185 @@ def test_import_extraction(self, code_snippet, expected_imports):
8787
expected = sorted(expected_imports)
8888

8989
assert extracted == expected
90+
91+
92+
# --- Tests CodeAnalyzer handling of Attributes ---
93+
94+
95+
class TestCodeAnalyzerAttributes:
96+
def assert_structures_equal(self, extracted, expected):
97+
assert len(extracted) == len(expected)
98+
for i in range(len(extracted)):
99+
ext_class = extracted[i]
100+
exp_class = expected[i]
101+
assert ext_class["class_name"] == exp_class["class_name"]
102+
assert ext_class["methods"] == exp_class["methods"]
103+
# Sort attributes by name for order-independent comparison │
104+
assert sorted(ext_class["attributes"], key=lambda x: x["name"]) == sorted(
105+
exp_class["attributes"], key=lambda x: x["name"]
106+
)
107+
108+
@pytest.mark.parametrize(
109+
"code_snippet, expected_structure",
110+
[
111+
pytest.param(
112+
"""
113+
class MyClass:
114+
CLASS_VAR = 123
115+
""",
116+
[
117+
{
118+
"class_name": "MyClass",
119+
"methods": [],
120+
"attributes": [{"name": "CLASS_VAR", "type": None}],
121+
}
122+
],
123+
id="class_var_assign",
124+
),
125+
pytest.param(
126+
"""
127+
class MyClass:
128+
class_var: int = 456
129+
""",
130+
[
131+
{
132+
"class_name": "MyClass",
133+
"methods": [],
134+
"attributes": [{"name": "class_var", "type": "int"}],
135+
}
136+
],
137+
id="class_var_annassign",
138+
),
139+
pytest.param(
140+
"""
141+
class MyClass:
142+
class_var: int
143+
""",
144+
[
145+
{
146+
"class_name": "MyClass",
147+
"methods": [],
148+
"attributes": [{"name": "class_var", "type": "int"}],
149+
}
150+
],
151+
id="class_var_annassign_no_value",
152+
),
153+
pytest.param(
154+
"""
155+
class MyClass:
156+
def __init__(self):
157+
self.instance_var = 789
158+
""",
159+
[
160+
{
161+
"class_name": "MyClass",
162+
"methods": [
163+
{
164+
"method_name": "__init__",
165+
"args": [{"name": "self", "type": None}],
166+
"return_type": None,
167+
}
168+
],
169+
"attributes": [{"name": "instance_var", "type": None}],
170+
}
171+
],
172+
id="instance_var_assign",
173+
),
174+
pytest.param(
175+
"""
176+
class MyClass:
177+
def __init__(self):
178+
self.instance_var: str = 'hello'
179+
""",
180+
[
181+
{
182+
"class_name": "MyClass",
183+
"methods": [
184+
{
185+
"method_name": "__init__",
186+
"args": [{"name": "self", "type": None}],
187+
"return_type": None,
188+
}
189+
],
190+
"attributes": [{"name": "instance_var", "type": "str"}],
191+
}
192+
],
193+
id="instance_var_annassign",
194+
),
195+
pytest.param(
196+
"""
197+
class MyClass:
198+
def __init__(self):
199+
self.instance_var: str
200+
""",
201+
[
202+
{
203+
"class_name": "MyClass",
204+
"methods": [
205+
{
206+
"method_name": "__init__",
207+
"args": [{"name": "self", "type": None}],
208+
"return_type": None,
209+
}
210+
],
211+
"attributes": [{"name": "instance_var", "type": "str"}],
212+
}
213+
],
214+
id="instance_var_annassign_no_value",
215+
),
216+
pytest.param(
217+
"""
218+
class MyClass:
219+
VAR_A = 1
220+
var_b: int = 2
221+
def __init__(self):
222+
self.var_c = 3
223+
self.var_d: float = 4.0
224+
""",
225+
[
226+
{
227+
"class_name": "MyClass",
228+
"methods": [
229+
{
230+
"method_name": "__init__",
231+
"args": [{"name": "self", "type": None}],
232+
"return_type": None,
233+
}
234+
],
235+
"attributes": [
236+
{"name": "VAR_A", "type": None},
237+
{"name": "var_b", "type": "int"},
238+
{"name": "var_c", "type": None},
239+
{"name": "var_d", "type": "float"},
240+
],
241+
}
242+
],
243+
id="mixed_attributes",
244+
),
245+
pytest.param(
246+
"a = 123 # Module level",
247+
[],
248+
id="module_level_assign",
249+
),
250+
pytest.param(
251+
"b: int = 456 # Module level",
252+
[],
253+
id="module_level_annassign",
254+
),
255+
],
256+
)
257+
def test_attribute_extraction(self, code_snippet, expected_structure):
258+
analyzer = CodeAnalyzer()
259+
tree = ast.parse(code_snippet)
260+
analyzer.visit(tree)
261+
262+
extracted = analyzer.structure
263+
# Normalize attributes for order-independent comparison
264+
for item in extracted:
265+
if "attributes" in item:
266+
item["attributes"].sort(key=lambda x: x["name"])
267+
for item in expected_structure:
268+
if "attributes" in item:
269+
item["attributes"].sort(key=lambda x: x["name"])
270+
271+
assert extracted == expected_structure

0 commit comments

Comments
 (0)