5
5
import re
6
6
from typing import Optional , Callable
7
7
8
- from parser import Stmt , SimpleStmt , BlockStmt , IfStmt , WhileStmt
8
+ from parser import Stmt , SimpleStmt , BlockStmt , IfStmt , WhileStmt , ForStmt , MacroIfStmt
9
9
10
10
@dataclass
11
11
class EscapingCall :
@@ -723,53 +723,57 @@ def visit(stmt: Stmt) -> None:
723
723
if error is not None :
724
724
raise analysis_error (f"Escaping call '{ error .text } in condition" , error )
725
725
726
+ def escaping_call_in_simple_stmt (stmt : SimpleStmt , result : dict [SimpleStmt , EscapingCall ]) -> None :
727
+ tokens = stmt .contents
728
+ for idx , tkn in enumerate (tokens ):
729
+ try :
730
+ next_tkn = tokens [idx + 1 ]
731
+ except IndexError :
732
+ break
733
+ if next_tkn .kind != lexer .LPAREN :
734
+ continue
735
+ if tkn .kind == lexer .IDENTIFIER :
736
+ if tkn .text .upper () == tkn .text :
737
+ # simple macro
738
+ continue
739
+ #if not tkn.text.startswith(("Py", "_Py", "monitor")):
740
+ # continue
741
+ if tkn .text .startswith (("sym_" , "optimize_" , "PyJitRef" )):
742
+ # Optimize functions
743
+ continue
744
+ if tkn .text .endswith ("Check" ):
745
+ continue
746
+ if tkn .text .startswith ("Py_Is" ):
747
+ continue
748
+ if tkn .text .endswith ("CheckExact" ):
749
+ continue
750
+ if tkn .text in NON_ESCAPING_FUNCTIONS :
751
+ continue
752
+ elif tkn .kind == "RPAREN" :
753
+ prev = tokens [idx - 1 ]
754
+ if prev .text .endswith ("_t" ) or prev .text == "*" or prev .text == "int" :
755
+ #cast
756
+ continue
757
+ elif tkn .kind != "RBRACKET" :
758
+ continue
759
+ if tkn .text in ("PyStackRef_CLOSE" , "PyStackRef_XCLOSE" ):
760
+ if len (tokens ) <= idx + 2 :
761
+ raise analysis_error ("Unexpected end of file" , next_tkn )
762
+ kills = tokens [idx + 2 ]
763
+ if kills .kind != "IDENTIFIER" :
764
+ raise analysis_error (f"Expected identifier, got '{ kills .text } '" , kills )
765
+ else :
766
+ kills = None
767
+ result [stmt ] = EscapingCall (stmt , tkn , kills )
768
+
769
+
726
770
def find_escaping_api_calls (instr : parser .CodeDef ) -> dict [SimpleStmt , EscapingCall ]:
727
771
result : dict [SimpleStmt , EscapingCall ] = {}
728
772
729
773
def visit (stmt : Stmt ) -> None :
730
774
if not isinstance (stmt , SimpleStmt ):
731
775
return
732
- tokens = stmt .contents
733
- for idx , tkn in enumerate (tokens ):
734
- try :
735
- next_tkn = tokens [idx + 1 ]
736
- except IndexError :
737
- break
738
- if next_tkn .kind != lexer .LPAREN :
739
- continue
740
- if tkn .kind == lexer .IDENTIFIER :
741
- if tkn .text .upper () == tkn .text :
742
- # simple macro
743
- continue
744
- #if not tkn.text.startswith(("Py", "_Py", "monitor")):
745
- # continue
746
- if tkn .text .startswith (("sym_" , "optimize_" , "PyJitRef" )):
747
- # Optimize functions
748
- continue
749
- if tkn .text .endswith ("Check" ):
750
- continue
751
- if tkn .text .startswith ("Py_Is" ):
752
- continue
753
- if tkn .text .endswith ("CheckExact" ):
754
- continue
755
- if tkn .text in NON_ESCAPING_FUNCTIONS :
756
- continue
757
- elif tkn .kind == "RPAREN" :
758
- prev = tokens [idx - 1 ]
759
- if prev .text .endswith ("_t" ) or prev .text == "*" or prev .text == "int" :
760
- #cast
761
- continue
762
- elif tkn .kind != "RBRACKET" :
763
- continue
764
- if tkn .text in ("PyStackRef_CLOSE" , "PyStackRef_XCLOSE" ):
765
- if len (tokens ) <= idx + 2 :
766
- raise analysis_error ("Unexpected end of file" , next_tkn )
767
- kills = tokens [idx + 2 ]
768
- if kills .kind != "IDENTIFIER" :
769
- raise analysis_error (f"Expected identifier, got '{ kills .text } '" , kills )
770
- else :
771
- kills = None
772
- result [stmt ] = EscapingCall (stmt , tkn , kills )
776
+ escaping_call_in_simple_stmt (stmt , result )
773
777
774
778
instr .block .accept (visit )
775
779
check_escaping_calls (instr , result )
@@ -822,6 +826,60 @@ def stack_effect_only_peeks(instr: parser.InstDef) -> bool:
822
826
)
823
827
824
828
829
+ def stmt_is_simple_exit (stmt : Stmt ) -> bool :
830
+ if not isinstance (stmt , SimpleStmt ):
831
+ return False
832
+ tokens = stmt .contents
833
+ if len (tokens ) < 4 :
834
+ return False
835
+ return (
836
+ tokens [0 ].text in ("ERROR_IF" , "DEOPT_IF" , "EXIT_IF" )
837
+ and
838
+ tokens [1 ].text == "("
839
+ and
840
+ tokens [2 ].text in ("true" , "1" )
841
+ and
842
+ tokens [3 ].text == ")"
843
+ )
844
+
845
+
846
+ def stmt_list_escapes (stmts : list [Stmt ]) -> bool :
847
+ if not stmts :
848
+ return False
849
+ if stmt_is_simple_exit (stmts [- 1 ]):
850
+ return False
851
+ for stmt in stmts :
852
+ if stmt_escapes (stmt ):
853
+ return True
854
+ return False
855
+
856
+
857
+ def stmt_escapes (stmt : Stmt ) -> bool :
858
+ if isinstance (stmt , BlockStmt ):
859
+ return stmt_list_escapes (stmt .body )
860
+ elif isinstance (stmt , SimpleStmt ):
861
+ for tkn in stmt .contents :
862
+ if tkn .text == "DECREF_INPUTS" :
863
+ return True
864
+ d : dict [SimpleStmt , EscapingCall ] = {}
865
+ escaping_call_in_simple_stmt (stmt , d )
866
+ return bool (d )
867
+ elif isinstance (stmt , IfStmt ):
868
+ if stmt .else_body and stmt_escapes (stmt .else_body ):
869
+ return True
870
+ return stmt_escapes (stmt .body )
871
+ elif isinstance (stmt , MacroIfStmt ):
872
+ if stmt .else_body and stmt_list_escapes (stmt .else_body ):
873
+ return True
874
+ return stmt_list_escapes (stmt .body )
875
+ elif isinstance (stmt , ForStmt ):
876
+ return stmt_escapes (stmt .body )
877
+ elif isinstance (stmt , WhileStmt ):
878
+ return stmt_escapes (stmt .body )
879
+ else :
880
+ assert False , "Unexpected statement type"
881
+
882
+
825
883
def compute_properties (op : parser .CodeDef ) -> Properties :
826
884
escaping_calls = find_escaping_api_calls (op )
827
885
has_free = (
@@ -843,7 +901,7 @@ def compute_properties(op: parser.CodeDef) -> Properties:
843
901
)
844
902
error_with_pop = has_error_with_pop (op )
845
903
error_without_pop = has_error_without_pop (op )
846
- escapes = bool ( escaping_calls ) or variable_used ( op , "DECREF_INPUTS" )
904
+ escapes = stmt_escapes ( op . block )
847
905
pure = False if isinstance (op , parser .LabelDef ) else "pure" in op .annotations
848
906
no_save_ip = False if isinstance (op , parser .LabelDef ) else "no_save_ip" in op .annotations
849
907
return Properties (
0 commit comments