Skip to content

Commit 9121672

Browse files
superbobrycopybara-github
authored andcommitted
Added an overlay for functools.partial
The goal of the overlay is to delay return type inference until the partial object is called, allowing pytype to use all available arguments instead of just the ones provided to `functools.partial`. The overlay is currently gated by a feature flag to avoid breaking existing (but ill-typed) code. PiperOrigin-RevId: 813149022
1 parent e32392c commit 9121672

File tree

5 files changed

+182
-0
lines changed

5 files changed

+182
-0
lines changed

pytype/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,11 @@ def add_options(o, arglist):
295295
_flag(
296296
"--use-fiddle-overlay", False, "Support the third-party fiddle library."
297297
),
298+
_flag(
299+
"--use-functools-partial-overlay",
300+
False,
301+
"Enable precise checks when calling functools.partial objects.",
302+
),
298303
] + _OPT_IN_FEATURES
299304

300305

pytype/overlays/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ py_library(
161161
DEPS
162162
.overlay
163163
.special_builtins
164+
pytype.abstract.abstract
165+
pytype.typegraph.cfg
164166
)
165167

166168
py_library(

pytype/overlays/functools_overlay.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,21 @@
11
"""Overlay for functools."""
22

3+
from __future__ import annotations
4+
5+
from collections.abc import Mapping, Sequence
6+
import threading
7+
from typing import Any, Self, TYPE_CHECKING
8+
9+
from pytype.abstract import abstract
10+
from pytype.abstract import function
11+
from pytype.abstract import mixin
312
from pytype.overlays import overlay
413
from pytype.overlays import special_builtins
14+
from pytype.typegraph import cfg
15+
16+
if TYPE_CHECKING:
17+
from pytype import context # pylint: disable=g-import-not-at-top
18+
519

620
_MODULE_NAME = "functools"
721

@@ -15,5 +29,131 @@ def __init__(self, ctx):
1529
"cached_property", special_builtins.Property.make_alias
1630
),
1731
}
32+
if ctx.options.use_functools_partial_overlay:
33+
member_map["partial"] = Partial
1834
ast = ctx.loader.import_name(_MODULE_NAME)
1935
super().__init__(ctx, _MODULE_NAME, member_map, ast)
36+
37+
38+
class Partial(abstract.PyTDClass, mixin.HasSlots):
39+
"""Implementation of functools.partial."""
40+
41+
def __init__(self, ctx: "context.Context", module: str):
42+
pytd_cls = ctx.loader.lookup_pytd(module, "partial")
43+
super().__init__("partial", pytd_cls, ctx)
44+
mixin.HasSlots.init_mixin(self)
45+
46+
self._pytd_new = self.pytd_cls.Lookup("__new__")
47+
48+
def new_slot(
49+
self, node, cls, *args, **kwargs
50+
) -> tuple[cfg.CFGNode, cfg.Variable]:
51+
# Make sure the call is well typed before binding the partial
52+
new = self.ctx.convert.convert_pytd_function(self._pytd_new)
53+
_, specialized_obj = function.call_function(
54+
self.ctx,
55+
node,
56+
new.to_variable(node),
57+
function.Args((cls, *args), kwargs),
58+
fallback_to_unsolvable=False,
59+
)
60+
[specialized_obj] = specialized_obj.data
61+
type_arg = specialized_obj.get_formal_type_parameter("_T")
62+
[cls] = cls.data
63+
cls = abstract.ParameterizedClass(cls, {"_T": type_arg}, self.ctx)
64+
obj = bind_partial(node, cls, args, kwargs, self.ctx)
65+
return node, obj.to_variable(node)
66+
67+
def get_own_new(self, node, value) -> tuple[cfg.CFGNode, cfg.Variable]:
68+
new = abstract.NativeFunction("__new__", self.new_slot, self.ctx)
69+
return node, new.to_variable(node)
70+
71+
72+
def bind_partial(node, cls, args, kwargs, ctx) -> BoundPartial:
73+
del node # Unused.
74+
obj = BoundPartial(ctx, cls)
75+
obj.underlying = args[0]
76+
obj.args = args[1:]
77+
obj.kwargs = kwargs
78+
return obj
79+
80+
81+
class CallContext(threading.local):
82+
"""A thread-local context for ``NativeFunction.call``."""
83+
84+
starargs: cfg.Variable | None = None
85+
starstarargs: cfg.Variable | None = None
86+
87+
def forward(
88+
self, starargs: cfg.Variable | None, starstarargs: cfg.Variable | None
89+
) -> Self:
90+
self.starargs = starargs
91+
self.starstarargs = starstarargs
92+
return self
93+
94+
def __enter__(self) -> Self:
95+
return self
96+
97+
def __exit__(self, *exc_info) -> None:
98+
self.starargs = None
99+
self.starstarargs = None
100+
101+
102+
call_context = CallContext()
103+
104+
105+
class NativeFunction(abstract.NativeFunction):
106+
"""A native function that forwards *args and **kwargs to the underlying function."""
107+
108+
def call(
109+
self,
110+
node: cfg.CFGNode,
111+
func: cfg.Binding,
112+
args: function.Args,
113+
alias_map: Any | None = None,
114+
) -> tuple[cfg.CFGNode, cfg.Variable]:
115+
# ``NativeFunction.call`` does not forward *args and **kwargs to the
116+
# underlying function, so we do it here to avoid changing core pytype APIs.
117+
starargs = args.starargs
118+
starstarargs = args.starstarargs
119+
if starargs is not None:
120+
starargs = starargs.AssignToNewVariable(node)
121+
if starstarargs is not None:
122+
starstarargs = starstarargs.AssignToNewVariable(node)
123+
with call_context.forward(starargs, starstarargs):
124+
return super().call(node, func, args, alias_map)
125+
126+
127+
class BoundPartial(abstract.Instance, mixin.HasSlots):
128+
"""An instance of functools.partial."""
129+
130+
underlying: cfg.Variable
131+
args: Sequence[cfg.Variable]
132+
kwargs: Mapping[str, cfg.Variable]
133+
134+
def __init__(self, ctx, cls, container=None):
135+
super().__init__(cls, ctx, container)
136+
mixin.HasSlots.init_mixin(self)
137+
self.set_slot(
138+
"__call__", NativeFunction("__call__", self.call_slot, self.ctx)
139+
)
140+
141+
@property
142+
def func(self) -> cfg.Variable:
143+
# The ``func`` attribute marks this class as a wrapper for
144+
# ``maybe_unwrap_decorated_function``.
145+
return self.underlying
146+
147+
def call_slot(self, node: cfg.CFGNode, *args, **kwargs):
148+
return function.call_function(
149+
self.ctx,
150+
node,
151+
self.underlying,
152+
function.Args(
153+
(*self.args, *args),
154+
{**self.kwargs, **kwargs},
155+
call_context.starargs,
156+
call_context.starstarargs,
157+
),
158+
fallback_to_unsolvable=False,
159+
)

