Skip to content

Commit 933e135

Browse files
stroxlerfacebook-github-bot
authored andcommitted
Infer: Add support for existing annotations to infer_v2
Summary: The existing infer has the ability to generate complete stub files, which include all explicitly annotated attributes and globals as well as all functions and methods with explicit return annotations. This commit adds support for extracting existing annotations from a module using `libcst`; the next diff adds the ability to combine them with annotations from infer. Reviewed By: pradeep90 Differential Revision: D28918138 fbshipit-source-id: d38be3550e68a1ec31e1c57453979d68163c469d
1 parent e540b66 commit 933e135

File tree

2 files changed

+199
-1
lines changed

2 files changed

+199
-1
lines changed

client/commands/infer_v2.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from pathlib import Path
1515
from typing import cast, Dict, List, Optional, Sequence
1616

17+
import libcst
1718
from typing_extensions import Final
1819
from typing_extensions import TypeAlias
1920

2021
from .. import command_arguments, log
2122
from ..analysis_directory import AnalysisDirectory, resolve_analysis_directory
22-
from ..annotation_collector import AnnotationCollector
2323
from ..configuration import Configuration
2424
from .check import Check
2525
from .infer import dequalify_and_fix_pathlike, split_imports
@@ -176,6 +176,114 @@ class AttributeAnnotation(FieldAnnotation):
176176
parent: str
177177

178178

179+
class AnnotationCollector(libcst.CSTVisitor):
180+
def __init__(self, file_path: str) -> None:
181+
self.qualifier: str = file_path.replace(".py", "")
182+
self.globals_: list[GlobalAnnotation] = []
183+
self.attributes: list[AttributeAnnotation] = []
184+
self.functions: list[FunctionAnnotation] = []
185+
self.methods: list[MethodAnnotation] = []
186+
self.class_name: list[str] = []
187+
188+
def visit_ClassDef(self, node: libcst.ClassDef) -> None:
189+
self.class_name.append(node.name.value)
190+
191+
def leave_ClassDef(self, original_node: libcst.ClassDef) -> None:
192+
self.class_name.remove(original_node.name.value)
193+
194+
def visit_AnnAssign(self, node: libcst.AnnAssign) -> None:
195+
target_name = self._code_for_node(node.target)
196+
if target_name is None:
197+
return
198+
parent = ".".join(self.class_name) if self.class_name else None
199+
name = self._qualified_name(
200+
parent=parent,
201+
target_name=target_name,
202+
)
203+
annotation = TypeAnnotation(self._code_for_node(node.annotation.annotation))
204+
if parent is None:
205+
self.globals_.append(
206+
GlobalAnnotation(
207+
name=name,
208+
annotation=annotation,
209+
)
210+
)
211+
else:
212+
self.attributes.append(
213+
AttributeAnnotation(
214+
parent=parent,
215+
name=name,
216+
annotation=annotation,
217+
)
218+
)
219+
220+
def visit_FunctionDef(self, node: libcst.FunctionDef) -> bool:
221+
if node.returns is None:
222+
return False
223+
parent = ".".join(self.class_name) if self.class_name else None
224+
name = self._qualified_name(
225+
parent=parent,
226+
target_name=node.name.value,
227+
)
228+
return_annotation = TypeAnnotation(self._code_for_node(node.returns.annotation))
229+
parameters = [
230+
Parameter(
231+
name=parameter.name.value,
232+
annotation=TypeAnnotation(
233+
self._code_for_node(
234+
parameter.annotation.annotation
235+
if parameter.annotation
236+
else None
237+
)
238+
),
239+
value=self._code_for_node(parameter.default),
240+
)
241+
for parameter in node.params.params
242+
]
243+
decorators = [
244+
decorator
245+
for decorator in (
246+
self._code_for_node(decorator) for decorator in node.decorators
247+
)
248+
if decorator is not None
249+
]
250+
is_async = node.asynchronous is not None
251+
if parent is None:
252+
self.functions.append(
253+
FunctionAnnotation(
254+
name=name,
255+
return_annotation=return_annotation,
256+
parameters=parameters,
257+
decorators=decorators,
258+
is_async=is_async,
259+
)
260+
)
261+
else:
262+
self.methods.append(
263+
MethodAnnotation(
264+
parent=parent,
265+
name=name,
266+
return_annotation=return_annotation,
267+
parameters=parameters,
268+
decorators=decorators,
269+
is_async=is_async,
270+
)
271+
)
272+
return False
273+
274+
# utility methods
275+
276+
def _qualified_name(self, parent: str | None, target_name: str) -> str:
277+
prefix = f"{self.qualifier}."
278+
if parent is not None:
279+
prefix += f"{parent}."
280+
return prefix + target_name
281+
282+
@staticmethod
283+
def _code_for_node(node: Optional[libcst.CSTNode]) -> Optional[str]:
284+
return None if node is None else libcst.parse_module("").code_for_node(node)
285+
286+
179287
@dataclass
180288
class ModuleAnnotations:
181289
path: str
@@ -252,6 +360,18 @@ def from_infer_output(
252360
],
253361
)
254362

