Skip to content

Commit de4d7d4

Browse files
committed
add simple return type inference
1 parent 78fb78b commit de4d7d4

File tree

2 files changed

+341
-46
lines changed

2 files changed

+341
-46
lines changed

mypy/stubgen.py

Lines changed: 228 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
import os.path
4848
import sys
4949
import traceback
50-
from typing import Final, Iterable, Iterator
50+
from typing import Final, Iterable, Iterator, Optional
5151

5252
import mypy.build
5353
import mypy.mixedtraverser
@@ -73,43 +73,68 @@
7373
ARG_STAR2,
7474
IS_ABSTRACT,
7575
NOT_ABSTRACT,
76+
AssertTypeExpr,
77+
AssignmentExpr,
7678
AssignmentStmt,
79+
AwaitExpr,
7780
Block,
7881
BytesExpr,
7982
CallExpr,
83+
CastExpr,
8084
ClassDef,
8185
ComparisonExpr,
8286
ComplexExpr,
87+
ConditionalExpr,
8388
Decorator,
8489
DictExpr,
90+
DictionaryComprehension,
8591
EllipsisExpr,
92+
EnumCallExpr,
8693
Expression,
8794
ExpressionStmt,
8895
FloatExpr,
8996
FuncBase,
9097
FuncDef,
98+
GeneratorExpr,
9199
IfStmt,
92100
Import,
93101
ImportAll,
94102
ImportFrom,
95103
IndexExpr,
96104
IntExpr,
105+
LambdaExpr,
106+
ListComprehension,
97107
ListExpr,
98108
MemberExpr,
99109
MypyFile,
110+
NamedTupleExpr,
100111
NameExpr,
112+
NewTypeExpr,
101113
OpExpr,
102114
OverloadedFuncDef,
115+
ParamSpecExpr,
116+
PromoteExpr,
117+
RevealExpr,
118+
SetComprehension,
103119
SetExpr,
120+
SliceExpr,
104121
StarExpr,
105122
Statement,
106123
StrExpr,
124+
SuperExpr,
107125
TempNode,
108126
TupleExpr,
127+
TypeAliasExpr,
109128
TypeAliasStmt,
129+
TypeApplication,
130+
TypedDictExpr,
110131
TypeInfo,
132+
TypeVarExpr,
133+
TypeVarTupleExpr,
111134
UnaryExpr,
112135
Var,
136+
YieldExpr,
137+
YieldFromExpr,
113138
)
114139
from mypy.options import Options as MypyOptions
115140
from mypy.sharedparse import MAGIC_METHODS_POS_ARGS_ONLY
@@ -132,6 +157,7 @@
132157
walk_packages,
133158
)
134159
from mypy.traverser import (
160+
all_return_statements,
135161
all_yield_expressions,
136162
has_return_statement,
137163
has_yield_expression,
@@ -149,7 +175,7 @@
149175
UnboundType,
150176
get_proper_type,
151177
)
152-
from mypy.visitor import NodeVisitor
178+
from mypy.visitor import ExpressionVisitor, NodeVisitor
153179

154180
# Common ways of naming package containing vendored modules.
155181
VENDOR_PACKAGES: Final = ["packages", "vendor", "vendored", "_vendor", "_vendored_packages"]
@@ -455,6 +481,186 @@ def add_ref(self, fullname: str) -> None:
455481
self.refs.add(fullname)
456482

457483

