Skip to content

Commit 801cf3f

Browse files
authored
GH-137276: Don't mark uop as escaping if the escaping call is on an exit branch (GH-137277)
1 parent 7475887 commit 801cf3f

File tree

3 files changed

+108
-50
lines changed

3 files changed

+108
-50
lines changed

Include/internal/pycore_opcode_metadata.h

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Include/internal/pycore_uop_metadata.h

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Tools/cases_generator/analyzer.py

Lines changed: 101 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import re
66
from typing import Optional, Callable
77

8-
from parser import Stmt, SimpleStmt, BlockStmt, IfStmt, WhileStmt
8+
from parser import Stmt, SimpleStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, MacroIfStmt
99

1010
@dataclass
1111
class EscapingCall:
@@ -723,53 +723,57 @@ def visit(stmt: Stmt) -> None:
723723
if error is not None:
724724
raise analysis_error(f"Escaping call '{error.text} in condition", error)
725725

726+
def escaping_call_in_simple_stmt(stmt: SimpleStmt, result: dict[SimpleStmt, EscapingCall]) -> None:
727+
tokens = stmt.contents
728+
for idx, tkn in enumerate(tokens):
729+
try:
730+
next_tkn = tokens[idx+1]
731+
except IndexError:
732+
break
733+
if next_tkn.kind != lexer.LPAREN:
734+
continue
735+
if tkn.kind == lexer.IDENTIFIER:
736+
if tkn.text.upper() == tkn.text:
737+
# simple macro
738+
continue
739+
#if not tkn.text.startswith(("Py", "_Py", "monitor")):
740+
# continue
741+
if tkn.text.startswith(("sym_", "optimize_", "PyJitRef")):
742+
# Optimize functions
743+
continue
744+
if tkn.text.endswith("Check"):
745+
continue
746+
if tkn.text.startswith("Py_Is"):
747+
continue
748+
if tkn.text.endswith("CheckExact"):
749+
continue
750+
if tkn.text in NON_ESCAPING_FUNCTIONS:
751+
continue
752+
elif tkn.kind == "RPAREN":
753+
prev = tokens[idx-1]
754+
if prev.text.endswith("_t") or prev.text == "*" or prev.text == "int":
755+
#cast
756+
continue
757+
elif tkn.kind != "RBRACKET":
758+
continue
759+
if tkn.text in ("PyStackRef_CLOSE", "PyStackRef_XCLOSE"):
760+
if len(tokens) <= idx+2:
761+
raise analysis_error("Unexpected end of file", next_tkn)
762+
kills = tokens[idx+2]
763+
if kills.kind != "IDENTIFIER":
764+
raise analysis_error(f"Expected identifier, got '{kills.text}'", kills)
765+
else:
766+
kills = None
767+
result[stmt] = EscapingCall(stmt, tkn, kills)
768+
769+
726770
def find_escaping_api_calls(instr: parser.CodeDef) -> dict[SimpleStmt, EscapingCall]:
727771
result: dict[SimpleStmt, EscapingCall] = {}
728772

729773
def visit(stmt: Stmt) -> None:
730774
if not isinstance(stmt, SimpleStmt):
731775
return
732-
tokens = stmt.contents
733-
for idx, tkn in enumerate(tokens):
734-
try:
735-
next_tkn = tokens[idx+1]
736-
except IndexError:
737-
break
738-
if next_tkn.kind != lexer.LPAREN:
739-
continue
740-
if tkn.kind == lexer.IDENTIFIER:
741-
if tkn.text.upper() == tkn.text:
742-
# simple macro
743-
continue
744-
#if not tkn.text.startswith(("Py", "_Py", "monitor")):
745-
# continue
746-
if tkn.text.startswith(("sym_", "optimize_", "PyJitRef")):
747-
# Optimize functions
748-
continue
749-
if tkn.text.endswith("Check"):
750-
continue
751-
if tkn.text.startswith("Py_Is"):
752-
continue
753-
if tkn.text.endswith("CheckExact"):
754-
continue
755-
if tkn.text in NON_ESCAPING_FUNCTIONS:
756-
continue
757-
elif tkn.kind == "RPAREN":
758-
prev = tokens[idx-1]
759-
if prev.text.endswith("_t") or prev.text == "*" or prev.text == "int":
760-
#cast
761-
continue
762-
elif tkn.kind != "RBRACKET":
763-
continue
764-
if tkn.text in ("PyStackRef_CLOSE", "PyStackRef_XCLOSE"):
765-
if len(tokens) <= idx+2:
766-
raise analysis_error("Unexpected end of file", next_tkn)
767-
kills = tokens[idx+2]
768-
if kills.kind != "IDENTIFIER":
769-
raise analysis_error(f"Expected identifier, got '{kills.text}'", kills)
770-
else:
771-
kills = None
772-
result[stmt] = EscapingCall(stmt, tkn, kills)
776+
escaping_call_in_simple_stmt(stmt, result)
773777

