Skip to content

Commit 4d661c4

Browse files
authored
Merge branch 'autogen' into test/test__should_include_class_or_method
2 parents ca8c832 + 1466881 commit 4d661c4

File tree

6 files changed

+235
-65
lines changed

6 files changed

+235
-65
lines changed

scripts/microgenerator/generate.py

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import argparse
2828
import glob
2929
import logging
30+
import re
3031
from collections import defaultdict
3132
from pathlib import Path
3233
from typing import List, Dict, Any
@@ -46,7 +47,7 @@ class CodeAnalyzer(ast.NodeVisitor):
4647
"""
4748

4849
def __init__(self):
49-
self.structure: List[Dict[str, Any]] = []
50+
self.analyzed_classes: List[Dict[str, Any]] = []
5051
self.imports: set[str] = set()
5152
self.types: set[str] = set()
5253
self._current_class_info: Dict[str, Any] | None = None
@@ -105,13 +106,19 @@ def _collect_types_from_node(self, node: ast.AST | None) -> None:
105106
if type_str:
106107
self.types.add(type_str)
107108
elif isinstance(node, ast.Subscript):
108-
self._collect_types_from_node(node.value)
109+
# Add the base type of the subscript (e.g., "List", "Dict")
110+
if isinstance(node.value, ast.Name):
111+
self.types.add(node.value.id)
112+
self._collect_types_from_node(node.value) # Recurse on value just in case
109113
self._collect_types_from_node(node.slice)
110114
elif isinstance(node, (ast.Tuple, ast.List)):
111115
for elt in node.elts:
112116
self._collect_types_from_node(elt)
113-
elif isinstance(node, ast.Constant) and isinstance(node.value, str):
114-
self.types.add(node.value)
117+
elif isinstance(node, ast.Constant):
118+
if isinstance(node.value, str): # Forward references
119+
self.types.add(node.value)
120+
elif node.value is None: # None type
121+
self.types.add("None")
115122
elif isinstance(node, ast.BinOp) and isinstance(
116123
node.op, ast.BitOr
117124
): # For | union type
@@ -163,7 +170,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
163170
type_str = self._get_type_str(item.annotation)
164171
class_info["attributes"].append({"name": attr_name, "type": type_str})
165172

166-
self.structure.append(class_info)
173+
self.analyzed_classes.append(class_info)
167174
self._current_class_info = class_info
168175
self._depth += 1
169176
self.generic_visit(node)
@@ -259,6 +266,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
259266
# directly within the class body, not inside a method.
260267
elif isinstance(target, ast.Name) and not self._is_in_method:
261268
self._add_attribute(target.id, self._get_type_str(node.annotation))
269+
self._collect_types_from_node(node.annotation)
262270
self.generic_visit(node)
263271

264272

@@ -279,7 +287,7 @@ def parse_code(code: str) -> tuple[List[Dict[str, Any]], set[str], set[str]]:
279287
tree = ast.parse(code)
280288
analyzer = CodeAnalyzer()
281289
analyzer.visit(tree)
282-
return analyzer.structure, analyzer.imports, analyzer.types
290+
return analyzer.analyzed_classes, analyzer.imports, analyzer.types
283291

284292

285293
def parse_file(file_path: str) -> tuple[List[Dict[str, Any]], set[str], set[str]]:
@@ -331,10 +339,10 @@ def list_code_objects(
331339
all_class_keys = []
332340

333341
def process_structure(
334-
structure: List[Dict[str, Any]], file_name: str | None = None
342+
analyzed_classes: List[Dict[str, Any]], file_name: str | None = None
335343
):
336344
"""Populates the results dictionary from the parsed AST structure."""
337-
for class_info in structure:
345+
for class_info in analyzed_classes:
338346
key = class_info["class_name"]
339347
if file_name:
340348
key = f"{key} (in {file_name})"
@@ -360,13 +368,13 @@ def process_structure(
360368

361369
# Determine if the path is a file or directory and process accordingly
362370
if os.path.isfile(path) and path.endswith(".py"):
363-
structure, _, _ = parse_file(path)
364-
process_structure(structure)
371+
analyzed_classes, _, _ = parse_file(path)
372+
process_structure(analyzed_classes)
365373
elif os.path.isdir(path):
366374
# This assumes `utils.walk_codebase` is defined elsewhere.
367375
for file_path in utils.walk_codebase(path):
368-
structure, _, _ = parse_file(file_path)
369-
process_structure(structure, file_name=os.path.basename(file_path))
376+
analyzed_classes, _, _ = parse_file(file_path)
377+
process_structure(analyzed_classes, file_name=os.path.basename(file_path))
370378

371379
# Return the data in the desired format based on the flags
372380
if not show_methods and not show_attributes:
@@ -466,11 +474,11 @@ def _build_request_arg_schema(
466474
module_name = os.path.splitext(relative_path)[0].replace(os.path.sep, ".")
467475

468476
try:
469-
structure, _, _ = parse_file(file_path)
470-
if not structure:
477+
analyzed_classes, _, _ = parse_file(file_path)
478+
if not analyzed_classes:
471479
continue
472480

473-
for class_info in structure:
481+
for class_info in analyzed_classes:
474482
class_name = class_info.get("class_name", "Unknown")
475483
if class_name.endswith("Request"):
476484
full_class_name = f"{module_name}.{class_name}"
@@ -498,11 +506,11 @@ def _process_service_clients(
498506
if "/services/" not in file_path:
499507
continue
500508

501-
structure, imports, types = parse_file(file_path)
509+
analyzed_classes, imports, types = parse_file(file_path)
502510
all_imports.update(imports)
503511
all_types.update(types)
504512

505-
for class_info in structure:
513+
for class_info in analyzed_classes:
506514
class_name = class_info["class_name"]
507515
if not _should_include_class(class_name, class_filters):
508516
continue
@@ -546,7 +554,6 @@ def analyze_source_files(
546554
# Make the pattern absolute
547555
absolute_pattern = os.path.join(project_root, pattern)
548556
source_files.extend(glob.glob(absolute_pattern, recursive=True))
549-
550557
# PASS 1: Build the request argument schema from the types files.
551558
request_arg_schema = _build_request_arg_schema(source_files, project_root)
552559

@@ -607,14 +614,14 @@ def generate_code(config: Dict[str, Any], analysis_results: tuple) -> None:
607614
Generates source code files using Jinja2 templates.
608615
"""
609616
data, all_imports, all_types, request_arg_schema = analysis_results
610-
project_root = config["project_root"]
611-
config_dir = config["config_dir"]
612617

