11"""Overlay for functools."""
22
3+ from __future__ import annotations
4+
5+ from collections .abc import Mapping , Sequence
6+ import threading
7+ from typing import Any , Self , TYPE_CHECKING
8+
9+ from pytype .abstract import abstract
10+ from pytype .abstract import function
11+ from pytype .abstract import mixin
312from pytype .overlays import overlay
413from pytype .overlays import special_builtins
14+ from pytype .typegraph import cfg
15+
16+ if TYPE_CHECKING :
17+ from pytype import context # pylint: disable=g-import-not-at-top
18+
519
620_MODULE_NAME = "functools"
721
@@ -15,5 +29,131 @@ def __init__(self, ctx):
1529 "cached_property" , special_builtins .Property .make_alias
1630 ),
1731 }
32+ if ctx .options .use_functools_partial_overlay :
33+ member_map ["partial" ] = Partial
1834 ast = ctx .loader .import_name (_MODULE_NAME )
1935 super ().__init__ (ctx , _MODULE_NAME , member_map , ast )
36+
37+
38+ class Partial (abstract .PyTDClass , mixin .HasSlots ):
39+ """Implementation of functools.partial."""
40+
41+ def __init__ (self , ctx : "context.Context" , module : str ):
42+ pytd_cls = ctx .loader .lookup_pytd (module , "partial" )
43+ super ().__init__ ("partial" , pytd_cls , ctx )
44+ mixin .HasSlots .init_mixin (self )
45+
46+ self ._pytd_new = self .pytd_cls .Lookup ("__new__" )
47+
48+ def new_slot (
49+ self , node , cls , * args , ** kwargs
50+ ) -> tuple [cfg .CFGNode , cfg .Variable ]:
51+ # Make sure the call is well typed before binding the partial
52+ new = self .ctx .convert .convert_pytd_function (self ._pytd_new )
53+ _ , specialized_obj = function .call_function (
54+ self .ctx ,
55+ node ,
56+ new .to_variable (node ),
57+ function .Args ((cls , * args ), kwargs ),
58+ fallback_to_unsolvable = False ,
59+ )
60+ [specialized_obj ] = specialized_obj .data
61+ type_arg = specialized_obj .get_formal_type_parameter ("_T" )
62+ [cls ] = cls .data
63+ cls = abstract .ParameterizedClass (cls , {"_T" : type_arg }, self .ctx )
64+ obj = bind_partial (node , cls , args , kwargs , self .ctx )
65+ return node , obj .to_variable (node )
66+
67+ def get_own_new (self , node , value ) -> tuple [cfg .CFGNode , cfg .Variable ]:
68+ new = abstract .NativeFunction ("__new__" , self .new_slot , self .ctx )
69+ return node , new .to_variable (node )
70+
71+
72+ def bind_partial (node , cls , args , kwargs , ctx ) -> BoundPartial :
73+ del node # Unused.
74+ obj = BoundPartial (ctx , cls )
75+ obj .underlying = args [0 ]
76+ obj .args = args [1 :]
77+ obj .kwargs = kwargs
78+ return obj
79+
80+
81+ class CallContext (threading .local ):
82+ """A thread-local context for ``NativeFunction.call``."""
83+
84+ starargs : cfg .Variable | None = None
85+ starstarargs : cfg .Variable | None = None
86+
87+ def forward (
88+ self , starargs : cfg .Variable | None , starstarargs : cfg .Variable | None
89+ ) -> Self :
90+ self .starargs = starargs
91+ self .starstarargs = starstarargs
92+ return self
93+
94+ def __enter__ (self ) -> Self :
95+ return self
96+
97+ def __exit__ (self , * exc_info ) -> None :
98+ self .starargs = None
99+ self .starstarargs = None
100+
101+
102+ call_context = CallContext ()
103+
104+
105+ class NativeFunction (abstract .NativeFunction ):
106+ """A native function that forwards *args and **kwargs to the underlying function."""
107+
108+ def call (
109+ self ,
110+ node : cfg .CFGNode ,
111+ func : cfg .Binding ,
112+ args : function .Args ,
113+ alias_map : Any | None = None ,
114+ ) -> tuple [cfg .CFGNode , cfg .Variable ]:
115+ # ``NativeFunction.call`` does not forward *args and **kwargs to the
116+ # underlying function, so we do it here to avoid changing core pytype APIs.
117+ starargs = args .starargs
118+ starstarargs = args .starstarargs
119+ if starargs is not None :
120+ starargs = starargs .AssignToNewVariable (node )
121+ if starstarargs is not None :
122+ starstarargs = starstarargs .AssignToNewVariable (node )
123+ with call_context .forward (starargs , starstarargs ):
124+ return super ().call (node , func , args , alias_map )
125+
126+
127+ class BoundPartial (abstract .Instance , mixin .HasSlots ):
128+ """An instance of functools.partial."""
129+
130+ underlying : cfg .Variable
131+ args : Sequence [cfg .Variable ]
132+ kwargs : Mapping [str , cfg .Variable ]
133+
134+ def __init__ (self , ctx , cls , container = None ):
135+ super ().__init__ (cls , ctx , container )
136+ mixin .HasSlots .init_mixin (self )
137+ self .set_slot (
138+ "__call__" , NativeFunction ("__call__" , self .call_slot , self .ctx )
139+ )
140+
141+ @property
142+ def func (self ) -> cfg .Variable :
143+ # The ``func`` attribute marks this class as a wrapper for
144+ # ``maybe_unwrap_decorated_function``.
145+ return self .underlying
146+
147+ def call_slot (self , node : cfg .CFGNode , * args , ** kwargs ):
148+ return function .call_function (
149+ self .ctx ,
150+ node ,
151+ self .underlying ,
152+ function .Args (
153+ (* self .args , * args ),
154+ {** self .kwargs , ** kwargs },
155+ call_context .starargs ,
156+ call_context .starstarargs ,
157+ ),
158+ fallback_to_unsolvable = False ,
159+ )
0 commit comments