Skip to content

Commit 09f94bc

Browse files
authored
Fix: tidy-imports fails to add missing imports when file has only local imports (#482)
* Add test for Quansight/deshaw#656 * Fix missing import not added. This make sure the import get added globally and not to local import block wrappers. * ignore forwardref warnings
1 parent b1a1b2f commit 09f94bc

File tree

4 files changed

+80
-2
lines changed

4 files changed

+80
-2
lines changed

doc/conf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
"prompt_toolkit",
3636
]
3737

38+
suppress_warnings = [
39+
"sphinx_autodoc_typehints.forward_reference",
40+
]
41+
3842
html_theme_options = {
3943
'collapse_navigation': False,
4044
'navigation_depth': -1,

lib/python/pyflyby/_imports2s.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -737,14 +737,17 @@ def select_import_block_by_closest_prefix_match(self, imp, max_lineno):
737737
`SourceToSourceImportBlockTransformation`
738738
"""
739739
# Create a data structure that annotates blocks with data by which
740-
# we'll sort.
740+
# we'll sort. Only consider global import blocks, not local ones
741+
# (wrapped in _LocalImportBlockWrapper), since new imports should
742+
# always be added at the module level.
741743
annotated_blocks = [
742744
( (max([0] + [len(imp.prefix_match(oimp))
743745
for oimp in block.importset.imports]),
744746
block.input.endpos.lineno),
745747
block )
746748
for block in self.import_blocks
747-
if block.input.endpos.lineno <= max_lineno+1 ]
749+
if not isinstance(block, _LocalImportBlockWrapper)
750+
and block.input.endpos.lineno <= max_lineno+1 ]
748751
if not annotated_blocks:
749752
raise NoImportBlockError()
750753
annotated_blocks.sort(key=lambda x: x[0])

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ py3 = "pyflyby._py:py_main"
4343

4444
[project.optional-dependencies]
4545
test = [
46+
'numpy',
4647
'beartype',
4748
'build',
4849
'coverage',

tests/test_imports2s_bis.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,76 @@ def fun_2():
227227
assert output == expected
228228

229229

230+
@pytest.mark.parametrize("tidy_local_imports", (False, True))
231+
@pytest.mark.parametrize(
232+
("input_code", "expected_code"),
233+
[
234+
pytest.param(
235+
"""\
236+
def f():
237+
import math
238+
x = math.inf + inf
239+
""",
240+
"""\
241+
from numpy import inf
242+
243+
def f():
244+
import math
245+
x = math.inf + inf
246+
""",
247+
id="no_global_imports_with_local",
248+
),
249+
pytest.param(
250+
"""\
251+
import os
252+
253+
def f():
254+
import math
255+
x = math.inf + inf
256+
""",
257+
"""\
258+
import os
259+
from numpy import inf
260+
261+
def f():
262+
import math
263+
x = math.inf + inf
264+
""",
265+
id="existing_global_imports_with_local",
266+
),
267+
pytest.param(
268+
"""\
269+
def f():
270+
x = inf
271+
""",
272+
"""\
273+
from numpy import inf
274+
275+
def f():
276+
x = inf
277+
""",
278+
id="no_global_no_local_imports",
279+
),
280+
],
281+
)
282+
def test_add_missing_imports_with_local_imports(
283+
input_code, expected_code, tidy_local_imports
284+
):
285+
"""Regression test: missing imports must be added at module level even when
286+
only local imports exist and tidy_local_imports is enabled.
287+
288+
"""
289+
input_block = PythonBlock(dedent(input_code).lstrip())
290+
db = ImportDB("from numpy import inf")
291+
output = fix_unused_and_missing_imports(
292+
input_block, db=db, add_missing=True, remove_unused=False,
293+
add_mandatory=False, tidy_local_imports=True,
294+
)
295+
result = str(output)
296+
expected = dedent(expected_code).lstrip()
297+
assert result == expected
298+
299+
230300
@pytest.mark.parametrize(
231301
"line,import_str,expected",
232302
[

0 commit comments

Comments
 (0)