Skip to content

Commit 5be86dd

Browse files
test: tests collection of import and from import statements (#2313)
Co-authored-by: gcf-merge-on-green[bot] <60162190+gcf-merge-on-green[bot]@users.noreply.github.com>
1 parent eca20c6 commit 5be86dd

File tree

2 files changed

+122
-24
lines changed

2 files changed

+122
-24
lines changed

scripts/microgenerator/generate.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
import argparse
2828
import glob
2929
import logging
30-
import re
3130
from collections import defaultdict
32-
from typing import List, Dict, Any, Iterator
31+
from pathlib import Path
32+
from typing import List, Dict, Any
3333

3434
from . import name_utils
3535
from . import utils
@@ -51,6 +51,7 @@ def __init__(self):
5151
self.types: set[str] = set()
5252
self._current_class_info: Dict[str, Any] | None = None
5353
self._is_in_method: bool = False
54+
self._depth = 0
5455

5556
def _get_type_str(self, node: ast.AST | None) -> str | None:
5657
"""Recursively reconstructs a type annotation string from an AST node."""
@@ -112,30 +113,32 @@ def _collect_types_from_node(self, node: ast.AST | None) -> None:
112113

113114
def visit_Import(self, node: ast.Import) -> None:
114115
"""Catches 'import X' and 'import X as Y' statements."""
115-
for alias in node.names:
116-
if alias.asname:
117-
self.imports.add(f"import {alias.name} as {alias.asname}")
118-
else:
119-
self.imports.add(f"import {alias.name}")
116+
if self._depth == 0: # Only top-level imports
117+
for alias in node.names:
118+
if alias.asname:
119+
self.imports.add(f"import {alias.name} as {alias.asname}")
120+
else:
121+
self.imports.add(f"import {alias.name}")
120122
self.generic_visit(node)
121123

122124
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
123125
"""Catches 'from X import Y' statements."""
124-
module = node.module or ""
125-
if not module:
126-
module = "." * node.level
127-
else:
128-
module = "." * node.level + module
129-
130-
names = []
131-
for alias in node.names:
132-
if alias.asname:
133-
names.append(f"{alias.name} as {alias.asname}")
126+
if self._depth == 0: # Only top-level imports
127+
module = node.module or ""
128+
if not module:
129+
module = "." * node.level
134130
else:
135-
names.append(alias.name)
131+
module = "." * node.level + module
132+
133+
names = []
134+
for alias in node.names:
135+
if alias.asname:
136+
names.append(f"{alias.name} as {alias.asname}")
137+
else:
138+
names.append(alias.name)
136139

137-
if names:
138-
self.imports.add(f"from {module} import {', '.join(names)}")
140+
if names:
141+
self.imports.add(f"from {module} import {', '.join(names)}")
139142
self.generic_visit(node)
140143

141144
def visit_ClassDef(self, node: ast.ClassDef) -> None:
@@ -155,12 +158,15 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
155158

156159
self.structure.append(class_info)
157160
self._current_class_info = class_info
161+
self._depth += 1
158162
self.generic_visit(node)
163+
self._depth -= 1
159164
self._current_class_info = None
160165

161166
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
162167
"""Visits a function/method definition node."""
163-
if self._current_class_info: # This is a method
168+
is_method = self._current_class_info is not None
169+
if is_method:
164170
args_info = []
165171

166172
# Get default values
@@ -189,10 +195,13 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
189195
"return_type": return_type,
190196
}
191197
self._current_class_info["methods"].append(method_info)
192-
193-
# Visit nodes inside the method to find instance attributes.
194198
self._is_in_method = True
195-
self.generic_visit(node)
199+
200+
self._depth += 1
201+
self.generic_visit(node)
202+
self._depth -= 1
203+
204+
if is_method:
196205
self._is_in_method = False
197206

198207
def _add_attribute(self, attr_name: str, attr_type: str | None = None):
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2025 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import ast
18+
import pytest
19+
from scripts.microgenerator.generate import CodeAnalyzer
20+
21+
# --- Tests CodeAnalyzer handling of Imports ---
22+
23+
24+
class TestCodeAnalyzerImports:
25+
@pytest.mark.parametrize(
26+
"code_snippet, expected_imports",
27+
[
28+
pytest.param(
29+
"import os\nimport sys",
30+
["import os", "import sys"],
31+
id="simple_imports",
32+
),
33+
pytest.param(
34+
"import numpy as np",
35+
["import numpy as np"],
36+
id="aliased_import",
37+
),
38+
pytest.param(
39+
"from collections import defaultdict, OrderedDict",
40+
["from collections import defaultdict, OrderedDict"],
41+
id="from_import_multiple",
42+
),
43+
pytest.param(
44+
"from typing import List as L",
45+
["from typing import List as L"],
46+
id="from_import_aliased",
47+
),
48+
pytest.param(
49+
"from math import *",
50+
["from math import *"],
51+
id="from_import_wildcard",
52+
),
53+
pytest.param(
54+
"import os.path",
55+
["import os.path"],
56+
id="dotted_import",
57+
),
58+
pytest.param(
59+
"from google.cloud import bigquery",
60+
["from google.cloud import bigquery"],
61+
id="from_dotted_module",
62+
),
63+
pytest.param(
64+
"",
65+
[],
66+
id="no_imports",
67+
),
68+
pytest.param(
69+
"class MyClass:\n import json # Should not be picked up",
70+
[],
71+
id="import_inside_class",
72+
),
73+
pytest.param(
74+
"def my_func():\n from time import sleep # Should not be picked up",
75+
[],
76+
id="import_inside_function",
77+
),
78+
],
79+
)
80+
def test_import_extraction(self, code_snippet, expected_imports):
81+
analyzer = CodeAnalyzer()
82+
tree = ast.parse(code_snippet)
83+
analyzer.visit(tree)
84+
85+
# Normalize for comparison
86+
extracted = sorted(list(analyzer.imports))
87+
expected = sorted(expected_imports)
88+
89+
assert extracted == expected

0 commit comments

Comments
 (0)