55
66import os
77
8- from .utils import print , lib_dir , blacken , to_snake_case , to_camel_case , Patcher
9- from .idlparser import get_idl_parser
8+ from codegen .utils import print , lib_dir , blacken , to_snake_case , to_camel_case , Patcher
9+ from codegen .idlparser import get_idl_parser
1010
1111
1212def patch_base_api (code ):
@@ -42,7 +42,11 @@ def patch_backend_api(code):
4242 base_api_code = f .read ().decode ()
4343
4444 # Patch!
45- for patcher in [CommentRemover (), BackendApiPatcher (base_api_code )]:
45+ for patcher in [
46+ CommentRemover (),
47+ BackendApiPatcher (base_api_code ),
48+ StructValidationChecker (),
49+ ]:
4650 patcher .apply (code )
4751 code = patcher .dumps ()
4852 return code
@@ -53,7 +57,7 @@ class CommentRemover(Patcher):
5357 to prevent accumulating comments.
5458 """
5559
56- triggers = "# IDL:" , "# FIXME: unknown api"
60+ triggers = "# IDL:" , "# FIXME: unknown api" , "# FIXME: missing check_struct"
5761
5862 def apply (self , code ):
5963 self ._init (code )
@@ -174,7 +178,7 @@ def patch_properties(self, classname, i1, i2):
174178 self ._apidiffs_from_lines (pre_lines , propname )
175179 if self .prop_is_known (classname , propname ):
176180 if "@apidiff.add" in pre_lines :
177- print (f"Error : apidiff.add for known { classname } .{ propname } " )
181+ print (f"ERROR : apidiff.add for known { classname } .{ propname } " )
178182 elif "@apidiff.hide" in pre_lines :
179183 pass # continue as normal
180184 old_line = self .lines [j1 ]
@@ -207,7 +211,7 @@ def patch_methods(self, classname, i1, i2):
207211 self ._apidiffs_from_lines (pre_lines , methodname )
208212 if self .method_is_known (classname , methodname ):
209213 if "@apidiff.add" in pre_lines :
210- print (f"Error : apidiff.add for known { classname } .{ methodname } " )
214+ print (f"ERROR : apidiff.add for known { classname } .{ methodname } " )
211215 elif "@apidiff.hide" in pre_lines :
212216 pass # continue as normal
213217 elif "@apidiff.change" in pre_lines :
@@ -443,3 +447,64 @@ def get_required_prop_names(self, classname):
443447 def get_required_method_names (self , classname ):
444448 _ , methods = self .classes [classname ]
445449 return list (name for name in methods .keys () if methods [name ][1 ])
450+
451+
452+ class StructValidationChecker (Patcher ):
453+ """Checks that all structs are vaildated in the methods that have incoming structs."""
454+
455+ def apply (self , code ):
456+ self ._init (code )
457+
458+ idl = get_idl_parser ()
459+ all_structs = set ()
460+ ignore_structs = {"Extent3D" }
461+
462+ for classname , i1 , i2 in self .iter_classes ():
463+ if classname not in idl .classes :
464+ continue
465+
466+ # For each method ...
467+ for methodname , j1 , j2 in self .iter_methods (i1 + 1 ):
468+ code = "\n " .join (self .lines [j1 : j2 + 1 ])
469+ # Get signature and cut it up in words
470+ sig_words = code .partition ("(" )[2 ].split ("):" )[0 ]
471+ for c in "][(),\" '" :
472+ sig_words = sig_words .replace (c , " " )
473+ # Collect incoming structs from signature
474+ method_structs = set ()
475+ for word in sig_words .split ():
476+ if word .startswith ("structs." ):
477+ structname = word .partition ("." )[2 ]
478+ method_structs .update (self ._get_sub_structs (idl , structname ))
479+ all_structs .update (method_structs )
480+ # Collect structs being checked
481+ checked = set ()
482+ for line in code .splitlines ():
483+ line = line .lstrip ()
484+ if line .startswith ("check_struct(" ):
485+ name = line .split ("(" )[1 ].split ("," )[0 ].strip ('"' )
486+ checked .add (name )
487+ # Test that a matching check is done
488+ unchecked = method_structs .difference (checked )
489+ unchecked = list (sorted (unchecked .difference (ignore_structs )))
490+ if (
491+ methodname .endswith ("_async" )
492+ and f"return self.{ methodname [:- 7 ]} " in code
493+ ):
494+ pass
495+ elif unchecked :
496+ msg = f"missing check_struct in { methodname } : { unchecked } "
497+ self .insert_line (j1 , f"# FIXME: { msg } " )
498+ print (f"ERROR: { msg } " )
499+
500+ # Test that we did find structs. In case our detection fails for
501+ # some reason, this would probably catch that.
502+ assert len (all_structs ) > 10
503+
504+ def _get_sub_structs (self , idl , structname ):
505+ structnames = {structname }
506+ for structfield in idl .structs [structname ].values ():
507+ structname2 = structfield .typename [3 :] # remove "GPU"
508+ if structname2 in idl .structs :
509+ structnames .update (self ._get_sub_structs (idl , structname2 ))
510+ return structnames
0 commit comments