Skip to content

Commit 037b75d

Browse files
DanielNoordcdce8p
andauthored
Add parameter typing to assigned_stmts (#1249)
Co-authored-by: Marc Mueller <[email protected]>
1 parent 6ac6b90 commit 037b75d

File tree

2 files changed

+155
-12
lines changed

2 files changed

+155
-12
lines changed

astroid/nodes/node_classes.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import typing
4343
import warnings
4444
from functools import lru_cache
45-
from typing import TYPE_CHECKING, Callable, Generator, Optional
45+
from typing import TYPE_CHECKING, Any, Callable, Generator, Optional, TypeVar, Union
4646

4747
from astroid import decorators, mixins, util
4848
from astroid.bases import Instance, _infer_stmts
@@ -72,6 +72,20 @@ def _is_const(value):
7272
return isinstance(value, tuple(CONST_CLS))
7373

7474

75+
T_Nodes = TypeVar("T_Nodes", bound=NodeNG)
76+
77+
AssignedStmtsPossibleNode = Union["List", "Tuple", "AssignName", "AssignAttr", None]
78+
AssignedStmtsCall = Callable[
79+
[
80+
T_Nodes,
81+
AssignedStmtsPossibleNode,
82+
Optional[InferenceContext],
83+
Optional[typing.List[int]],
84+
],
85+
Any,
86+
]
87+
88+
7589
@decorators.raise_if_nothing_inferred
7690
def unpack_infer(stmt, context=None):
7791
"""recursively generate nodes inferred by the given statement.
@@ -672,6 +686,11 @@ def __init__(
672686
parent=parent,
673687
)
674688

689+
assigned_stmts: AssignedStmtsCall["AssignName"]
690+
"""Returns the assigned statement (non inferred) according to the assignment type.
691+
See astroid/protocols.py for actual implementation.
692+
"""
693+
675694

676695
class DelName(
677696
mixins.NoChildrenMixin, LookupMixIn, mixins.ParentAssignTypeMixin, NodeNG
@@ -993,6 +1012,11 @@ def postinit(
9931012
if type_comment_posonlyargs is not None:
9941013
self.type_comment_posonlyargs = type_comment_posonlyargs
9951014

1015+
assigned_stmts: AssignedStmtsCall["Arguments"]
1016+
"""Returns the assigned statement (non inferred) according to the assignment type.
1017+
See astroid/protocols.py for actual implementation.
1018+
"""
1019+
9961020
def _infer_name(self, frame, name):
9971021
if self.parent is frame:
9981022
return name
@@ -1246,6 +1270,11 @@ def postinit(self, expr: Optional[NodeNG] = None) -> None:
12461270
"""
12471271
self.expr = expr
12481272

1273+
assigned_stmts: AssignedStmtsCall["AssignAttr"]
1274+
"""Returns the assigned statement (non inferred) according to the assignment type.
1275+
See astroid/protocols.py for actual implementation.
1276+
"""
1277+
12491278
def get_children(self):
12501279
yield self.expr
12511280

@@ -1389,6 +1418,11 @@ def postinit(
13891418
self.value = value
13901419
self.type_annotation = type_annotation
13911420

1421+
assigned_stmts: AssignedStmtsCall["Assign"]
1422+
"""Returns the assigned statement (non inferred) according to the assignment type.
1423+
See astroid/protocols.py for actual implementation.
1424+
"""
1425+
13921426
def get_children(self):
13931427
yield from self.targets
13941428

@@ -1481,6 +1515,11 @@ def postinit(
14811515
self.value = value
14821516
self.simple = simple
14831517

1518+
assigned_stmts: AssignedStmtsCall["AnnAssign"]
1519+
"""Returns the assigned statement (non inferred) according to the assignment type.
1520+
See astroid/protocols.py for actual implementation.
1521+
"""
1522+
14841523
def get_children(self):
14851524
yield self.target
14861525
yield self.annotation
@@ -1562,6 +1601,11 @@ def postinit(
15621601
self.target = target
15631602
self.value = value
15641603

1604+
assigned_stmts: AssignedStmtsCall["AugAssign"]
1605+
"""Returns the assigned statement (non inferred) according to the assignment type.
1606+
See astroid/protocols.py for actual implementation.
1607+
"""
1608+
15651609
# This is set by inference.py
15661610
def _infer_augassign(self, context=None):
15671611
raise NotImplementedError
@@ -2028,6 +2072,11 @@ def postinit(
20282072
self.ifs = ifs
20292073
self.is_async = is_async
20302074

2075+
assigned_stmts: AssignedStmtsCall["Comprehension"]
2076+
"""Returns the assigned statement (non inferred) according to the assignment type.
2077+
See astroid/protocols.py for actual implementation.
2078+
"""
2079+
20312080
def assign_type(self):
20322081
"""The type of assignment that this node performs.
20332082
@@ -2715,6 +2764,11 @@ def __init__(
27152764
parent=parent,
27162765
)
27172766

2767+
assigned_stmts: AssignedStmtsCall["ExceptHandler"]
2768+
"""Returns the assigned statement (non inferred) according to the assignment type.
2769+
See astroid/protocols.py for actual implementation.
2770+
"""
2771+
27182772
def get_children(self):
27192773
if self.type is not None:
27202774
yield self.type
@@ -2873,6 +2927,11 @@ def postinit(
28732927
self.orelse = orelse
28742928
self.type_annotation = type_annotation
28752929

2930+
assigned_stmts: AssignedStmtsCall["For"]
2931+
"""Returns the assigned statement (non inferred) according to the assignment type.
2932+
See astroid/protocols.py for actual implementation.
2933+
"""
2934+
28762935
@decorators.cachedproperty
28772936
def blockstart_tolineno(self):
28782937
"""The line on which the beginning of this block ends.
@@ -3573,6 +3632,11 @@ def __init__(
35733632
parent=parent,
35743633
)
35753634

3635+
assigned_stmts: AssignedStmtsCall["List"]
3636+
"""Returns the assigned statement (non inferred) according to the assignment type.
3637+
See astroid/protocols.py for actual implementation.
3638+
"""
3639+
35763640
def pytype(self):
35773641
"""Get the name of the type that this node represents.
35783642
@@ -3999,6 +4063,11 @@ def postinit(self, value: Optional[NodeNG] = None) -> None:
39994063
"""
40004064
self.value = value
40014065

4066+
assigned_stmts: AssignedStmtsCall["Starred"]
4067+
"""Returns the assigned statement (non inferred) according to the assignment type.
4068+
See astroid/protocols.py for actual implementation.
4069+
"""
4070+
40024071
def get_children(self):
40034072
yield self.value
40044073

@@ -4325,6 +4394,11 @@ def __init__(
43254394
parent=parent,
43264395
)
43274396

4397+
assigned_stmts: AssignedStmtsCall["Tuple"]
4398+
"""Returns the assigned statement (non inferred) according to the assignment type.
4399+
See astroid/protocols.py for actual implementation.
4400+
"""
4401+
43284402
def pytype(self):
43294403
"""Get the name of the type that this node represents.
43304404
@@ -4619,6 +4693,11 @@ def postinit(
46194693
self.body = body
46204694
self.type_annotation = type_annotation
46214695

4696+
assigned_stmts: AssignedStmtsCall["With"]
4697+
"""Returns the assigned statement (non inferred) according to the assignment type.
4698+
See astroid/protocols.py for actual implementation.
4699+
"""
4700+
46224701
@decorators.cachedproperty
46234702
def blockstart_tolineno(self):
46244703
"""The line on which the beginning of this block ends.
@@ -4922,6 +5001,11 @@ def postinit(self, target: NodeNG, value: NodeNG) -> None:
49225001
self.target = target
49235002
self.value = value
49245003

5004+
assigned_stmts: AssignedStmtsCall["NamedExpr"]
5005+
"""Returns the assigned statement (non inferred) according to the assignment type.
5006+
See astroid/protocols.py for actual implementation.
5007+
"""
5008+
49255009
def frame(self):
49265010
"""The first parent frame node.
49275011
@@ -5291,6 +5375,9 @@ def postinit(
52915375
],
52925376
Generator[NodeNG, None, None],
52935377
]
5378+
"""Returns the assigned statement (non inferred) according to the assignment type.
5379+
See astroid/protocols.py for actual implementation.
5380+
"""
52945381

52955382

52965383
class MatchClass(Pattern):
@@ -5393,6 +5480,9 @@ def postinit(self, *, name: Optional[AssignName]) -> None:
53935480
],
53945481
Generator[NodeNG, None, None],
53955482
]
5483+
"""Returns the assigned statement (non inferred) according to the assignment type.
5484+
See astroid/protocols.py for actual implementation.
5485+
"""
53965486

53975487

53985488
class MatchAs(mixins.AssignTypeMixin, Pattern):
@@ -5459,6 +5549,9 @@ def postinit(
54595549
],
54605550
Generator[NodeNG, None, None],
54615551
]
5552+
"""Returns the assigned statement (non inferred) according to the assignment type.
5553+
See astroid/protocols.py for actual implementation.
5554+
"""
54625555

54635556

54645557
class MatchOr(Pattern):

astroid/protocols.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import itertools
3333
import operator as operator_mod
3434
import sys
35-
from typing import Generator, Optional
35+
from typing import Any, Generator, List, Optional, Union
3636

3737
from astroid import arguments, bases, decorators, helpers, nodes, util
3838
from astroid.const import Context
@@ -274,7 +274,12 @@ def _resolve_looppart(parts, assign_path, context):
274274

275275

276276
@decorators.raise_if_nothing_inferred
277-
def for_assigned_stmts(self, node=None, context=None, assign_path=None):
277+
def for_assigned_stmts(
278+
self: Union[nodes.For, nodes.Comprehension],
279+
node: node_classes.AssignedStmtsPossibleNode = None,
280+
context: Optional[InferenceContext] = None,
281+
assign_path: Optional[List[int]] = None,
282+
) -> Any:
278283
if isinstance(self, nodes.AsyncFor) or getattr(self, "is_async", False):
279284
# Skip inferring of async code for now
280285
return dict(node=self, unknown=node, assign_path=assign_path, context=context)
@@ -291,7 +296,12 @@ def for_assigned_stmts(self, node=None, context=None, assign_path=None):
291296
nodes.Comprehension.assigned_stmts = for_assigned_stmts
292297

293298

294-
def sequence_assigned_stmts(self, node=None, context=None, assign_path=None):
299+
def sequence_assigned_stmts(
300+
self: Union[nodes.Tuple, nodes.List],
301+
node: node_classes.AssignedStmtsPossibleNode = None,
302+
context: Optional[InferenceContext] = None,
303+
assign_path: Optional[List[int]] = None,
304+
) -> Any:
295305
if assign_path is None:
296306
assign_path = []
297307
try:
@@ -314,7 +324,12 @@ def sequence_assigned_stmts(self, node=None, context=None, assign_path=None):
314324
nodes.List.assigned_stmts = sequence_assigned_stmts
315325

316326

317-
def assend_assigned_stmts(self, node=None, context=None, assign_path=None):
327+
def assend_assigned_stmts(
328+
self: Union[nodes.AssignName, nodes.AssignAttr],
329+
node: node_classes.AssignedStmtsPossibleNode = None,
330+
context: Optional[InferenceContext] = None,
331+
assign_path: Optional[List[int]] = None,
332+
) -> Any:
318333
return self.parent.assigned_stmts(node=self, context=context)
319334

320335

@@ -381,7 +396,12 @@ def _arguments_infer_argname(self, name, context):
381396
yield util.Uninferable
382397

383398

384-
def arguments_assigned_stmts(self, node=None, context=None, assign_path=None):
399+
def arguments_assigned_stmts(
400+
self: nodes.Arguments,
401+
node: node_classes.AssignedStmtsPossibleNode = None,
402+
context: Optional[InferenceContext] = None,
403+
assign_path: Optional[List[int]] = None,
404+
) -> Any:
385405
if context.callcontext:
386406
callee = context.callcontext.callee
387407
while hasattr(callee, "_proxied"):
@@ -406,7 +426,12 @@ def arguments_assigned_stmts(self, node=None, context=None, assign_path=None):
406426

407427

408428
@decorators.raise_if_nothing_inferred
409-
def assign_assigned_stmts(self, node=None, context=None, assign_path=None):
429+
def assign_assigned_stmts(
430+
self: Union[nodes.AugAssign, nodes.Assign, nodes.AnnAssign],
431+
node: node_classes.AssignedStmtsPossibleNode = None,
432+
context: Optional[InferenceContext] = None,
433+
assign_path: Optional[List[int]] = None,
434+
) -> Any:
410435
if not assign_path:
411436
yield self.value
412437
return None
@@ -417,7 +442,12 @@ def assign_assigned_stmts(self, node=None, context=None, assign_path=None):
417442
return dict(node=self, unknown=node, assign_path=assign_path, context=context)
418443

419444

420-
def assign_annassigned_stmts(self, node=None, context=None, assign_path=None):
445+
def assign_annassigned_stmts(
446+
self: nodes.AnnAssign,
447+
node: node_classes.AssignedStmtsPossibleNode = None,
448+
context: Optional[InferenceContext] = None,
449+
assign_path: Optional[List[int]] = None,
450+
) -> Any:
421451
for inferred in assign_assigned_stmts(self, node, context, assign_path):
422452
if inferred is None:
423453
yield util.Uninferable
@@ -471,7 +501,12 @@ def _resolve_assignment_parts(parts, assign_path, context):
471501

472502

473503
@decorators.raise_if_nothing_inferred
474-
def excepthandler_assigned_stmts(self, node=None, context=None, assign_path=None):
504+
def excepthandler_assigned_stmts(
505+
self: nodes.ExceptHandler,
506+
node: node_classes.AssignedStmtsPossibleNode = None,
507+
context: Optional[InferenceContext] = None,
508+
assign_path: Optional[List[int]] = None,
509+
) -> Any:
475510
for assigned in node_classes.unpack_infer(self.type):
476511
if isinstance(assigned, nodes.ClassDef):
477512
assigned = objects.ExceptionInstance(assigned)
@@ -522,7 +557,12 @@ def _infer_context_manager(self, mgr, context):
522557

523558

524559
@decorators.raise_if_nothing_inferred
525-
def with_assigned_stmts(self, node=None, context=None, assign_path=None):
560+
def with_assigned_stmts(
561+
self: nodes.With,
562+
node: node_classes.AssignedStmtsPossibleNode = None,
563+
context: Optional[InferenceContext] = None,
564+
assign_path: Optional[List[int]] = None,
565+
) -> Any:
526566
"""Infer names and other nodes from a *with* statement.
527567
528568
This enables only inference for name binding in a *with* statement.
@@ -595,7 +635,12 @@ def __enter__(self):
595635

596636

597637
@decorators.raise_if_nothing_inferred
598-
def named_expr_assigned_stmts(self, node, context=None, assign_path=None):
638+
def named_expr_assigned_stmts(
639+
self: nodes.NamedExpr,
640+
node: node_classes.AssignedStmtsPossibleNode,
641+
context: Optional[InferenceContext] = None,
642+
assign_path: Optional[List[int]] = None,
643+
) -> Any:
599644
"""Infer names and other nodes from an assignment expression"""
600645
if self.target == node:
601646
yield from self.value.infer(context=context)
@@ -612,7 +657,12 @@ def named_expr_assigned_stmts(self, node, context=None, assign_path=None):
612657

613658

614659
@decorators.yes_if_nothing_inferred
615-
def starred_assigned_stmts(self, node=None, context=None, assign_path=None):
660+
def starred_assigned_stmts(
661+
self: nodes.Starred,
662+
node: node_classes.AssignedStmtsPossibleNode = None,
663+
context: Optional[InferenceContext] = None,
664+
assign_path: Optional[List[int]] = None,
665+
) -> Any:
616666
"""
617667
Arguments:
618668
self: nodes.Starred

0 commit comments

Comments
 (0)