Skip to content

Commit 1647bff

Browse files
committed
[stubgen] Improve self annotations
1 parent bac9984 commit 1647bff

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

mypy/stubgen.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -648,11 +648,11 @@ def visit_func_def(self, o: FuncDef) -> None:
648648
self.add("\n")
649649
if not self.is_top_level():
650650
self_inits = find_self_initializers(o)
651-
for init, value in self_inits:
651+
for init, value, annotation in self_inits:
652652
if init in self.method_names:
653653
# Can't have both an attribute and a method/property with the same name.
654654
continue
655-
init_code = self.get_init(init, value)
655+
init_code = self.get_init(init, value, annotation)
656656
if init_code:
657657
self.add(init_code)
658658

@@ -1414,7 +1414,7 @@ def find_method_names(defs: list[Statement]) -> set[str]:
14141414

14151415
class SelfTraverser(mypy.traverser.TraverserVisitor):
14161416
def __init__(self) -> None:
1417-
self.results: list[tuple[str, Expression]] = []
1417+
self.results: list[tuple[str, Expression, Type | None]] = []
14181418

14191419
def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
14201420
lvalue = o.lvalues[0]
@@ -1423,10 +1423,10 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
14231423
and isinstance(lvalue.expr, NameExpr)
14241424
and lvalue.expr.name == "self"
14251425
):
1426-
self.results.append((lvalue.name, o.rvalue))
1426+
self.results.append((lvalue.name, o.rvalue, o.unanalyzed_type))
14271427

14281428

1429-
def find_self_initializers(fdef: FuncBase) -> list[tuple[str, Expression]]:
1429+
def find_self_initializers(fdef: FuncBase) -> list[tuple[str, Expression, Type | None]]:
14301430
"""Find attribute initializers in a method.
14311431
14321432
Return a list of pairs (attribute name, r.h.s. expression).

test-data/unit/stubgen.test

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,13 +238,20 @@ class C:
238238
def __init__(self, x: str) -> None: ...
239239

240240
[case testSelfAssignment]
241+
from typing import Any, Dict, Union
241242
class C:
242243
def __init__(self):
243244
self.x = 1
244245
x.y = 2
246+
self.y: Dict[str, Any] = {}
247+
self.z: Union[int, str, bool, None] = None
245248
[out]
249+
from typing import Any
250+
246251
class C:
247252
x: int
253+
y: dict[str, Any]
254+
z: int | str | bool | None
248255
def __init__(self) -> None: ...
249256

250257
[case testSelfAndClassBodyAssignment]

0 commit comments

Comments
 (0)