pytype/tests/test_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def setUp(self):
9393
strict_primitive_comparisons=True,
9494
strict_none_binding=True,
9595
use_fiddle_overlay=True,
96+
use_functools_partial_overlay=True,
9697
use_rewrite=_USE_REWRITE,
9798
validate_version=False,
9899
)

pytype/tests/test_functions1.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,40 @@ def f(a, b=None):
10791079
partial_f(0)
10801080
""")
10811081

1082+
def test_functools_partial_overloaded(self):
1083+
self.Check("""
1084+
import functools
1085+
from typing import overload
1086+
@overload
1087+
def f(x: int) -> int: ...
1088+
@overload
1089+
def f(x: str) -> str: ...
1090+
def f(x):
1091+
return x
1092+
partial_f = functools.partial(f)
1093+
# TODO(slebedev): This should be functools.partial[int | str].
1094+
assert_type(partial_f, functools.partial)
1095+
assert_type(partial_f(1), int)
1096+
assert_type(partial_f("s"), str)
1097+
""")
1098+
1099+
def test_functools_partial_overloaded_with_star(self):
1100+
self.Check("""
1101+
import functools
1102+
from typing import overload
1103+
@overload
1104+
def f(x: int, y: int) -> int: ...
1105+
@overload
1106+
def f(x: str, y: str) -> str: ...
1107+
def f(x, y):
1108+
return x
1109+
partial_f = functools.partial(f, 42)
1110+
def test(*args):
1111+
# TODO(slebedev): This should be functools.partial[int].
1112+
assert_type(partial_f, functools.partial)
1113+
assert_type(partial_f(*args), int)
1114+
""")
1115+
10821116
def test_functools_partial_class(self):
10831117
self.Check("""
10841118
import functools

0 commit comments

Comments
 (0)