Skip to content

Commit 89e1b9d

Browse files
stroxlerfacebook-github-bot
authored andcommitted
Infer: Allow combining two ModuleAnnotations
Summary: The `__add__` function combines the annotations in two `ModuleAnnotations` instances, preferring annotations from the left when there are collisions. The current purpose is to be able to combine together existing annotations with the results of pyre infer (preferring infer output when both are available) when making full stub files. Reviewed By: pradeep90 Differential Revision: D28935186 fbshipit-source-id: c960ecd4f7976496a3fd96d13aea4aa2384816f1
1 parent 933e135 commit 89e1b9d

File tree

2 files changed

+110
-1
lines changed

2 files changed

+110
-1
lines changed

client/commands/infer_v2.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,45 @@ def from_module(path: str, module: libcst.Module) -> ModuleAnnotations:
372372
methods=collector.methods,
373373
)
374374

375+
def __add__(self, other: ModuleAnnotations) -> ModuleAnnotations:
376+
"""
377+
Combine two sets of annotations for the same module, preferring
378+
those from the right where there are collisions.
379+
380+
One use for this is to add existing annotations to those from infer
381+
when generating full stub files.
382+
"""
383+
if self.path != other.path:
384+
raise ValueError(
385+
"Cannot add ModuleAnnotations from different paths "
386+
+ f"{self.path!r} vs {other.path!r}"
387+
)
388+
other_names: set[str] = {
389+
annotation.name
390+
for annotation in
391+
# pyre-ignore[58] : list[A] + list[B] = list[A | B]
392+
other.globals_ + other.attributes + other.functions + other.methods
393+
}
394+
return ModuleAnnotations(
395+
path=self.path,
396+
globals_=other.globals_
397+
+ [global_ for global_ in self.globals_ if global_.name not in other_names],
398+
attributes=other.attributes
399+
+ [
400+
attribute
401+
for attribute in self.attributes
402+
if attribute.name not in other_names
403+
],
404+
functions=other.functions
405+
+ [
406+
function
407+
for function in self.functions
408+
if function.name not in other_names
409+
],
410+
methods=other.methods
411+
+ [method for method in self.methods if method.name not in other_names],
412+
)
413+
375414
def filter_for_complete(self) -> ModuleAnnotations:
376415
return ModuleAnnotations(
377416
path=self.path,

client/commands/tests/infer_v2_test.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,10 +569,15 @@ def with_params(y=7, x: List[int] = [5]) -> Union[int, str]: ...
569569

570570

571571
class ExistingAnnotationsTest(unittest.TestCase):
572-
def _assert_stubs(self, code: str, expected: str) -> None:
572+
def _assert_stubs(self, code: str, expected: str, data: dict | None = None) -> None:
573573
module_annotations = ModuleAnnotations.from_module(
574574
path=PATH, module=libcst.parse_module(textwrap.dedent(code))
575575
)
576+
if data is not None:
577+
module_annotations += _create_test_module_annotations(
578+
data=data,
579+
complete_only=False,
580+
)
576581
actual = module_annotations.to_stubs()
577582
_assert_stubs_equal(actual, expected)
578583

@@ -642,3 +647,68 @@ class Foo:
642647
x: Any = ...
643648
""",
644649
)
650+
651+
def test_stubs_combining_annotations(self) -> None:
652+
self._assert_stubs(
653+
"""
654+
x: object = 1 + 1
655+
y: int = 1 + 1
656+
def f() -> object: return 10
657+
def g() -> str: return "hello"
658+
class Foo:
659+
x: object = 10
660+
y: int = 10
661+
def f(self) -> object: return 10
662+
def g(self) -> str: return "hello"
663+
""",
664+
data={
665+
"globals": [
666+
{
667+
"name": "test.x",
668+
"annotation": "int",
669+
}
670+
],
671+
"attributes": [
672+
{
673+
"name": "test.Foo.x",
674+
"parent": "test.Foo",
675+
"annotation": "int",
676+
}
677+
],
678+
"defines": [
679+
{
680+
"name": "test.f",
681+
"parent": None,
682+
"return": "int",
683+
"parameters": [],
684+
"decorators": [],
685+
"async": False,
686+
},
687+
{
688+
"name": "test.Foo.f",
689+
"parent": "Foo",
690+
"return": "int",
691+
"parameters": [
692+
{
693+
"name": "self",
694+
"annotation": None,
695+
"value": None,
696+
}
697+
],
698+
"decorators": [],
699+
"async": False,
700+
},
701+
],
702+
},
703+
expected="""\
704+
x: int = ...
705+
y: int = ...
706+
def f() -> int: ...
707+
def g() -> str: ...
708+
class Foo:
709+
x: int = ...
710+
y: int = ...
711+
def f(self) -> int: ...
712+
def g(self) -> str: ...
713+
""",
714+
)

0 commit comments

Comments
 (0)