Skip to content

Commit 595e59f

Browse files
committed
feat: Add AST analysis utilities
Introduces the CodeAnalyzer class and helper functions for parsing Python code using the ast module. This provides the foundation for understanding service client structures.
1 parent c457754 commit 595e59f

File tree

1 file changed

+354
-0
lines changed

1 file changed

+354
-0
lines changed

scripts/microgenerator/generate.py

Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2025 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
"""
18+
A dual-purpose module for Python code analysis and BigQuery client generation.
19+
20+
When run as a script, it generates the BigQueryClient source code.
21+
When imported, it provides utility functions for parsing and exploring
22+
any Python codebase using the `ast` module.
23+
"""
24+
25+
import ast
26+
import os
27+
from collections import defaultdict
28+
from typing import List, Dict, Any, Iterator
29+
30+
from . import utils
31+
32+
# =============================================================================
33+
# Section 1: Generic AST Analysis Utilities
34+
# =============================================================================
35+
36+
37+
class CodeAnalyzer(ast.NodeVisitor):
38+
"""
39+
A node visitor to traverse an AST and extract structured information
40+
about classes, methods, and their arguments.
41+
"""
42+
43+
def __init__(self):
44+
self.structure: List[Dict[str, Any]] = []
45+
self.imports: set[str] = set()
46+
self.types: set[str] = set()
47+
self._current_class_info: Dict[str, Any] | None = None
48+
self._is_in_method: bool = False
49+
50+
def _get_type_str(self, node: ast.AST | None) -> str | None:
51+
"""Recursively reconstructs a type annotation string from an AST node."""
52+
if node is None:
53+
return None
54+
# Handles simple names like 'str', 'int', 'HttpRequest'
55+
if isinstance(node, ast.Name):
56+
return node.id
57+
# Handles dotted names like 'service.GetDatasetRequest'
58+
if isinstance(node, ast.Attribute):
59+
# Attempt to reconstruct the full dotted path
60+
parts = []
61+
curr = node
62+
while isinstance(curr, ast.Attribute):
63+
parts.append(curr.attr)
64+
curr = curr.value
65+
if isinstance(curr, ast.Name):
66+
parts.append(curr.id)
67+
return ".".join(reversed(parts))
68+
# Handles subscripted types like 'list[str]', 'Optional[...]'
69+
if isinstance(node, ast.Subscript):
70+
value_str = self._get_type_str(node.value)
71+
slice_str = self._get_type_str(node.slice)
72+
return f"{value_str}[{slice_str}]"
73+
# Handles tuples inside subscripts, e.g., 'dict[str, int]'
74+
if isinstance(node, ast.Tuple):
75+
return ", ".join(
76+
[s for s in (self._get_type_str(e) for e in node.elts) if s]
77+
)
78+
# Handles forward references as strings, e.g., '"Dataset"'
79+
if isinstance(node, ast.Constant):
80+
return repr(node.value)
81+
return None # Fallback for unhandled types
82+
83+
def _collect_types_from_node(self, node: ast.AST | None) -> None:
84+
"""Recursively traverses an annotation node to find and collect all type names."""
85+
if node is None:
86+
return
87+
88+
if isinstance(node, ast.Name):
89+
self.types.add(node.id)
90+
elif isinstance(node, ast.Attribute):
91+
type_str = self._get_type_str(node)
92+
if type_str:
93+
self.types.add(type_str)
94+
elif isinstance(node, ast.Subscript):
95+
self._collect_types_from_node(node.value)
96+
self._collect_types_from_node(node.slice)
97+
elif isinstance(node, (ast.Tuple, ast.List)):
98+
for elt in node.elts:
99+
self._collect_types_from_node(elt)
100+
elif isinstance(node, ast.Constant) and isinstance(node.value, str):
101+
self.types.add(node.value)
102+
elif isinstance(node, ast.BinOp) and isinstance(
103+
node.op, ast.BitOr
104+
): # For | union type
105+
self._collect_types_from_node(node.left)
106+
self._collect_types_from_node(node.right)
107+
108+
def visit_Import(self, node: ast.Import) -> None:
109+
"""Catches 'import X' and 'import X as Y' statements."""
110+
for alias in node.names:
111+
if alias.asname:
112+
self.imports.add(f"import {alias.name} as {alias.asname}")
113+
else:
114+
self.imports.add(f"import {alias.name}")
115+
self.generic_visit(node)
116+
117+
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
118+
"""Catches 'from X import Y' statements."""
119+
module = node.module or ""
120+
if not module:
121+
module = "." * node.level
122+
else:
123+
module = "." * node.level + module
124+
125+
names = []
126+
for alias in node.names:
127+
if alias.asname:
128+
names.append(f"{alias.name} as {alias.asname}")
129+
else:
130+
names.append(alias.name)
131+
132+
if names:
133+
self.imports.add(f"from {module} import {', '.join(names)}")
134+
self.generic_visit(node)
135+
136+
def visit_ClassDef(self, node: ast.ClassDef) -> None:
137+
"""Visits a class definition node."""
138+
class_info = {
139+
"class_name": node.name,
140+
"methods": [],
141+
"attributes": [],
142+
}
143+
144+
# Extract class-level attributes (for proto.Message classes)
145+
for item in node.body:
146+
if isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name):
147+
attr_name = item.target.id
148+
type_str = self._get_type_str(item.annotation)
149+
class_info["attributes"].append({"name": attr_name, "type": type_str})
150+
151+
self.structure.append(class_info)
152+
self._current_class_info = class_info
153+
self.generic_visit(node)
154+
self._current_class_info = None
155+
156+
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
157+
"""Visits a function/method definition node."""
158+
if self._current_class_info: # This is a method
159+
args_info = []
160+
161+
# Get default values
162+
defaults = [self._get_type_str(d) for d in node.args.defaults]
163+
num_defaults = len(defaults)
164+
num_args = len(node.args.args)
165+
166+
for i, arg in enumerate(node.args.args):
167+
arg_data = {"name": arg.arg, "type": self._get_type_str(arg.annotation)}
168+
169+
# Match defaults to arguments from the end
170+
default_index = i - (num_args - num_defaults)
171+
if default_index >= 0:
172+
arg_data["default"] = defaults[default_index]
173+
174+
args_info.append(arg_data)
175+
self._collect_types_from_node(arg.annotation)
176+
177+
# Collect return type
178+
return_type = self._get_type_str(node.returns)
179+
self._collect_types_from_node(node.returns)
180+
181+
method_info = {
182+
"method_name": node.name,
183+
"args": args_info,
184+
"return_type": return_type,
185+
}
186+
self._current_class_info["methods"].append(method_info)
187+
188+
# Visit nodes inside the method to find instance attributes.
189+
self._is_in_method = True
190+
self.generic_visit(node)
191+
self._is_in_method = False
192+
193+
def _add_attribute(self, attr_name: str, attr_type: str | None = None):
194+
"""Adds a unique attribute to the current class context."""
195+
if self._current_class_info:
196+
# Create a list of attribute names for easy lookup
197+
attr_names = [
198+
attr.get("name") for attr in self._current_class_info["attributes"]
199+
]
200+
if attr_name not in attr_names:
201+
self._current_class_info["attributes"].append(
202+
{"name": attr_name, "type": attr_type}
203+
)
204+
205+
def visit_Assign(self, node: ast.Assign) -> None:
206+
"""Handles attribute assignments: `x = ...` and `self.x = ...`."""
207+
if self._current_class_info:
208+
for target in node.targets:
209+
# Instance attribute: self.x = ...
210+
if (
211+
isinstance(target, ast.Attribute)
212+
and isinstance(target.value, ast.Name)
213+
and target.value.id == "self"
214+
):
215+
self._add_attribute(target.attr)
216+
# Class attribute: x = ... (only if not inside a method)
217+
elif isinstance(target, ast.Name) and not self._is_in_method:
218+
self._add_attribute(target.id)
219+
self.generic_visit(node)
220+
221+
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
222+
"""Handles annotated assignments: `x: int = ...` and `self.x: int = ...`."""
223+
if self._current_class_info:
224+
target = node.target
225+
# Instance attribute: self.x: int = ...
226+
if (
227+
isinstance(target, ast.Attribute)
228+
and isinstance(target.value, ast.Name)
229+
and target.value.id == "self"
230+
):
231+
self._add_attribute(target.attr, self._get_type_str(node.annotation))
232+
# Class attribute: x: int = ...
233+
# We identify it as a class attribute if the assignment happens
234+
# directly within the class body, not inside a method.
235+
elif isinstance(target, ast.Name) and not self._is_in_method:
236+
self._add_attribute(target.id, self._get_type_str(node.annotation))
237+
self.generic_visit(node)
238+
239+
240+
def parse_code(code: str) -> tuple[List[Dict[str, Any]], set[str], set[str]]:
241+
"""
242+
Parses a string of Python code into a structured list of classes, a set of imports,
243+
and a set of all type annotations found.
244+
245+
Args:
246+
code: A string containing Python code.
247+
248+
Returns:
249+
A tuple containing:
250+
- A list of dictionaries, where each dictionary represents a class.
251+
- A set of strings, where each string is an import statement.
252+
- A set of strings, where each string is a type annotation.
253+
"""
254+
tree = ast.parse(code)
255+
analyzer = CodeAnalyzer()
256+
analyzer.visit(tree)
257+
return analyzer.structure, analyzer.imports, analyzer.types
258+
259+
260+
def parse_file(file_path: str) -> tuple[List[Dict[str, Any]], set[str], set[str]]:
261+
"""
262+
Parses a Python file into a structured list of classes, a set of imports,
263+
and a set of all type annotations found.
264+
265+
Args:
266+
file_path: The absolute path to the Python file.
267+
268+
Returns:
269+
A tuple containing the class structure, a set of import statements,
270+
and a set of type annotations.
271+
"""
272+
with open(file_path, "r", encoding="utf-8") as source:
273+
code = source.read()
274+
return parse_code(code)
275+
276+
277+
def list_code_objects(
278+
path: str,
279+
show_methods: bool = False,
280+
show_attributes: bool = False,
281+
show_arguments: bool = False,
282+
) -> Any:
283+
"""
284+
Lists classes and optionally their methods, attributes, and arguments
285+
from a given Python file or directory.
286+
287+
This function consolidates the functionality of the various `list_*` functions.
288+
289+
Args:
290+
path (str): The absolute path to a Python file or directory.
291+
show_methods (bool): Whether to include methods in the output.
292+
show_attributes (bool): Whether to include attributes in the output.
293+
show_arguments (bool): If True, includes method arguments. Implies show_methods.
294+
295+
Returns:
296+
- If `show_methods` and `show_attributes` are both False, returns a
297+
sorted `List[str]` of class names (mimicking `list_classes`).
298+
- Otherwise, returns a `Dict[str, Dict[str, Any]]` containing the
299+
requested details about each class.
300+
"""
301+
# If show_arguments is True, we must show methods.
302+
if show_arguments:
303+
show_methods = True
304+
305+
results = defaultdict(dict)
306+
all_class_keys = []
307+
308+
def process_structure(
309+
structure: List[Dict[str, Any]], file_name: str | None = None
310+
):
311+
"""Populates the results dictionary from the parsed AST structure."""
312+
for class_info in structure:
313+
key = class_info["class_name"]
314+
if file_name:
315+
key = f"{key} (in {file_name})"
316+
317+
all_class_keys.append(key)
318+
319+
# Skip filling details if not needed for the dictionary.
320+
if not show_methods and not show_attributes:
321+
continue
322+
323+
if show_attributes:
324+
results[key]["attributes"] = sorted(class_info["attributes"])
325+
326+
if show_methods:
327+
if show_arguments:
328+
method_details = {}
329+
# Sort methods by name for consistent output
330+
for method in sorted(
331+
class_info["methods"], key=lambda m: m["method_name"]
332+
):
333+
method_details[method["method_name"]] = method["args"]
334+
results[key]["methods"] = method_details
335+
else:
336+
results[key]["methods"] = sorted(
337+
[m["method_name"] for m in class_info["methods"]]
338+
)
339+
340+
# Determine if the path is a file or directory and process accordingly
341+
if os.path.isfile(path) and path.endswith(".py"):
342+
structure, _, _ = parse_file(path)
343+
process_structure(structure)
344+
elif os.path.isdir(path):
345+
# This assumes `utils.walk_codebase` is defined elsewhere.
346+
for file_path in utils.walk_codebase(path):
347+
structure, _, _ = parse_file(file_path)
348+
process_structure(structure, file_name=os.path.basename(file_path))
349+
350+
# Return the data in the desired format based on the flags
351+
if not show_methods and not show_attributes:
352+
return sorted(all_class_keys)
353+
else:
354+
return dict(results)

0 commit comments

Comments
 (0)