55import  re 
66from  typing  import  Optional , Callable 
77
8- from  parser  import  Stmt , SimpleStmt , BlockStmt , IfStmt , WhileStmt 
8+ from  parser  import  Stmt , SimpleStmt , BlockStmt , IfStmt , WhileStmt ,  ForStmt ,  MacroIfStmt 
99
1010@dataclass  
1111class  EscapingCall :
@@ -723,53 +723,57 @@ def visit(stmt: Stmt) -> None:
723723    if  error  is  not   None :
724724        raise  analysis_error (f"Escaping call '{ error .text }   in condition" , error )
725725
726+ def  escaping_call_in_simple_stmt (stmt : SimpleStmt , result : dict [SimpleStmt , EscapingCall ]) ->  None :
727+     tokens  =  stmt .contents 
728+     for  idx , tkn  in  enumerate (tokens ):
729+         try :
730+             next_tkn  =  tokens [idx + 1 ]
731+         except  IndexError :
732+             break 
733+         if  next_tkn .kind  !=  lexer .LPAREN :
734+             continue 
735+         if  tkn .kind  ==  lexer .IDENTIFIER :
736+             if  tkn .text .upper () ==  tkn .text :
737+                 # simple macro 
738+                 continue 
739+             #if not tkn.text.startswith(("Py", "_Py", "monitor")): 
740+             #    continue 
741+             if  tkn .text .startswith (("sym_" , "optimize_" , "PyJitRef" )):
742+                 # Optimize functions 
743+                 continue 
744+             if  tkn .text .endswith ("Check" ):
745+                 continue 
746+             if  tkn .text .startswith ("Py_Is" ):
747+                 continue 
748+             if  tkn .text .endswith ("CheckExact" ):
749+                 continue 
750+             if  tkn .text  in  NON_ESCAPING_FUNCTIONS :
751+                 continue 
752+         elif  tkn .kind  ==  "RPAREN" :
753+             prev  =  tokens [idx - 1 ]
754+             if  prev .text .endswith ("_t" ) or  prev .text  ==  "*"  or  prev .text  ==  "int" :
755+                 #cast 
756+                 continue 
757+         elif  tkn .kind  !=  "RBRACKET" :
758+             continue 
759+         if  tkn .text  in  ("PyStackRef_CLOSE" , "PyStackRef_XCLOSE" ):
760+             if  len (tokens ) <=  idx + 2 :
761+                 raise  analysis_error ("Unexpected end of file" , next_tkn )
762+             kills  =  tokens [idx + 2 ]
763+             if  kills .kind  !=  "IDENTIFIER" :
764+                 raise  analysis_error (f"Expected identifier, got '{ kills .text }  '" , kills )
765+         else :
766+             kills  =  None 
767+         result [stmt ] =  EscapingCall (stmt , tkn , kills )
768+ 
769+ 
726770def  find_escaping_api_calls (instr : parser .CodeDef ) ->  dict [SimpleStmt , EscapingCall ]:
727771    result : dict [SimpleStmt , EscapingCall ] =  {}
728772
729773    def  visit (stmt : Stmt ) ->  None :
730774        if  not  isinstance (stmt , SimpleStmt ):
731775            return 
732-         tokens  =  stmt .contents 
733-         for  idx , tkn  in  enumerate (tokens ):
734-             try :
735-                 next_tkn  =  tokens [idx + 1 ]
736-             except  IndexError :
737-                 break 
738-             if  next_tkn .kind  !=  lexer .LPAREN :
739-                 continue 
740-             if  tkn .kind  ==  lexer .IDENTIFIER :
741-                 if  tkn .text .upper () ==  tkn .text :
742-                     # simple macro 
743-                     continue 
744-                 #if not tkn.text.startswith(("Py", "_Py", "monitor")): 
745-                 #    continue 
746-                 if  tkn .text .startswith (("sym_" , "optimize_" , "PyJitRef" )):
747-                     # Optimize functions 
748-                     continue 
749-                 if  tkn .text .endswith ("Check" ):
750-                     continue 
751-                 if  tkn .text .startswith ("Py_Is" ):
752-                     continue 
753-                 if  tkn .text .endswith ("CheckExact" ):
754-                     continue 
755-                 if  tkn .text  in  NON_ESCAPING_FUNCTIONS :
756-                     continue 
757-             elif  tkn .kind  ==  "RPAREN" :
758-                 prev  =  tokens [idx - 1 ]
759-                 if  prev .text .endswith ("_t" ) or  prev .text  ==  "*"  or  prev .text  ==  "int" :
760-                     #cast 
761-                     continue 
762-             elif  tkn .kind  !=  "RBRACKET" :
763-                 continue 
764-             if  tkn .text  in  ("PyStackRef_CLOSE" , "PyStackRef_XCLOSE" ):
765-                 if  len (tokens ) <=  idx + 2 :
766-                     raise  analysis_error ("Unexpected end of file" , next_tkn )
767-                 kills  =  tokens [idx + 2 ]
768-                 if  kills .kind  !=  "IDENTIFIER" :
769-                     raise  analysis_error (f"Expected identifier, got '{ kills .text }  '" , kills )
770-             else :
771-                 kills  =  None 
772-             result [stmt ] =  EscapingCall (stmt , tkn , kills )
776+         escaping_call_in_simple_stmt (stmt , result )
773777
774778    instr .block .accept (visit )
775779    check_escaping_calls (instr , result )
@@ -822,6 +826,60 @@ def stack_effect_only_peeks(instr: parser.InstDef) -> bool:
822826    )
823827
824828
829+ def  stmt_is_simple_exit (stmt : Stmt ) ->  bool :
830+     if  not  isinstance (stmt , SimpleStmt ):
831+         return  False 
832+     tokens  =  stmt .contents 
833+     if  len (tokens ) <  4 :
834+         return  False 
835+     return  (
836+         tokens [0 ].text  in  ("ERROR_IF" , "DEOPT_IF" , "EXIT_IF" )
837+         and 
838+         tokens [1 ].text  ==  "(" 
839+         and 
840+         tokens [2 ].text  in  ("true" , "1" )
841+         and 
842+         tokens [3 ].text  ==  ")" 
843+     )
844+ 
845+ 
846+ def  stmt_list_escapes (stmts : list [Stmt ]) ->  bool :
847+     if  not  stmts :
848+         return  False 
849+     if  stmt_is_simple_exit (stmts [- 1 ]):
850+         return  False 
851+     for  stmt  in  stmts :
852+         if  stmt_escapes (stmt ):
853+             return  True 
854+     return  False 
855+ 
856+ 
857+ def  stmt_escapes (stmt : Stmt ) ->  bool :
858+     if  isinstance (stmt , BlockStmt ):
859+         return  stmt_list_escapes (stmt .body )
860+     elif  isinstance (stmt , SimpleStmt ):
861+         for  tkn  in  stmt .contents :
862+             if  tkn .text  ==  "DECREF_INPUTS" :
863+                 return  True 
864+         d : dict [SimpleStmt , EscapingCall ] =  {}
865+         escaping_call_in_simple_stmt (stmt , d )
866+         return  bool (d )
867+     elif  isinstance (stmt , IfStmt ):
868+         if  stmt .else_body  and  stmt_escapes (stmt .else_body ):
869+             return  True 
870+         return  stmt_escapes (stmt .body )
871+     elif  isinstance (stmt , MacroIfStmt ):
872+         if  stmt .else_body  and  stmt_list_escapes (stmt .else_body ):
873+             return  True 
874+         return  stmt_list_escapes (stmt .body )
875+     elif  isinstance (stmt , ForStmt ):
876+         return  stmt_escapes (stmt .body )
877+     elif  isinstance (stmt , WhileStmt ):
878+         return  stmt_escapes (stmt .body )
879+     else :
880+         assert  False , "Unexpected statement type" 
881+ 
882+ 
825883def  compute_properties (op : parser .CodeDef ) ->  Properties :
826884    escaping_calls  =  find_escaping_api_calls (op )
827885    has_free  =  (
@@ -843,7 +901,7 @@ def compute_properties(op: parser.CodeDef) -> Properties:
843901        )
844902    error_with_pop  =  has_error_with_pop (op )
845903    error_without_pop  =  has_error_without_pop (op )
846-     escapes  =  bool ( escaping_calls )  or   variable_used ( op ,  "DECREF_INPUTS" )
904+     escapes  =  stmt_escapes ( op . block )
847905    pure  =  False  if  isinstance (op , parser .LabelDef ) else  "pure"  in  op .annotations 
848906    no_save_ip  =  False  if  isinstance (op , parser .LabelDef ) else  "no_save_ip"  in  op .annotations 
849907    return  Properties (
0 commit comments