Skip to content

Commit 4e28c0f

Browse files
authored
Fix inference from nested unpacking of iterables (#1311)
1 parent 3ff65be commit 4e28c0f

File tree

4 files changed

+90
-0
lines changed

4 files changed

+90
-0
lines changed

ChangeLog

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ Release date: TBA
2020
* Rename ``ModuleSpec`` -> ``module_type`` constructor parameter to match attribute
2121
name and improve typing. Use ``type`` instead.
2222

23+
* Fixed pylint ``not-callable`` false positive with nested-tuple assignment in a for-loop.
24+
25+
Refs PyCQA/pylint#5113
26+
2327
* Add a bound to the inference tips cache.
2428

2529
Closes #1150

astroid/protocols.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,11 @@ def _resolve_looppart(parts, assign_path, context):
240240
itered = part.itered()
241241
except TypeError:
242242
continue
243+
try:
244+
if isinstance(itered[index], (nodes.Const, nodes.Name)):
245+
itered = [part]
246+
except IndexError:
247+
pass
243248
for stmt in itered:
244249
index_node = nodes.Const(index)
245250
try:

tests/unittest_inference.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,45 @@ def test_simple_for_genexpr(self) -> None:
798798
[i.value for i in test_utils.get_name_node(ast, "e", -1).infer()], [1, 3]
799799
)
800800

801+
def test_for_dict(self) -> None:
802+
code = """
803+
for a, b in {1: 2, 3: 4}.items():
804+
print(a)
805+
print(b)
806+
807+
for c, (d, e) in {1: (2, 3), 4: (5, 6)}.items():
808+
print(c)
809+
print(d)
810+
print(e)
811+
812+
print([(f, g, h) for f, (g, h) in {1: (2, 3), 4: (5, 6)}.items()])
813+
"""
814+
ast = parse(code, __name__)
815+
self.assertEqual(
816+
[i.value for i in test_utils.get_name_node(ast, "a", -1).infer()], [1, 3]
817+
)
818+
self.assertEqual(
819+
[i.value for i in test_utils.get_name_node(ast, "b", -1).infer()], [2, 4]
820+
)
821+
self.assertEqual(
822+
[i.value for i in test_utils.get_name_node(ast, "c", -1).infer()], [1, 4]
823+
)
824+
self.assertEqual(
825+
[i.value for i in test_utils.get_name_node(ast, "d", -1).infer()], [2, 5]
826+
)
827+
self.assertEqual(
828+
[i.value for i in test_utils.get_name_node(ast, "e", -1).infer()], [3, 6]
829+
)
830+
self.assertEqual(
831+
[i.value for i in test_utils.get_name_node(ast, "f", -1).infer()], [1, 4]
832+
)
833+
self.assertEqual(
834+
[i.value for i in test_utils.get_name_node(ast, "g", -1).infer()], [2, 5]
835+
)
836+
self.assertEqual(
837+
[i.value for i in test_utils.get_name_node(ast, "h", -1).infer()], [3, 6]
838+
)
839+
801840
def test_builtin_help(self) -> None:
802841
code = """
803842
help()

tests/unittest_protocols.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,48 @@ def test_assigned_stmts_simple_for(self) -> None:
7070
for2_assnode = next(assign_stmts[1].nodes_of_class(nodes.AssignName))
7171
self.assertRaises(InferenceError, list, for2_assnode.assigned_stmts())
7272

73+
def test_assigned_stmts_nested_for_tuple(self) -> None:
74+
assign_stmts = extract_node(
75+
"""
76+
for a, (b, c) in [(1, (2, 3))]: #@
77+
pass
78+
"""
79+
)
80+
81+
assign_nodes = assign_stmts.nodes_of_class(nodes.AssignName)
82+
83+
for1_assnode = next(assign_nodes)
84+
assigned = list(for1_assnode.assigned_stmts())
85+
self.assertConstNodesEqual([1], assigned)
86+
87+
for2_assnode = next(assign_nodes)
88+
assigned2 = list(for2_assnode.assigned_stmts())
89+
self.assertConstNodesEqual([2], assigned2)
90+
91+
def test_assigned_stmts_nested_for_dict(self) -> None:
92+
assign_stmts = extract_node(
93+
"""
94+
for a, (b, c) in {1: ("a", str), 2: ("b", bytes)}.items(): #@
95+
pass
96+
"""
97+
)
98+
assign_nodes = assign_stmts.nodes_of_class(nodes.AssignName)
99+
100+
# assigned: [1, 2]
101+
for1_assnode = next(assign_nodes)
102+
assigned = list(for1_assnode.assigned_stmts())
103+
self.assertConstNodesEqual([1, 2], assigned)
104+
105+
# assigned2: ["a", "b"]
106+
for2_assnode = next(assign_nodes)
107+
assigned2 = list(for2_assnode.assigned_stmts())
108+
self.assertConstNodesEqual(["a", "b"], assigned2)
109+
110+
# assigned3: [str, bytes]
111+
for3_assnode = next(assign_nodes)
112+
assigned3 = list(for3_assnode.assigned_stmts())
113+
self.assertNameNodesEqual(["str", "bytes"], assigned3)
114+
73115
def test_assigned_stmts_starred_for(self) -> None:
74116
assign_stmts = extract_node(
75117
"""

0 commit comments

Comments
 (0)