11# autoflake: skip_file
2+ import copy
23import inspect
34import os
45import pathlib
@@ -28,13 +29,12 @@ def repr_annotation(field_type: type) -> str:
2829
2930@dataclass (slots = True )
3031class ParamSnapshot :
31- name : str
3232 annotation : type
3333 default : Any = None
3434
3535 @classmethod
3636 def from_inspect (cls , param : inspect .Parameter ):
37- return cls (param .name , param . annotation , param .default )
37+ return cls (param .annotation , param .default )
3838
3939 @classmethod
4040 def from_docstring (cls , param : docstring_parser .common .DocstringParam ):
@@ -57,7 +57,7 @@ def from_docstring(cls, param: docstring_parser.common.DocstringParam):
5757 except (NameError , SyntaxError ):
5858 default = param .default
5959
60- return cls (param . arg_name , annotation , default )
60+ return cls (annotation , default )
6161
6262 @classmethod
6363 def from_dict (cls , d : dict ):
@@ -77,19 +77,17 @@ def to_dict(self):
7777 return d
7878
7979 def assert_equal (self , other : 'ParamSnapshot' ):
80- assert self .name == other .name
8180 assert self .annotation == other .annotation
8281 assert self .default == other .default
8382
8483
8584@dataclass (slots = True )
8685class MethodSnapshot :
87- name : str
8886 parameters : Dict [str , ParamSnapshot ]
8987 return_annotation : type
9088
9189 @classmethod
92- def from_inspect (cls , name : str , method : MethodType ):
90+ def from_inspect (cls , method : MethodType ):
9391 signature = inspect .signature (method )
9492 parameters = {}
9593 for param_name , param in signature .parameters .items ():
@@ -99,10 +97,10 @@ def from_inspect(cls, name: str, method: MethodType):
9997 return_annotation = signature .return_annotation
10098 if isinstance (return_annotation , str ):
10199 return_annotation = eval (return_annotation )
102- return cls (name , parameters , return_annotation )
100+ return cls (parameters , return_annotation )
103101
104102 @classmethod
105- def from_docstring (cls , name : str , method : MethodType ):
103+ def from_docstring (cls , method : MethodType ):
106104 doc = docstring_parser .parse (method .__doc__ )
107105 parameters = {}
108106 for param in doc .params :
@@ -112,7 +110,7 @@ def from_docstring(cls, name: str, method: MethodType):
112110 return_annotation = None
113111 else :
114112 return_annotation = eval (doc .returns .type_name )
115- return cls (name , parameters , return_annotation )
113+ return cls (parameters , return_annotation )
116114
117115 @classmethod
118116 def from_dict (cls , d : dict ):
@@ -132,13 +130,23 @@ def to_dict(self):
132130 d ['return_annotation' ] = repr_annotation (d ['return_annotation' ])
133131 return d
134132
133+ def merge (self , other : 'MethodSnapshot' ):
134+ assert self .parameters .keys ().isdisjoint (other .parameters .keys ())
135+ self .parameters .update (copy .deepcopy (other .parameters ))
136+ assert self .return_annotation == other .return_annotation
137+
135138 def assert_equal (self , other : 'MethodSnapshot' ):
136- assert self .name == other .name
137139 assert self .parameters .keys () == other .parameters .keys ()
138140 for name , param in self .parameters .items ():
139141 param .assert_equal (other .parameters [name ])
140142 assert self .return_annotation == other .return_annotation
141143
144+ def assert_containing (self , other : 'MethodSnapshot' ):
145+ for name , param in other .parameters .items ():
146+ assert name in self .parameters
147+ self .parameters [name ].assert_equal (param )
148+ assert self .return_annotation == other .return_annotation
149+
142150
143151@dataclass (slots = True )
144152class ClassSnapshot :
@@ -153,16 +161,14 @@ def from_inspect(cls, snapshot_cls: type):
153161 inst , predicate = inspect .ismethod ):
154162 if method_name .startswith ("_" ) and method_name != "__init__" :
155163 continue
156- methods [method_name ] = MethodSnapshot .from_inspect (
157- method_name , method )
164+ methods [method_name ] = MethodSnapshot .from_inspect (method )
158165 properties = {}
159166 for prop_name , prop in inspect .getmembers (
160167 snapshot_cls , predicate = lambda x : isinstance (x , property )):
161168 if prop_name .startswith ("_" ):
162169 continue
163170 annotation = inspect .signature (prop .fget ).return_annotation
164- properties [prop_name ] = ParamSnapshot (prop_name , annotation ,
165- inspect ._empty )
171+ properties [prop_name ] = ParamSnapshot (annotation , inspect ._empty )
166172 return cls (methods , properties )
167173
168174 @classmethod
@@ -175,10 +181,9 @@ def from_docstring(cls, snapshot_cls: type):
175181 continue
176182 if method_name == "__init__" :
177183 methods ["__init__" ] = MethodSnapshot .from_docstring (
178- "__init__" , snapshot_cls )
184+ snapshot_cls )
179185 else :
180- methods [method_name ] = MethodSnapshot .from_docstring (
181- method_name , method )
186+ methods [method_name ] = MethodSnapshot .from_docstring (method )
182187 properties = {}
183188 doc = docstring_parser .parse (snapshot_cls .__doc__ )
184189 for param in doc .params :
@@ -210,6 +215,19 @@ def to_dict(self):
210215 }
211216 return d
212217
218+ def merge (self , other : 'ClassSnapshot' ):
219+ for name , method in self .methods .items ():
220+ if name in other .methods :
221+ method .merge (other .methods [name ])
222+ new_methods = {
223+ name : method
224+ for name , method in other .methods .items ()
225+ if name not in self .methods
226+ }
227+ self .methods .update (copy .deepcopy (new_methods ))
228+ assert self .properties .keys ().isdisjoint (other .properties .keys ())
229+ self .properties .update (copy .deepcopy (other .properties ))
230+
213231 def assert_equal (self , other : 'ClassSnapshot' ):
214232 assert self .methods .keys () == other .methods .keys ()
215233 for name , method in self .methods .items ():
@@ -218,30 +236,47 @@ def assert_equal(self, other: 'ClassSnapshot'):
218236 for name , prop in self .properties .items ():
219237 prop .assert_equal (other .properties [name ])
220238
239+ def assert_containing (self , other : 'ClassSnapshot' ):
240+ for name , method in other .methods .items ():
241+ assert name in self .methods
242+ self .methods [name ].assert_containing (method )
243+ for name , prop in other .properties .items ():
244+ assert name in self .properties
245+ self .properties [name ].assert_equal (prop )
246+
221247
222248class ApiStabilityTestHarness :
223249 TEST_CLASS = None
250+ REFERENCE_COMMITTED_DIR = f"{ os .path .dirname (__file__ )} /references_committed"
224251 REFERENCE_DIR = f"{ os .path .dirname (__file__ )} /references"
225252 REFERENCE_FILE = None
226253
227- @classmethod
228- def reference_path (cls ):
229- return f"{ cls .REFERENCE_DIR } /{ cls .REFERENCE_FILE } "
230-
231254 @classmethod
232255 def setup_class (cls ):
233- with open (cls .reference_path () ) as f :
256+ with open (f" { cls .REFERENCE_DIR } / { cls . REFERENCE_FILE } " ) as f :
234257 cls .reference = ClassSnapshot .from_dict (yaml .safe_load (f ))
235- cls .error_msg = (
236- f"API stability validation failed. "
237- f"This is probably because you changed { cls .TEST_CLASS .__name__ } 's APIs, please ask for reviews from the code owners."
238- )
258+ if os .path .exists (
259+ f"{ cls .REFERENCE_COMMITTED_DIR } /{ cls .REFERENCE_FILE } " ):
260+ with open (
261+ f"{ cls .REFERENCE_COMMITTED_DIR } /{ cls .REFERENCE_FILE } " ) as f :
262+ cls .reference_committed = ClassSnapshot .from_dict (
263+ yaml .safe_load (f ))
264+ cls .reference .merge (cls .reference_committed )
265+ else :
266+ cls .reference_committed = None
267+ cls .error_msg = f"API validation failed because you changed { cls .TEST_CLASS .__name__ } 's APIs, please ask for reviews from the code owners."
268+ cls .error_msg_committed = f"API validation failed because you changed { cls .TEST_CLASS .__name__ } 's committed APIs, please ask for approval."
239269
240270 def create_snapshot_from_inspect (self ):
241271 return ClassSnapshot .from_inspect (self .TEST_CLASS )
242272
243273 def test_signature (self ):
244274 snapshot = self .create_snapshot_from_inspect ()
275+ if self .reference_committed is not None :
276+ try :
277+ snapshot .assert_containing (self .reference_committed )
278+ except AssertionError as e :
279+ raise AssertionError (self .error_msg_committed ) from e
245280 try :
246281 snapshot .assert_equal (self .reference )
247282 except AssertionError as e :
@@ -252,6 +287,11 @@ def create_snapshot_from_docstring(self):
252287
253288 def test_docstring (self ):
254289 snapshot = self .create_snapshot_from_docstring ()
290+ if self .reference_committed is not None :
291+ try :
292+ snapshot .assert_containing (self .reference_committed )
293+ except AssertionError as e :
294+ raise AssertionError (self .error_msg_committed ) from e
255295 try :
256296 snapshot .assert_equal (self .reference )
257297 except AssertionError as e :
0 commit comments