363+
@staticmethod
364+
def from_module(path: str, module: libcst.Module) -> ModuleAnnotations:
365+
collector = AnnotationCollector(file_path=path)
366+
module.visit(collector)
367+
return ModuleAnnotations(
368+
path=path,
369+
globals_=collector.globals_,
370+
attributes=collector.attributes,
371+
functions=collector.functions,
372+
methods=collector.methods,
373+
)
374+
255375
def filter_for_complete(self) -> ModuleAnnotations:
256376
return ModuleAnnotations(
257377
path=self.path,

client/commands/tests/infer_v2_test.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from pathlib import Path
1212
from typing import cast
1313

14+
import libcst
15+
1416
from ...commands.infer_v2 import (
1517
_create_module_annotations,
1618
RawInferOutput,
@@ -564,3 +566,79 @@ def test_stubs_no_typing_import(self) -> None:
564566
def with_params(y=7, x: List[int] = [5]) -> Union[int, str]: ...
565567
""",
566568
)
569+
570+
571+
class ExistingAnnotationsTest(unittest.TestCase):
572+
def _assert_stubs(self, code: str, expected: str) -> None:
573+
module_annotations = ModuleAnnotations.from_module(
574+
path=PATH, module=libcst.parse_module(textwrap.dedent(code))
575+
)
576+
actual = module_annotations.to_stubs()
577+
_assert_stubs_equal(actual, expected)
578+
579+
def test_stubs_from_existing_annotations(self) -> None:
580+
self._assert_stubs(
581+
"""
582+
def foo() -> int:
583+
return 1 + 1
584+
""",
585+
"def foo() -> int: ...",
586+
)
587+
588+
# methods
589+
self._assert_stubs(
590+
"""
591+
class Foo:
592+
def bar(self, x: int) -> Union[int, str]:
593+
return ""
594+
""",
595+
"""\
596+
class Foo:
597+
def bar(self, x: int) -> Union[int, str]: ...
598+
""",
599+
)
600+
601+
# with async
602+
self._assert_stubs(
603+
"""
604+
async def foo() -> int:
605+
return 1 + 1
606+
""",
607+
"async def foo() -> int: ...",
608+
)
609+
610+
# with decorators
611+
self._assert_stubs(
612+
"""
613+
@click
614+
def foo() -> int:
615+
return 1 + 1
616+
""",
617+
"@@click\n\ndef foo() -> int: ...",
618+
)
619+
620+
# globals
621+
self._assert_stubs(
622+
"""
623+
x: int = 10
624+
""",
625+
"x: int = ...",
626+
)
627+
628+
# attributes
629+
self._assert_stubs(
630+
# TODO (T92336996)
631+
# libcst does not produce fully-qualified typenames when extracting
632+
# annotations (unlike the pyre parser). As a result, we're producing
633+
# incorrect stubs here.
634+
"""
635+
from typing import Any
636+
637+
class Foo:
638+
x: Any = 10
639+
""",
640+
"""\
641+
class Foo:
642+
x: Any = ...
643+
""",
644+
)

0 commit comments

Comments
 (0)