Skip to content

Commit 2c06529

Browse files
authored
Merge branch 'autogen' into test/adds-generate-typing-checks-2
2 parents cb385ab + 5be86dd commit 2c06529

File tree

3 files changed

+48
-38
lines changed

3 files changed

+48
-38
lines changed

scripts/microgenerator/generate.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
import argparse
2828
import glob
2929
import logging
30-
import re
3130
from collections import defaultdict
32-
from typing import List, Dict, Any, Iterator
31+
from pathlib import Path
32+
from typing import List, Dict, Any
3333

3434
from . import name_utils
3535
from . import utils
@@ -51,6 +51,7 @@ def __init__(self):
5151
self.types: set[str] = set()
5252
self._current_class_info: Dict[str, Any] | None = None
5353
self._is_in_method: bool = False
54+
self._depth = 0
5455

5556
def _get_type_str(self, node: ast.AST | None) -> str | None:
5657
"""Recursively reconstructs a type annotation string from an AST node."""
@@ -119,30 +120,32 @@ def _collect_types_from_node(self, node: ast.AST | None) -> None:
119120

120121
def visit_Import(self, node: ast.Import) -> None:
121122
"""Catches 'import X' and 'import X as Y' statements."""
122-
for alias in node.names:
123-
if alias.asname:
124-
self.imports.add(f"import {alias.name} as {alias.asname}")
125-
else:
126-
self.imports.add(f"import {alias.name}")
123+
if self._depth == 0: # Only top-level imports
124+
for alias in node.names:
125+
if alias.asname:
126+
self.imports.add(f"import {alias.name} as {alias.asname}")
127+
else:
128+
self.imports.add(f"import {alias.name}")
127129
self.generic_visit(node)
128130

129131
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
130132
"""Catches 'from X import Y' statements."""
131-
module = node.module or ""
132-
if not module:
133-
module = "." * node.level
134-
else:
135-
module = "." * node.level + module
136-
137-
names = []
138-
for alias in node.names:
139-
if alias.asname:
140-
names.append(f"{alias.name} as {alias.asname}")
133+
if self._depth == 0: # Only top-level imports
134+
module = node.module or ""
135+
if not module:
136+
module = "." * node.level
141137
else:
142-
names.append(alias.name)
138+
module = "." * node.level + module
139+
140+
names = []
141+
for alias in node.names:
142+
if alias.asname:
143+
names.append(f"{alias.name} as {alias.asname}")
144+
else:
145+
names.append(alias.name)
143146

144-
if names:
145-
self.imports.add(f"from {module} import {', '.join(names)}")
147+
if names:
148+
self.imports.add(f"from {module} import {', '.join(names)}")
146149
self.generic_visit(node)
147150

148151
def visit_ClassDef(self, node: ast.ClassDef) -> None:
@@ -162,12 +165,15 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
162165

163166
self.structure.append(class_info)
164167
self._current_class_info = class_info
168+
self._depth += 1
165169
self.generic_visit(node)
170+
self._depth -= 1
166171
self._current_class_info = None
167172

168173
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
169174
"""Visits a function/method definition node."""
170-
if self._current_class_info: # This is a method
175+
is_method = self._current_class_info is not None
176+
if is_method:
171177
args_info = []
172178

173179
# Get default values
@@ -196,10 +202,13 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
196202
"return_type": return_type,
197203
}
198204
self._current_class_info["methods"].append(method_info)
199-
200-
# Visit nodes inside the method to find instance attributes.
201205
self._is_in_method = True
202-
self.generic_visit(node)
206+
207+
self._depth += 1
208+
self.generic_visit(node)
209+
self._depth -= 1
210+
211+
if is_method:
203212
self._is_in_method = False
204213

205214
def _add_attribute(self, attr_name: str, attr_type: str | None = None):

scripts/microgenerator/templates/client.py.j2

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
1-
# TODO: Add a header if needed.
2-
3-
# ======== 🦕 HERE THERE BE DINOSAURS 🦖 =========
4-
# This content is subject to significant change. Not for review yet.
5-
# Included as a proof of concept for context or testing ONLY.
6-
# ================================================
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2025 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
716

817
# Imports
918
import os
@@ -111,9 +120,3 @@ class BigQueryClient:
111120

112121
{#- Helper Section: methods included from partial template -#}
113122
{#- include "partials/_client_helpers.j2" #}
114-
115-
116-
# ======== 🦕 HERE THERE WERE DINOSAURS 🦖 =========
117-
# The above content is subject to significant change. Not for review yet.
118-
# Included as a proof of concept for context or testing ONLY.
119-
# ================================================

scripts/microgenerator/tests/unit/test_generate_analyzer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ class YetAnotherClass:
3232
pass
3333

3434

35-
# --- Existing Tests ---
3635
def test_codeanalyzer_finds_class():
3736
code = """
3837
class MyClass:
@@ -193,7 +192,6 @@ def func(self, a: Literal['one', 'two']) -> Literal[True]: return True""",
193192
]
194193

195194

196-
# --- Tests ---
197195
class TestCodeAnalyzerArgsReturns:
198196
@pytest.mark.parametrize(
199197
"code_snippet, expected_args, expected_return", TYPE_TEST_CASES

0 commit comments

Comments
 (0)