|
47 | 47 | import os.path |
48 | 48 | import sys |
49 | 49 | import traceback |
50 | | -from typing import Final, Iterable, Iterator |
| 50 | +from typing import Final, Iterable, Iterator, Optional |
51 | 51 |
|
52 | 52 | import mypy.build |
53 | 53 | import mypy.mixedtraverser |
|
73 | 73 | ARG_STAR2, |
74 | 74 | IS_ABSTRACT, |
75 | 75 | NOT_ABSTRACT, |
| 76 | + AssertTypeExpr, |
| 77 | + AssignmentExpr, |
76 | 78 | AssignmentStmt, |
| 79 | + AwaitExpr, |
77 | 80 | Block, |
78 | 81 | BytesExpr, |
79 | 82 | CallExpr, |
| 83 | + CastExpr, |
80 | 84 | ClassDef, |
81 | 85 | ComparisonExpr, |
82 | 86 | ComplexExpr, |
| 87 | + ConditionalExpr, |
83 | 88 | Decorator, |
84 | 89 | DictExpr, |
| 90 | + DictionaryComprehension, |
85 | 91 | EllipsisExpr, |
| 92 | + EnumCallExpr, |
86 | 93 | Expression, |
87 | 94 | ExpressionStmt, |
88 | 95 | FloatExpr, |
89 | 96 | FuncBase, |
90 | 97 | FuncDef, |
| 98 | + GeneratorExpr, |
91 | 99 | IfStmt, |
92 | 100 | Import, |
93 | 101 | ImportAll, |
94 | 102 | ImportFrom, |
95 | 103 | IndexExpr, |
96 | 104 | IntExpr, |
| 105 | + LambdaExpr, |
| 106 | + ListComprehension, |
97 | 107 | ListExpr, |
98 | 108 | MemberExpr, |
99 | 109 | MypyFile, |
| 110 | + NamedTupleExpr, |
100 | 111 | NameExpr, |
| 112 | + NewTypeExpr, |
101 | 113 | OpExpr, |
102 | 114 | OverloadedFuncDef, |
| 115 | + ParamSpecExpr, |
| 116 | + PromoteExpr, |
| 117 | + RevealExpr, |
| 118 | + SetComprehension, |
103 | 119 | SetExpr, |
| 120 | + SliceExpr, |
104 | 121 | StarExpr, |
105 | 122 | Statement, |
106 | 123 | StrExpr, |
| 124 | + SuperExpr, |
107 | 125 | TempNode, |
108 | 126 | TupleExpr, |
| 127 | + TypeAliasExpr, |
109 | 128 | TypeAliasStmt, |
| 129 | + TypeApplication, |
| 130 | + TypedDictExpr, |
110 | 131 | TypeInfo, |
| 132 | + TypeVarExpr, |
| 133 | + TypeVarTupleExpr, |
111 | 134 | UnaryExpr, |
112 | 135 | Var, |
| 136 | + YieldExpr, |
| 137 | + YieldFromExpr, |
113 | 138 | ) |
114 | 139 | from mypy.options import Options as MypyOptions |
115 | 140 | from mypy.sharedparse import MAGIC_METHODS_POS_ARGS_ONLY |
|
132 | 157 | walk_packages, |
133 | 158 | ) |
134 | 159 | from mypy.traverser import ( |
| 160 | + all_return_statements, |
135 | 161 | all_yield_expressions, |
136 | 162 | has_return_statement, |
137 | 163 | has_yield_expression, |
|
149 | 175 | UnboundType, |
150 | 176 | get_proper_type, |
151 | 177 | ) |
152 | | -from mypy.visitor import NodeVisitor |
| 178 | +from mypy.visitor import ExpressionVisitor, NodeVisitor |
153 | 179 |
|
154 | 180 | # Common ways of naming package containing vendored modules. |
155 | 181 | VENDOR_PACKAGES: Final = ["packages", "vendor", "vendored", "_vendor", "_vendored_packages"] |
@@ -455,6 +481,186 @@ def add_ref(self, fullname: str) -> None: |
455 | 481 | self.refs.add(fullname) |
456 | 482 |
|
457 | 483 |
|
| 484 | +class ExpressionTyper(ExpressionVisitor[Optional[str]]): |
| 485 | + containers: set[str | None] |
| 486 | + |
| 487 | + def __init__(self) -> None: |
| 488 | + self.containers = set() |
| 489 | + |
| 490 | + def visit_int_expr(self, o: IntExpr) -> str: |
| 491 | + return "int" |
| 492 | + |
| 493 | + def visit_str_expr(self, o: StrExpr) -> str: |
| 494 | + return "str" |
| 495 | + |
| 496 | + def visit_bytes_expr(self, o: BytesExpr) -> str: |
| 497 | + return "bytes" |
| 498 | + |
| 499 | + def visit_float_expr(self, o: FloatExpr) -> str: |
| 500 | + return "float" |
| 501 | + |
| 502 | + def visit_complex_expr(self, o: ComplexExpr) -> str: |
| 503 | + return "complex" |
| 504 | + |
| 505 | + def visit_comparison_expr(self, o: ComparisonExpr) -> str: |
| 506 | + return "bool" |
| 507 | + |
| 508 | + def visit_name_expr(self, o: NameExpr) -> str | None: |
| 509 | + if o.name == "True": |
| 510 | + return "bool" |
| 511 | + elif o.name == "False": |
| 512 | + return "bool" |
| 513 | + elif o.name == "None": |
| 514 | + return "None" |
| 515 | + return None |
| 516 | + |
| 517 | + def visit_unary_expr(self, o: UnaryExpr) -> str | None: |
| 518 | + if o.op == "not": |
| 519 | + return "bool" |
| 520 | + return None |
| 521 | + |
| 522 | + def visit_assignment_expr(self, o: AssignmentExpr) -> str | None: |
| 523 | + return o.value.accept(self) |
| 524 | + |
| 525 | + def visit_list_expr(self, o: ListExpr) -> str | None: |
| 526 | + items: list[str | None] = [item.accept(self) for item in o.items] |
| 527 | + if not items: |
| 528 | + return None |
| 529 | + element_type = items[0] |
| 530 | + if element_type is not None and all(item == element_type for item in items): |
| 531 | + self.containers.add("List") |
| 532 | + return f"List[{element_type}]" |
| 533 | + return None |
| 534 | + |
| 535 | + def visit_dict_expr(self, o: DictExpr) -> str | None: |
| 536 | + items: list[tuple[str | None, str | None]] = [ |
| 537 | + ((None, None) if key is None else (key.accept(self), value.accept(self))) |
| 538 | + for key, value in o.items |
| 539 | + ] |
| 540 | + if not items: |
| 541 | + return None |
| 542 | + key, value = items[0] |
| 543 | + if ( |
| 544 | + key is not None |
| 545 | + and value is not None |
| 546 | + and all(k == key and v == value for k, v in items) |
| 547 | + ): |
| 548 | + self.containers.add("Dict") |
| 549 | + return f"Dict[{key}, {value}]" |
| 550 | + return None |
| 551 | + |
| 552 | + def visit_tuple_expr(self, o: TupleExpr) -> str | None: |
| 553 | + items: list[str | None] = [item.accept(self) for item in o.items] |
| 554 | + if items and all(item is not None for item in items): |
| 555 | + self.containers.add("Tuple") |
| 556 | + elements = ", ".join([item for item in items if item is not None]) |
| 557 | + return f"Tuple[{elements}]" |
| 558 | + return None |
| 559 | + |
| 560 | + def visit_set_expr(self, o: SetExpr) -> str | None: |
| 561 | + items: list[str | None] = [item.accept(self) for item in o.items] |
| 562 | + if not items: |
| 563 | + return None |
| 564 | + element_type = items[0] |
| 565 | + if element_type is not None and all(item == element_type for item in items): |
| 566 | + self.containers.add("Set") |
| 567 | + return f"Set[{element_type}]" |
| 568 | + return None |
| 569 | + |
| 570 | + def visit_ellipsis(self, o: EllipsisExpr) -> None: |
| 571 | + return None |
| 572 | + |
| 573 | + def visit_star_expr(self, o: StarExpr) -> None: |
| 574 | + return None |
| 575 | + |
| 576 | + def visit_member_expr(self, o: MemberExpr) -> None: |
| 577 | + return None |
| 578 | + |
| 579 | + def visit_yield_from_expr(self, o: YieldFromExpr) -> None: |
| 580 | + return None |
| 581 | + |
| 582 | + def visit_yield_expr(self, o: YieldExpr) -> None: |
| 583 | + return None |
| 584 | + |
| 585 | + def visit_call_expr(self, o: CallExpr) -> None: |
| 586 | + return None |
| 587 | + |
| 588 | + def visit_op_expr(self, o: OpExpr) -> None: |
| 589 | + return None |
| 590 | + |
| 591 | + def visit_cast_expr(self, o: CastExpr) -> None: |
| 592 | + return None |
| 593 | + |
| 594 | + def visit_assert_type_expr(self, o: AssertTypeExpr) -> None: |
| 595 | + return None |
| 596 | + |
| 597 | + def visit_reveal_expr(self, o: RevealExpr) -> None: |
| 598 | + return None |
| 599 | + |
| 600 | + def visit_super_expr(self, o: SuperExpr) -> None: |
| 601 | + return None |
| 602 | + |
| 603 | + def visit_index_expr(self, o: IndexExpr) -> None: |
| 604 | + return None |
| 605 | + |
| 606 | + def visit_type_application(self, o: TypeApplication) -> None: |
| 607 | + return None |
| 608 | + |
| 609 | + def visit_lambda_expr(self, o: LambdaExpr) -> None: |
| 610 | + return None |
| 611 | + |
| 612 | + def visit_list_comprehension(self, o: ListComprehension) -> None: |
| 613 | + return None |
| 614 | + |
| 615 | + def visit_set_comprehension(self, o: SetComprehension) -> None: |
| 616 | + return None |
| 617 | + |
| 618 | + def visit_dictionary_comprehension(self, o: DictionaryComprehension) -> None: |
| 619 | + return None |
| 620 | + |
| 621 | + def visit_generator_expr(self, o: GeneratorExpr) -> None: |
| 622 | + return None |
| 623 | + |
| 624 | + def visit_slice_expr(self, o: SliceExpr) -> None: |
| 625 | + return None |
| 626 | + |
| 627 | + def visit_conditional_expr(self, o: ConditionalExpr) -> None: |
| 628 | + return None |
| 629 | + |
| 630 | + def visit_type_var_expr(self, o: TypeVarExpr) -> None: |
| 631 | + return None |
| 632 | + |
| 633 | + def visit_paramspec_expr(self, o: ParamSpecExpr) -> None: |
| 634 | + return None |
| 635 | + |
| 636 | + def visit_type_var_tuple_expr(self, o: TypeVarTupleExpr) -> None: |
| 637 | + return None |
| 638 | + |
| 639 | + def visit_type_alias_expr(self, o: TypeAliasExpr) -> None: |
| 640 | + return None |
| 641 | + |
| 642 | + def visit_namedtuple_expr(self, o: NamedTupleExpr) -> None: |
| 643 | + return None |
| 644 | + |
| 645 | + def visit_enum_call_expr(self, o: EnumCallExpr) -> None: |
| 646 | + return None |
| 647 | + |
| 648 | + def visit_typeddict_expr(self, o: TypedDictExpr) -> None: |
| 649 | + return None |
| 650 | + |
| 651 | + def visit_newtype_expr(self, o: NewTypeExpr) -> None: |
| 652 | + return None |
| 653 | + |
| 654 | + def visit__promote_expr(self, o: PromoteExpr) -> None: |
| 655 | + return None |
| 656 | + |
| 657 | + def visit_await_expr(self, o: AwaitExpr) -> None: |
| 658 | + return None |
| 659 | + |
| 660 | + def visit_temp_node(self, o: TempNode) -> None: |
| 661 | + return None |
| 662 | + |
| 663 | + |
458 | 664 | class ASTStubGenerator(BaseStubGenerator, mypy.traverser.TraverserVisitor): |
459 | 665 | """Generate stub text from a mypy AST.""" |
460 | 666 |
|
@@ -619,6 +825,26 @@ def _get_func_return(self, o: FuncDef, ctx: FunctionContext) -> str | None: |
619 | 825 | return f"{generator_name}[{yield_name}]" |
620 | 826 | if not has_return_statement(o) and o.abstract_status == NOT_ABSTRACT: |
621 | 827 | return "None" |
| 828 | + if has_return_statement(o) and o.abstract_status == NOT_ABSTRACT: |
| 829 | + return_expressions = [ret.expr for ret in all_return_statements(o)] |
| 830 | + return_type_visitor = ExpressionTyper() |
| 831 | + return_types = [ |
| 832 | + ret.accept(return_type_visitor) if ret is not None else "None" |
| 833 | + for ret in return_expressions |
| 834 | + ] |
| 835 | + if not all(return_types): |
| 836 | + return None |
| 837 | + return_type_set = set(return_types) |
| 838 | + if len(return_type_set) == 2 and "None" in return_type_set: |
| 839 | + for name in return_type_visitor.containers: |
| 840 | + self.add_name(f"typing.{name}") |
| 841 | + return_type_set.remove("None") |
| 842 | + inner_type = next(iter(return_type_set)) |
| 843 | + return f"{inner_type} | None" |
| 844 | + elif len(return_type_set) == 1: |
| 845 | + for name in return_type_visitor.containers: |
| 846 | + self.add_name(f"typing.{name}") |
| 847 | + return next(iter(return_type_set)) |
622 | 848 | return None |
623 | 849 |
|
624 | 850 | def _get_func_docstring(self, node: FuncDef) -> str | None: |
|
0 commit comments