774778
instr.block.accept(visit)
775779
check_escaping_calls(instr, result)
@@ -822,6 +826,60 @@ def stack_effect_only_peeks(instr: parser.InstDef) -> bool:
822826
)
823827

824828

829+
def stmt_is_simple_exit(stmt: Stmt) -> bool:
830+
if not isinstance(stmt, SimpleStmt):
831+
return False
832+
tokens = stmt.contents
833+
if len(tokens) < 4:
834+
return False
835+
return (
836+
tokens[0].text in ("ERROR_IF", "DEOPT_IF", "EXIT_IF")
837+
and
838+
tokens[1].text == "("
839+
and
840+
tokens[2].text in ("true", "1")
841+
and
842+
tokens[3].text == ")"
843+
)
844+
845+
846+
def stmt_list_escapes(stmts: list[Stmt]) -> bool:
847+
if not stmts:
848+
return False
849+
if stmt_is_simple_exit(stmts[-1]):
850+
return False
851+
for stmt in stmts:
852+
if stmt_escapes(stmt):
853+
return True
854+
return False
855+
856+
857+
def stmt_escapes(stmt: Stmt) -> bool:
858+
if isinstance(stmt, BlockStmt):
859+
return stmt_list_escapes(stmt.body)
860+
elif isinstance(stmt, SimpleStmt):
861+
for tkn in stmt.contents:
862+
if tkn.text == "DECREF_INPUTS":
863+
return True
864+
d: dict[SimpleStmt, EscapingCall] = {}
865+
escaping_call_in_simple_stmt(stmt, d)
866+
return bool(d)
867+
elif isinstance(stmt, IfStmt):
868+
if stmt.else_body and stmt_escapes(stmt.else_body):
869+
return True
870+
return stmt_escapes(stmt.body)
871+
elif isinstance(stmt, MacroIfStmt):
872+
if stmt.else_body and stmt_list_escapes(stmt.else_body):
873+
return True
874+
return stmt_list_escapes(stmt.body)
875+
elif isinstance(stmt, ForStmt):
876+
return stmt_escapes(stmt.body)
877+
elif isinstance(stmt, WhileStmt):
878+
return stmt_escapes(stmt.body)
879+
else:
880+
assert False, "Unexpected statement type"
881+
882+
825883
def compute_properties(op: parser.CodeDef) -> Properties:
826884
escaping_calls = find_escaping_api_calls(op)
827885
has_free = (
@@ -843,7 +901,7 @@ def compute_properties(op: parser.CodeDef) -> Properties:
843901
)
844902
error_with_pop = has_error_with_pop(op)
845903
error_without_pop = has_error_without_pop(op)
846-
escapes = bool(escaping_calls) or variable_used(op, "DECREF_INPUTS")
904+
escapes = stmt_escapes(op.block)
847905
pure = False if isinstance(op, parser.LabelDef) else "pure" in op.annotations
848906
no_save_ip = False if isinstance(op, parser.LabelDef) else "no_save_ip" in op.annotations
849907
return Properties(

0 commit comments

Comments
 (0)