613618
templates_config = config.get("templates", [])
614619
for item in templates_config:
615-
template_path = str(Path(config_dir) / item["template"])
616-
output_path = str(Path(project_root) / item["output"])
620+
template_name = item["template"]
621+
output_name = item["output"]
617622

623+
template_path = str(Path(config["config_dir"]) / template_name)
624+
output_path = str(Path(config["project_root"]) / output_name)
618625
template = utils.load_template(template_path)
619626
methods_context = []
620627
for class_name, methods in data.items():
@@ -710,18 +717,20 @@ def find_project_root(start_path: str, markers: list[str]) -> str | None:
710717
return None
711718
current_path = parent_path
712719

713-
# Load configuration from the YAML file.
714-
config = utils.load_config(config_path)
720+
# Get the absolute path of the config file
721+
abs_config_path = os.path.abspath(config_path)
722+
config = utils.load_config(abs_config_path)
715723

716-
# Determine the project root.
717-
script_dir = os.path.dirname(os.path.abspath(__file__))
718-
project_root = find_project_root(script_dir, ["setup.py", ".git"])
724+
# Determine the project root
725+
# Start searching from the directory of this script file
726+
script_file_dir = os.path.dirname(os.path.abspath(__file__))
727+
project_root = find_project_root(script_file_dir, [".git"])
719728
if not project_root:
720-
project_root = os.getcwd() # Fallback to current directory
729+
# Fallback to the directory from which the script was invoked
730+
project_root = os.getcwd()
721731

722-
# Set paths in the config dictionary.
723732
config["project_root"] = project_root
724-
config["config_dir"] = os.path.dirname(os.path.abspath(config_path))
733+
config["config_dir"] = os.path.dirname(abs_config_path)
725734

726735
return config
727736

