55import re
66from typing import Optional , Callable
77
8- from parser import Stmt , SimpleStmt , BlockStmt , IfStmt , WhileStmt
8+ from parser import Stmt , SimpleStmt , BlockStmt , IfStmt , WhileStmt , ForStmt , MacroIfStmt
99
1010@dataclass
1111class EscapingCall :
@@ -723,53 +723,57 @@ def visit(stmt: Stmt) -> None:
723723 if error is not None :
724724 raise analysis_error (f"Escaping call '{ error .text } in condition" , error )
725725
726+ def escaping_call_in_simple_stmt (stmt : SimpleStmt , result : dict [SimpleStmt , EscapingCall ]) -> None :
727+ tokens = stmt .contents
728+ for idx , tkn in enumerate (tokens ):
729+ try :
730+ next_tkn = tokens [idx + 1 ]
731+ except IndexError :
732+ break
733+ if next_tkn .kind != lexer .LPAREN :
734+ continue
735+ if tkn .kind == lexer .IDENTIFIER :
736+ if tkn .text .upper () == tkn .text :
737+ # simple macro
738+ continue
739+ #if not tkn.text.startswith(("Py", "_Py", "monitor")):
740+ # continue
741+ if tkn .text .startswith (("sym_" , "optimize_" , "PyJitRef" )):
742+ # Optimize functions
743+ continue
744+ if tkn .text .endswith ("Check" ):
745+ continue
746+ if tkn .text .startswith ("Py_Is" ):
747+ continue
748+ if tkn .text .endswith ("CheckExact" ):
749+ continue
750+ if tkn .text in NON_ESCAPING_FUNCTIONS :
751+ continue
752+ elif tkn .kind == "RPAREN" :
753+ prev = tokens [idx - 1 ]
754+ if prev .text .endswith ("_t" ) or prev .text == "*" or prev .text == "int" :
755+ #cast
756+ continue
757+ elif tkn .kind != "RBRACKET" :
758+ continue
759+ if tkn .text in ("PyStackRef_CLOSE" , "PyStackRef_XCLOSE" ):
760+ if len (tokens ) <= idx + 2 :
761+ raise analysis_error ("Unexpected end of file" , next_tkn )
762+ kills = tokens [idx + 2 ]
763+ if kills .kind != "IDENTIFIER" :
764+ raise analysis_error (f"Expected identifier, got '{ kills .text } '" , kills )
765+ else :
766+ kills = None
767+ result [stmt ] = EscapingCall (stmt , tkn , kills )
768+
769+
726770def find_escaping_api_calls (instr : parser .CodeDef ) -> dict [SimpleStmt , EscapingCall ]:
727771 result : dict [SimpleStmt , EscapingCall ] = {}
728772
729773 def visit (stmt : Stmt ) -> None :
730774 if not isinstance (stmt , SimpleStmt ):
731775 return
732- tokens = stmt .contents
733- for idx , tkn in enumerate (tokens ):
734- try :
735- next_tkn = tokens [idx + 1 ]
736- except IndexError :
737- break
738- if next_tkn .kind != lexer .LPAREN :
739- continue
740- if tkn .kind == lexer .IDENTIFIER :
741- if tkn .text .upper () == tkn .text :
742- # simple macro
743- continue
744- #if not tkn.text.startswith(("Py", "_Py", "monitor")):
745- # continue
746- if tkn .text .startswith (("sym_" , "optimize_" , "PyJitRef" )):
747- # Optimize functions
748- continue
749- if tkn .text .endswith ("Check" ):
750- continue
751- if tkn .text .startswith ("Py_Is" ):
752- continue
753- if tkn .text .endswith ("CheckExact" ):
754- continue
755- if tkn .text in NON_ESCAPING_FUNCTIONS :
756- continue
757- elif tkn .kind == "RPAREN" :
758- prev = tokens [idx - 1 ]
759- if prev .text .endswith ("_t" ) or prev .text == "*" or prev .text == "int" :
760- #cast
761- continue
762- elif tkn .kind != "RBRACKET" :
763- continue
764- if tkn .text in ("PyStackRef_CLOSE" , "PyStackRef_XCLOSE" ):
765- if len (tokens ) <= idx + 2 :
766- raise analysis_error ("Unexpected end of file" , next_tkn )
767- kills = tokens [idx + 2 ]
768- if kills .kind != "IDENTIFIER" :
769- raise analysis_error (f"Expected identifier, got '{ kills .text } '" , kills )
770- else :
771- kills = None
772- result [stmt ] = EscapingCall (stmt , tkn , kills )
776+ escaping_call_in_simple_stmt (stmt , result )
773777
774778 instr .block .accept (visit )
775779 check_escaping_calls (instr , result )
@@ -822,6 +826,60 @@ def stack_effect_only_peeks(instr: parser.InstDef) -> bool:
822826 )
823827
824828
829+ def stmt_is_simple_exit (stmt : Stmt ) -> bool :
830+ if not isinstance (stmt , SimpleStmt ):
831+ return False
832+ tokens = stmt .contents
833+ if len (tokens ) < 4 :
834+ return False
835+ return (
836+ tokens [0 ].text in ("ERROR_IF" , "DEOPT_IF" , "EXIT_IF" )
837+ and
838+ tokens [1 ].text == "("
839+ and
840+ tokens [2 ].text in ("true" , "1" )
841+ and
842+ tokens [3 ].text == ")"
843+ )
844+
845+
846+ def stmt_list_escapes (stmts : list [Stmt ]) -> bool :
847+ if not stmts :
848+ return False
849+ if stmt_is_simple_exit (stmts [- 1 ]):
850+ return False
851+ for stmt in stmts :
852+ if stmt_escapes (stmt ):
853+ return True
854+ return False
855+
856+
857+ def stmt_escapes (stmt : Stmt ) -> bool :
858+ if isinstance (stmt , BlockStmt ):
859+ return stmt_list_escapes (stmt .body )
860+ elif isinstance (stmt , SimpleStmt ):
861+ for tkn in stmt .contents :
862+ if tkn .text == "DECREF_INPUTS" :
863+ return True
864+ d : dict [SimpleStmt , EscapingCall ] = {}
865+ escaping_call_in_simple_stmt (stmt , d )
866+ return bool (d )
867+ elif isinstance (stmt , IfStmt ):
868+ if stmt .else_body and stmt_escapes (stmt .else_body ):
869+ return True
870+ return stmt_escapes (stmt .body )
871+ elif isinstance (stmt , MacroIfStmt ):
872+ if stmt .else_body and stmt_list_escapes (stmt .else_body ):
873+ return True
874+ return stmt_list_escapes (stmt .body )
875+ elif isinstance (stmt , ForStmt ):
876+ return stmt_escapes (stmt .body )
877+ elif isinstance (stmt , WhileStmt ):
878+ return stmt_escapes (stmt .body )
879+ else :
880+ assert False , "Unexpected statement type"
881+
882+
825883def compute_properties (op : parser .CodeDef ) -> Properties :
826884 escaping_calls = find_escaping_api_calls (op )
827885 has_free = (
@@ -843,7 +901,7 @@ def compute_properties(op: parser.CodeDef) -> Properties:
843901 )
844902 error_with_pop = has_error_with_pop (op )
845903 error_without_pop = has_error_without_pop (op )
846- escapes = bool ( escaping_calls ) or variable_used ( op , "DECREF_INPUTS" )
904+ escapes = stmt_escapes ( op . block )
847905 pure = False if isinstance (op , parser .LabelDef ) else "pure" in op .annotations
848906 no_save_ip = False if isinstance (op , parser .LabelDef ) else "no_save_ip" in op .annotations
849907 return Properties (
0 commit comments