Skip to content

Commit 7a7b331

Browse files
committed
feat: adds type extraction tests
1 parent 44eb401 commit 7a7b331

File tree

2 files changed

+125
-3
lines changed

2 files changed

+125
-3
lines changed

scripts/microgenerator/generate.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@
2727
import argparse
2828
import glob
2929
import logging
30-
import re
3130
from collections import defaultdict
32-
from typing import List, Dict, Any, Iterator
31+
from typing import List, Dict, Any
3332

3433
from . import name_utils
3534
from . import utils
@@ -83,6 +82,11 @@ def _get_type_str(self, node: ast.AST | None) -> str | None:
8382
# Handles forward references as strings, e.g., '"Dataset"'
8483
if isinstance(node, ast.Constant):
8584
return repr(node.value)
85+
# Handles | union types, e.g., int | float
86+
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
87+
left_str = self._get_type_str(node.left)
88+
right_str = self._get_type_str(node.right)
89+
return f"{left_str} | {right_str}"
8690
return None # Fallback for unhandled types
8791

8892
def _collect_types_from_node(self, node: ast.AST | None) -> None:

scripts/microgenerator/tests/unit/test_generate_analyzer.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,24 @@
1515
#
1616

1717
import ast
18-
from scripts.microgenerator.generate import CodeAnalyzer
18+
import pytest
19+
from scripts.microgenerator.generate import parse_code, CodeAnalyzer
1920

2021

22+
# --- Mock Types ---
23+
class MyClass:
24+
pass
25+
26+
27+
class AnotherClass:
28+
pass
29+
30+
31+
class YetAnotherClass:
32+
pass
33+
34+
35+
# --- Existing Tests ---
2136
def test_codeanalyzer_finds_class():
2237
code = """
2338
class MyClass:
@@ -100,3 +115,106 @@ class MyClass:
100115
assert len(analyzer.structure) == 1
101116
assert analyzer.structure[0]["class_name"] == "MyClass"
102117
assert len(analyzer.structure[0]["methods"]) == 0
118+
119+
120+
# --- Test Data for Parameterization ---
121+
TYPE_TEST_CASES = [
122+
pytest.param(
123+
"""class TestClass:
124+
def func(self, a: int, b: str) -> bool: return True""",
125+
[("a", "int"), ("b", "str")],
126+
"bool",
127+
id="simple_types",
128+
),
129+
pytest.param(
130+
"""from typing import Optional
131+
class TestClass:
132+
def func(self, a: Optional[int]) -> str | None: return 'hello'""",
133+
[("a", "Optional[int]")],
134+
"str | None",
135+
id="optional_union_none",
136+
),
137+
pytest.param(
138+
"""from typing import Union
139+
class TestClass:
140+
def func(self, a: int | float, b: Union[str, bytes]) -> None: pass""",
141+
[("a", "int | float"), ("b", "Union[str, bytes]")],
142+
"None",
143+
id="union_types",
144+
),
145+
pytest.param(
146+
"""from typing import List, Dict, Tuple
147+
class TestClass:
148+
def func(self, a: List[int], b: Dict[str, float]) -> Tuple[int, str]: return (1, 'a')""",
149+
[("a", "List[int]"), ("b", "Dict[str, float]")],
150+
"Tuple[int, str]",
151+
id="generic_types",
152+
),
153+
pytest.param(
154+
"""import datetime
155+
from scripts.microgenerator.tests.unit.test_generate_analyzer import MyClass
156+
class TestClass:
157+
def func(self, a: datetime.date, b: MyClass) -> MyClass: return b""",
158+
[("a", "datetime.date"), ("b", "MyClass")],
159+
"MyClass",
160+
id="imported_types",
161+
),
162+
pytest.param(
163+
"""from scripts.microgenerator.tests.unit.test_generate_analyzer import AnotherClass, YetAnotherClass
164+
class TestClass:
165+
def func(self, a: 'AnotherClass') -> 'YetAnotherClass': return AnotherClass()""",
166+
[("a", "'AnotherClass'")],
167+
"'YetAnotherClass'",
168+
id="forward_refs",
169+
),
170+
pytest.param(
171+
"""class TestClass:
172+
def func(self, a, b): return a + b""",
173+
[("a", None), ("b", None)], # No annotations means type is None
174+
None,
175+
id="no_annotations",
176+
),
177+
pytest.param(
178+
"""from typing import List, Optional, Dict, Union, Any
179+
class TestClass:
180+
def func(self, a: List[Optional[Dict[str, Union[int, str]]]]) -> Dict[str, Any]: return {}""",
181+
[("a", "List[Optional[Dict[str, Union[int, str]]]]")],
182+
"Dict[str, Any]",
183+
id="complex_nested",
184+
),
185+
pytest.param(
186+
"""from typing import Literal
187+
class TestClass:
188+
def func(self, a: Literal['one', 'two']) -> Literal[True]: return True""",
189+
[("a", "Literal['one', 'two']")],
190+
"Literal[True]",
191+
id="literal_type",
192+
),
193+
]
194+
195+
196+
# --- Tests ---
197+
class TestCodeAnalyzerArgsReturns:
198+
@pytest.mark.parametrize(
199+
"code_snippet, expected_args, expected_return", TYPE_TEST_CASES
200+
)
201+
def test_type_extraction(self, code_snippet, expected_args, expected_return):
202+
structure, imports, types = parse_code(code_snippet)
203+
204+
assert len(structure) == 1, "Should parse one class"
205+
class_info = structure[0]
206+
assert class_info["class_name"] == "TestClass"
207+
208+
assert len(class_info["methods"]) == 1, "Should find one method"
209+
method_info = class_info["methods"][0]
210+
assert method_info["method_name"] == "func"
211+
212+
# Extract args, skipping 'self'
213+
extracted_args = []
214+
for arg in method_info.get("args", []):
215+
if arg["name"] == "self":
216+
continue
217+
extracted_args.append((arg["name"], arg["type"]))
218+
219+
assert extracted_args == expected_args
220+
assert method_info.get("return_type") == expected_return

0 commit comments

Comments
 (0)