484+
class ExpressionTyper(ExpressionVisitor[Optional[str]]):
485+
containers: set[str | None]
486+
487+
def __init__(self) -> None:
488+
self.containers = set()
489+
490+
def visit_int_expr(self, o: IntExpr) -> str:
491+
return "int"
492+
493+
def visit_str_expr(self, o: StrExpr) -> str:
494+
return "str"
495+
496+
def visit_bytes_expr(self, o: BytesExpr) -> str:
497+
return "bytes"
498+
499+
def visit_float_expr(self, o: FloatExpr) -> str:
500+
return "float"
501+
502+
def visit_complex_expr(self, o: ComplexExpr) -> str:
503+
return "complex"
504+
505+
def visit_comparison_expr(self, o: ComparisonExpr) -> str:
506+
return "bool"
507+
508+
def visit_name_expr(self, o: NameExpr) -> str | None:
509+
if o.name == "True":
510+
return "bool"
511+
elif o.name == "False":
512+
return "bool"
513+
elif o.name == "None":
514+
return "None"
515+
return None
516+
517+
def visit_unary_expr(self, o: UnaryExpr) -> str | None:
518+
if o.op == "not":
519+
return "bool"
520+
return None
521+
522+
def visit_assignment_expr(self, o: AssignmentExpr) -> str | None:
523+
return o.value.accept(self)
524+
525+
def visit_list_expr(self, o: ListExpr) -> str | None:
526+
items: list[str | None] = [item.accept(self) for item in o.items]
527+
if not items:
528+
return None
529+
element_type = items[0]
530+
if element_type is not None and all(item == element_type for item in items):
531+
self.containers.add("List")
532+
return f"List[{element_type}]"
533+
return None
534+
535+
def visit_dict_expr(self, o: DictExpr) -> str | None:
536+
items: list[tuple[str | None, str | None]] = [
537+
((None, None) if key is None else (key.accept(self), value.accept(self)))
538+
for key, value in o.items
539+
]
540+
if not items:
541+
return None
542+
key, value = items[0]
543+
if (
544+
key is not None
545+
and value is not None
546+
and all(k == key and v == value for k, v in items)
547+
):
548+
self.containers.add("Dict")
549+
return f"Dict[{key}, {value}]"
550+
return None
551+
552+
def visit_tuple_expr(self, o: TupleExpr) -> str | None:
553+
items: list[str | None] = [item.accept(self) for item in o.items]
554+
if items and all(item is not None for item in items):
555+
self.containers.add("Tuple")
556+
elements = ", ".join([item for item in items if item is not None])
557+
return f"Tuple[{elements}]"
558+
return None
559+
560+
def visit_set_expr(self, o: SetExpr) -> str | None:
561+
items: list[str | None] = [item.accept(self) for item in o.items]
562+
if not items:
563+
return None
564+
element_type = items[0]
565+
if element_type is not None and all(item == element_type for item in items):
566+
self.containers.add("Set")
567+
return f"Set[{element_type}]"
568+
return None
569+
570+
def visit_ellipsis(self, o: EllipsisExpr) -> None:
571+
return None
572+
573+
def visit_star_expr(self, o: StarExpr) -> None:
574+
return None
575+
576+
def visit_member_expr(self, o: MemberExpr) -> None:
577+
return None
578+
579+
def visit_yield_from_expr(self, o: YieldFromExpr) -> None:
580+
return None
581+
582+
def visit_yield_expr(self, o: YieldExpr) -> None:
583+
return None
584+
585+
def visit_call_expr(self, o: CallExpr) -> None:
586+
return None
587+
588+
def visit_op_expr(self, o: OpExpr) -> None:
589+
return None
590+
591+
def visit_cast_expr(self, o: CastExpr) -> None:
592+
return None
593+
594+
def visit_assert_type_expr(self, o: AssertTypeExpr) -> None:
595+
return None
596+
597+
def visit_reveal_expr(self, o: RevealExpr) -> None:
598+
return None
599+
600+
def visit_super_expr(self, o: SuperExpr) -> None:
601+
return None
602+
603+
def visit_index_expr(self, o: IndexExpr) -> None:
604+
return None
605+
606+
def visit_type_application(self, o: TypeApplication) -> None:
607+
return None
608+
609+
def visit_lambda_expr(self, o: LambdaExpr) -> None:
610+
return None
611+
612+
def visit_list_comprehension(self, o: ListComprehension) -> None:
613+
return None
614+
615+
def visit_set_comprehension(self, o: SetComprehension) -> None:
616+
return None
617+
618+
def visit_dictionary_comprehension(self, o: DictionaryComprehension) -> None:
619+
return None
620+
621+
def visit_generator_expr(self, o: GeneratorExpr) -> None:
622+
return None
623+
624+
def visit_slice_expr(self, o: SliceExpr) -> None:
625+
return None
626+
627+
def visit_conditional_expr(self, o: ConditionalExpr) -> None:
628+
return None
629+
630+
def visit_type_var_expr(self, o: TypeVarExpr) -> None:
631+
return None
632+
633+
def visit_paramspec_expr(self, o: ParamSpecExpr) -> None:
634+
return None
635+
636+
def visit_type_var_tuple_expr(self, o: TypeVarTupleExpr) -> None:
637+
return None
638+
639+
def visit_type_alias_expr(self, o: TypeAliasExpr) -> None:
640+
return None
641+
642+
def visit_namedtuple_expr(self, o: NamedTupleExpr) -> None:
643+
return None
644+
645+
def visit_enum_call_expr(self, o: EnumCallExpr) -> None:
646+
return None
647+
648+
def visit_typeddict_expr(self, o: TypedDictExpr) -> None:
649+
return None
650+
651+
def visit_newtype_expr(self, o: NewTypeExpr) -> None:
652+
return None
653+
654+
def visit__promote_expr(self, o: PromoteExpr) -> None:
655+
return None
656+
657+
def visit_await_expr(self, o: AwaitExpr) -> None:
658+
return None
659+
660+
def visit_temp_node(self, o: TempNode) -> None:
661+
return None
662+
663+
458664
class ASTStubGenerator(BaseStubGenerator, mypy.traverser.TraverserVisitor):
459665
"""Generate stub text from a mypy AST."""
460666

@@ -619,6 +825,26 @@ def _get_func_return(self, o: FuncDef, ctx: FunctionContext) -> str | None:
619825
return f"{generator_name}[{yield_name}]"
620826
if not has_return_statement(o) and o.abstract_status == NOT_ABSTRACT:
621827
return "None"
828+
if has_return_statement(o) and o.abstract_status == NOT_ABSTRACT:
829+
return_expressions = [ret.expr for ret in all_return_statements(o)]
830+
return_type_visitor = ExpressionTyper()
831+
return_types = [
832+
ret.accept(return_type_visitor) if ret is not None else "None"
833+
for ret in return_expressions
834+
]
835+
if not all(return_types):
836+
return None
837+
return_type_set = set(return_types)
838+
if len(return_type_set) == 2 and "None" in return_type_set:
839+
for name in return_type_visitor.containers:
840+
self.add_name(f"typing.{name}")
841+
return_type_set.remove("None")
842+
inner_type = next(iter(return_type_set))
843+
return f"{inner_type} | None"
844+
elif len(return_type_set) == 1:
845+
for name in return_type_visitor.containers:
846+
self.add_name(f"typing.{name}")
847+
return next(iter(return_type_set))
622848
return None
623849

624850
def _get_func_docstring(self, node: FuncDef) -> str | None:

0 commit comments

Comments
 (0)