|
15 | 15 | # |
16 | 16 |
|
17 | 17 | import ast |
18 | | -from scripts.microgenerator.generate import CodeAnalyzer |
| 18 | +import pytest |
| 19 | +from scripts.microgenerator.generate import parse_code, CodeAnalyzer |
19 | 20 |
|
20 | 21 |
|
| 22 | +# --- Mock Types --- |
| 23 | +class MyClass: |
| 24 | + pass |
| 25 | + |
| 26 | + |
| 27 | +class AnotherClass: |
| 28 | + pass |
| 29 | + |
| 30 | + |
| 31 | +class YetAnotherClass: |
| 32 | + pass |
| 33 | + |
| 34 | + |
| 35 | +# --- Existing Tests --- |
21 | 36 | def test_codeanalyzer_finds_class(): |
22 | 37 | code = """ |
23 | 38 | class MyClass: |
@@ -100,3 +115,106 @@ class MyClass: |
100 | 115 | assert len(analyzer.structure) == 1 |
101 | 116 | assert analyzer.structure[0]["class_name"] == "MyClass" |
102 | 117 | assert len(analyzer.structure[0]["methods"]) == 0 |
| 118 | + |
| 119 | + |
| 120 | +# --- Test Data for Parameterization --- |
| 121 | +TYPE_TEST_CASES = [ |
| 122 | + pytest.param( |
| 123 | + """class TestClass: |
| 124 | + def func(self, a: int, b: str) -> bool: return True""", |
| 125 | + [("a", "int"), ("b", "str")], |
| 126 | + "bool", |
| 127 | + id="simple_types", |
| 128 | + ), |
| 129 | + pytest.param( |
| 130 | + """from typing import Optional |
| 131 | +class TestClass: |
| 132 | + def func(self, a: Optional[int]) -> str | None: return 'hello'""", |
| 133 | + [("a", "Optional[int]")], |
| 134 | + "str | None", |
| 135 | + id="optional_union_none", |
| 136 | + ), |
| 137 | + pytest.param( |
| 138 | + """from typing import Union |
| 139 | +class TestClass: |
| 140 | + def func(self, a: int | float, b: Union[str, bytes]) -> None: pass""", |
| 141 | + [("a", "int | float"), ("b", "Union[str, bytes]")], |
| 142 | + "None", |
| 143 | + id="union_types", |
| 144 | + ), |
| 145 | + pytest.param( |
| 146 | + """from typing import List, Dict, Tuple |
| 147 | +class TestClass: |
| 148 | + def func(self, a: List[int], b: Dict[str, float]) -> Tuple[int, str]: return (1, 'a')""", |
| 149 | + [("a", "List[int]"), ("b", "Dict[str, float]")], |
| 150 | + "Tuple[int, str]", |
| 151 | + id="generic_types", |
| 152 | + ), |
| 153 | + pytest.param( |
| 154 | + """import datetime |
| 155 | +from scripts.microgenerator.tests.unit.test_generate_analyzer import MyClass |
| 156 | +class TestClass: |
| 157 | + def func(self, a: datetime.date, b: MyClass) -> MyClass: return b""", |
| 158 | + [("a", "datetime.date"), ("b", "MyClass")], |
| 159 | + "MyClass", |
| 160 | + id="imported_types", |
| 161 | + ), |
| 162 | + pytest.param( |
| 163 | + """from scripts.microgenerator.tests.unit.test_generate_analyzer import AnotherClass, YetAnotherClass |
| 164 | +class TestClass: |
| 165 | + def func(self, a: 'AnotherClass') -> 'YetAnotherClass': return AnotherClass()""", |
| 166 | + [("a", "'AnotherClass'")], |
| 167 | + "'YetAnotherClass'", |
| 168 | + id="forward_refs", |
| 169 | + ), |
| 170 | + pytest.param( |
| 171 | + """class TestClass: |
| 172 | + def func(self, a, b): return a + b""", |
| 173 | + [("a", None), ("b", None)], # No annotations means type is None |
| 174 | + None, |
| 175 | + id="no_annotations", |
| 176 | + ), |
| 177 | + pytest.param( |
| 178 | + """from typing import List, Optional, Dict, Union, Any |
| 179 | +class TestClass: |
| 180 | + def func(self, a: List[Optional[Dict[str, Union[int, str]]]]) -> Dict[str, Any]: return {}""", |
| 181 | + [("a", "List[Optional[Dict[str, Union[int, str]]]]")], |
| 182 | + "Dict[str, Any]", |
| 183 | + id="complex_nested", |
| 184 | + ), |
| 185 | + pytest.param( |
| 186 | + """from typing import Literal |
| 187 | +class TestClass: |
| 188 | + def func(self, a: Literal['one', 'two']) -> Literal[True]: return True""", |
| 189 | + [("a", "Literal['one', 'two']")], |
| 190 | + "Literal[True]", |
| 191 | + id="literal_type", |
| 192 | + ), |
| 193 | +] |
| 194 | + |
| 195 | + |
| 196 | +# --- Tests --- |
| 197 | +class TestCodeAnalyzerArgsReturns: |
| 198 | + @pytest.mark.parametrize( |
| 199 | + "code_snippet, expected_args, expected_return", TYPE_TEST_CASES |
| 200 | + ) |
| 201 | + def test_type_extraction(self, code_snippet, expected_args, expected_return): |
| 202 | + structure, imports, types = parse_code(code_snippet) |
| 203 | + |
| 204 | + assert len(structure) == 1, "Should parse one class" |
| 205 | + class_info = structure[0] |
| 206 | + assert class_info["class_name"] == "TestClass" |
| 207 | + |
| 208 | + assert len(class_info["methods"]) == 1, "Should find one method" |
| 209 | + method_info = class_info["methods"][0] |
| 210 | + assert method_info["method_name"] == "func" |
| 211 | + |
| 212 | + # Extract args, skipping 'self' |
| 213 | + extracted_args = [] |
| 214 | + for arg in method_info.get("args", []): |
| 215 | + if arg["name"] == "self": |
| 216 | + continue |
| 217 | + extracted_args.append((arg["name"], arg["type"])) |
| 218 | + |
| 219 | + assert extracted_args == expected_args |
| 220 | + assert method_info.get("return_type") == expected_return |
0 commit comments