@@ -70,6 +70,48 @@ def match_args(self, node, args, alias_map=None, match_all_views=False):
7070 name = self ._get_cell_variable_name (a )
7171 assert name is not None , "Closure variable lookup failed."
7272 raise function .UndefinedParameterError (name )
73+ # The implementation of match_args is currently rather convoluted because we
74+ # have two different implementations:
75+ # * Old implementation: `_match_views` matches 'args' against 'self' one
76+ # view at a time, where a view is a mapping of every variable in args to a
77+ # particular binding. This handles generics but scales poorly with the
78+ # number of bindings per variable.
79+ # * New implementation: `_match_args_sequentially` matches 'args' one at a
80+ # time. This scales better but cannot yet handle generics.
81+ # Subclasses should implement the following:
82+ # * _match_view(node, args, view, alias_map): this will be called repeatedly
83+ # by _match_views.
84+ # * _match_args_sequentially(node, args, alias_map, match_all_views): A
85+ # sequential matching implementation.
86+ # TODO(b/228241343): Get rid of _match_views and simplify match_args once
87+ # _match_args_sequentially can handle generics.
88+ if self ._is_generic_call (args ):
89+ return self ._match_views (node , args , alias_map , match_all_views )
90+ return self ._match_args_sequentially (node , args , alias_map , match_all_views )
91+
92+ def _is_generic_call (self , args ):
93+ for sig in function .get_signatures (self ):
94+ for t in sig .annotations .values ():
95+ stack = [t ]
96+ seen = set ()
97+ while stack :
98+ cur = stack .pop ()
99+ if cur in seen :
100+ continue
101+ seen .add (cur )
102+ if cur .formal or cur .template :
103+ return True
104+ if _isinstance (cur , "Union" ):
105+ stack .extend (cur .options )
106+ if self .is_attribute_of_class and args .posargs :
107+ for self_val in args .posargs [0 ].data :
108+ for cls in self_val .cls .mro :
109+ if cls .template :
110+ return True
111+ return False
112+
113+ def _match_views (self , node , args , alias_map , match_all_views ):
114+ """Matches all views of the given args against this function."""
73115 error = None
74116 matched = []
75117 arg_variables = args .get_variables ()
@@ -107,6 +149,9 @@ def match_args(self, node, args, alias_map=None, match_all_views=False):
107149 def _match_view (self , node , args , view , alias_map ):
108150 raise NotImplementedError (self .__class__ .__name__ )
109151
152+ def _match_args_sequentially (self , node , args , alias_map , match_all_views ):
153+ raise NotImplementedError (self .__class__ .__name__ )
154+
110155 def __repr__ (self ):
111156 return self .full_name + "(...)"
112157
0 commit comments