@@ -762,9 +771,12 @@ def _execute_post_processing(config: Dict[str, Any]):
762771
all_end_index = i
763772

764773
if all_start_index != -1 and all_end_index != -1:
765-
for i in range(all_start_index + 1, all_end_index):
766-
member = lines[i].strip().replace('"', "").replace(",", "")
767-
if member:
774+
all_content = "".join(lines[all_start_index + 1 : all_end_index])
775+
776+
# Extract quoted strings
777+
found_members = re.findall(r'"([^"]+)"', all_content)
778+
for member in found_members:
779+
if member not in all_list:
768780
all_list.append(member)
769781

770782
# --- Add new items and sort ---
@@ -777,7 +789,9 @@ def _execute_post_processing(config: Dict[str, Any]):
777789
for new_member in job.get("add_to_all", []):
778790
if new_member not in all_list:
779791
all_list.append(new_member)
780-
all_list.sort()
792+
all_list = sorted(list(set(all_list))) # Ensure unique and sorted
793+
# Format for the template
794+
all_list = [f' "{item}",\n' for item in all_list]
781795

782796
# --- Render the new file content ---
783797
template = utils.load_template(template_path)

scripts/microgenerator/noxfile.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from functools import wraps
1818
import pathlib
19-
import os
2019
import nox
2120
import time
2221

@@ -26,7 +25,7 @@
2625
BLACK_VERSION = "black==23.7.0"
2726
BLACK_PATHS = (".",)
2827

29-
DEFAULT_PYTHON_VERSION = "3.9"
28+
DEFAULT_PYTHON_VERSION = "3.13"
3029
UNIT_TEST_PYTHON_VERSIONS = ["3.9", "3.11", "3.12", "3.13"]
3130
CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute()
3231

@@ -190,9 +189,8 @@ def lint(session):
190189
session.install("flake8", BLACK_VERSION)
191190
session.install("-e", ".")
192191
session.run("python", "-m", "pip", "freeze")
193-
session.run("flake8", os.path.join("scripts"))
192+
session.run("flake8", ".")
194193
session.run("flake8", "tests")
195-
session.run("flake8", "benchmark")
196194
session.run("black", "--check", *BLACK_PATHS)
197195

198196

scripts/microgenerator/templates/post-processing/init.py.j2

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ __version__ = package_version.__version__
2323
{%- endfor %}
2424

2525
__all__ = (
26-
{%- for item in all_list %}
27-
"{{ item }}",
26+
{% for item in all_list %}
27+
{{ item }}
2828
{%- endfor %}
2929
)

scripts/microgenerator/tests/unit/test_generate_analyzer.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_import_extraction(self, code_snippet, expected_imports):
9595

9696
class TestCodeAnalyzerAttributes:
9797
@pytest.mark.parametrize(
98-
"code_snippet, expected_structure",
98+
"code_snippet, expected_analyzed_classes",
9999
[
100100
pytest.param(
101101
"""
@@ -243,22 +243,24 @@ def __init__(self):
243243
),
244244
],
245245
)
246-
def test_attribute_extraction(self, code_snippet: str, expected_structure: list):
246+
def test_attribute_extraction(
247+
self, code_snippet: str, expected_analyzed_classes: list
248+
):
247249
"""Tests the extraction of class and instance attributes."""
248250
analyzer = CodeAnalyzer()
249251
tree = ast.parse(code_snippet)
250252
analyzer.visit(tree)
251253

252-
extracted = analyzer.structure
254+
extracted = analyzer.analyzed_classes
253255
# Normalize attributes for order-independent comparison
254256
for item in extracted:
255257
if "attributes" in item:
256258
item["attributes"].sort(key=lambda x: x["name"])
257-
for item in expected_structure:
259+
for item in expected_analyzed_classes:
258260
if "attributes" in item:
259261
item["attributes"].sort(key=lambda x: x["name"])
260262

261-
assert extracted == expected_structure
263+
assert extracted == expected_analyzed_classes
262264

263265

264266
# --- Mock Types ---
@@ -284,8 +286,8 @@ class MyClass:
284286
analyzer = CodeAnalyzer()
285287
tree = ast.parse(code)
286288
analyzer.visit(tree)
287-
assert len(analyzer.structure) == 1
288-
assert analyzer.structure[0]["class_name"] == "MyClass"
289+
assert len(analyzer.analyzed_classes) == 1
290+
assert analyzer.analyzed_classes[0]["class_name"] == "MyClass"
289291

290292

291293
def test_codeanalyzer_finds_multiple_classes():
@@ -302,8 +304,8 @@ class ClassB:
302304
analyzer = CodeAnalyzer()
303305
tree = ast.parse(code)
304306
analyzer.visit(tree)
305-
assert len(analyzer.structure) == 2
306-
class_names = sorted([c["class_name"] for c in analyzer.structure])
307+
assert len(analyzer.analyzed_classes) == 2
308+
class_names = sorted([c["class_name"] for c in analyzer.analyzed_classes])
307309
assert class_names == ["ClassA", "ClassB"]
308310

309311

@@ -318,9 +320,9 @@ def my_method(self):
318320
analyzer = CodeAnalyzer()
319321
tree = ast.parse(code)
320322
analyzer.visit(tree)
321-
assert len(analyzer.structure) == 1
322-
assert len(analyzer.structure[0]["methods"]) == 1
323-
assert analyzer.structure[0]["methods"][0]["method_name"] == "my_method"
323+
assert len(analyzer.analyzed_classes) == 1
324+
assert len(analyzer.analyzed_classes[0]["methods"]) == 1
325+
assert analyzer.analyzed_classes[0]["methods"][0]["method_name"] == "my_method"
324326

325327

326328
def test_codeanalyzer_finds_multiple_methods():
@@ -337,8 +339,8 @@ def method_b(self):
337339
analyzer = CodeAnalyzer()
338340
tree = ast.parse(code)
339341
analyzer.visit(tree)
340-
assert len(analyzer.structure) == 1
341-
method_names = sorted([m["method_name"] for m in analyzer.structure[0]["methods"]])
342+
assert len(analyzer.analyzed_classes) == 1
343+
method_names = sorted([m["method_name"] for m in analyzer.analyzed_classes[0]["methods"]])
342344
assert method_names == ["method_a", "method_b"]
343345

344346

@@ -352,7 +354,7 @@ def top_level_function():
352354
analyzer = CodeAnalyzer()
353355
tree = ast.parse(code)
354356
analyzer.visit(tree)
355-
assert len(analyzer.structure) == 0
357+
assert len(analyzer.analyzed_classes) == 0
356358

357359

358360
def test_codeanalyzer_class_with_no_methods():
@@ -365,9 +367,9 @@ class MyClass:
365367
analyzer = CodeAnalyzer()
366368
tree = ast.parse(code)
367369
analyzer.visit(tree)
368-
assert len(analyzer.structure) == 1
369-
assert analyzer.structure[0]["class_name"] == "MyClass"
370-
assert len(analyzer.structure[0]["methods"]) == 0
370+
assert len(analyzer.analyzed_classes) == 1
371+
assert analyzer.analyzed_classes[0]["class_name"] == "MyClass"
372+
assert len(analyzer.analyzed_classes[0]["methods"]) == 0
371373

372374

373375
# --- Test Data for Parameterization ---
@@ -487,10 +489,10 @@ class TestCodeAnalyzerArgsReturns:
487489
"code_snippet, expected_args, expected_return", TYPE_TEST_CASES
488490
)
489491
def test_type_extraction(self, code_snippet, expected_args, expected_return):
490-
structure, imports, types = parse_code(code_snippet)
492+
analyzed_classes, imports, types = parse_code(code_snippet)
491493

492-
assert len(structure) == 1, "Should parse one class"
493-
class_info = structure[0]
494+
assert len(analyzed_classes) == 1, "Should parse one class"
495+
class_info = analyzed_classes[0]
494496
assert class_info["class_name"] == "TestClass"
495497

496498
assert len(class_info["methods"]) == 1, "Should find one method"
@@ -506,3 +508,4 @@ def test_type_extraction(self, code_snippet, expected_args, expected_return):
506508

507509
assert extracted_args == expected_args
508510
assert method_info.get("return_type") == expected_return
511+

0 commit comments

Comments
 (0)