Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Include/internal/pycore_opcode_metadata.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Include/internal/pycore_uop_metadata.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

144 changes: 101 additions & 43 deletions Tools/cases_generator/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
from typing import Optional, Callable

from parser import Stmt, SimpleStmt, BlockStmt, IfStmt, WhileStmt
from parser import Stmt, SimpleStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, MacroIfStmt

@dataclass
class EscapingCall:
Expand Down Expand Up @@ -723,53 +723,57 @@ def visit(stmt: Stmt) -> None:
if error is not None:
raise analysis_error(f"Escaping call '{error.text} in condition", error)

def escaping_call_in_simple_stmt(stmt: SimpleStmt, result: dict[SimpleStmt, EscapingCall]) -> None:
tokens = stmt.contents
for idx, tkn in enumerate(tokens):
try:
next_tkn = tokens[idx+1]
except IndexError:
break
if next_tkn.kind != lexer.LPAREN:
continue
if tkn.kind == lexer.IDENTIFIER:
if tkn.text.upper() == tkn.text:
# simple macro
continue
#if not tkn.text.startswith(("Py", "_Py", "monitor")):
# continue
if tkn.text.startswith(("sym_", "optimize_", "PyJitRef")):
# Optimize functions
continue
if tkn.text.endswith("Check"):
continue
if tkn.text.startswith("Py_Is"):
continue
if tkn.text.endswith("CheckExact"):
continue
if tkn.text in NON_ESCAPING_FUNCTIONS:
continue
elif tkn.kind == "RPAREN":
prev = tokens[idx-1]
if prev.text.endswith("_t") or prev.text == "*" or prev.text == "int":
#cast
continue
elif tkn.kind != "RBRACKET":
continue
if tkn.text in ("PyStackRef_CLOSE", "PyStackRef_XCLOSE"):
if len(tokens) <= idx+2:
raise analysis_error("Unexpected end of file", next_tkn)
kills = tokens[idx+2]
if kills.kind != "IDENTIFIER":
raise analysis_error(f"Expected identifier, got '{kills.text}'", kills)
else:
kills = None
result[stmt] = EscapingCall(stmt, tkn, kills)


def find_escaping_api_calls(instr: parser.CodeDef) -> dict[SimpleStmt, EscapingCall]:
result: dict[SimpleStmt, EscapingCall] = {}

def visit(stmt: Stmt) -> None:
if not isinstance(stmt, SimpleStmt):
return
tokens = stmt.contents
for idx, tkn in enumerate(tokens):
try:
next_tkn = tokens[idx+1]
except IndexError:
break
if next_tkn.kind != lexer.LPAREN:
continue
if tkn.kind == lexer.IDENTIFIER:
if tkn.text.upper() == tkn.text:
# simple macro
continue
#if not tkn.text.startswith(("Py", "_Py", "monitor")):
# continue
if tkn.text.startswith(("sym_", "optimize_", "PyJitRef")):
# Optimize functions
continue
if tkn.text.endswith("Check"):
continue
if tkn.text.startswith("Py_Is"):
continue
if tkn.text.endswith("CheckExact"):
continue
if tkn.text in NON_ESCAPING_FUNCTIONS:
continue
elif tkn.kind == "RPAREN":
prev = tokens[idx-1]
if prev.text.endswith("_t") or prev.text == "*" or prev.text == "int":
#cast
continue
elif tkn.kind != "RBRACKET":
continue
if tkn.text in ("PyStackRef_CLOSE", "PyStackRef_XCLOSE"):
if len(tokens) <= idx+2:
raise analysis_error("Unexpected end of file", next_tkn)
kills = tokens[idx+2]
if kills.kind != "IDENTIFIER":
raise analysis_error(f"Expected identifier, got '{kills.text}'", kills)
else:
kills = None
result[stmt] = EscapingCall(stmt, tkn, kills)
escaping_call_in_simple_stmt(stmt, result)

instr.block.accept(visit)
check_escaping_calls(instr, result)
Expand Down Expand Up @@ -822,6 +826,60 @@ def stack_effect_only_peeks(instr: parser.InstDef) -> bool:
)


def stmt_is_simple_exit(stmt: Stmt) -> bool:
if not isinstance(stmt, SimpleStmt):
return False
tokens = stmt.contents
if len(tokens) < 4:
return False
return (
tokens[0].text in ("ERROR_IF", "DEOPT_IF", "EXIT_IF")
and
tokens[1].text == "("
and
tokens[2].text in ("true", "1")
and
tokens[3].text == ")"
)


def stmt_list_escapes(stmts: list[Stmt]) -> bool:
if not stmts:
return False
if stmt_is_simple_exit(stmts[-1]):
return False
for stmt in stmts:
if stmt_escapes(stmt):
return True
return False


def stmt_escapes(stmt: Stmt) -> bool:
if isinstance(stmt, BlockStmt):
return stmt_list_escapes(stmt.body)
elif isinstance(stmt, SimpleStmt):
for tkn in stmt.contents:
if tkn.text == "DECREF_INPUTS":
return True
d: dict[SimpleStmt, EscapingCall] = {}
escaping_call_in_simple_stmt(stmt, d)
return bool(d)
elif isinstance(stmt, IfStmt):
if stmt.else_body and stmt_escapes(stmt.else_body):
return True
return stmt_escapes(stmt.body)
elif isinstance(stmt, MacroIfStmt):
if stmt.else_body and stmt_list_escapes(stmt.else_body):
return True
return stmt_list_escapes(stmt.body)
elif isinstance(stmt, ForStmt):
return stmt_escapes(stmt.body)
elif isinstance(stmt, WhileStmt):
return stmt_escapes(stmt.body)
else:
assert False, "Unexpected statement type"


def compute_properties(op: parser.CodeDef) -> Properties:
escaping_calls = find_escaping_api_calls(op)
has_free = (
Expand All @@ -843,7 +901,7 @@ def compute_properties(op: parser.CodeDef) -> Properties:
)
error_with_pop = has_error_with_pop(op)
error_without_pop = has_error_without_pop(op)
escapes = bool(escaping_calls) or variable_used(op, "DECREF_INPUTS")
escapes = stmt_escapes(op.block)
pure = False if isinstance(op, parser.LabelDef) else "pure" in op.annotations
no_save_ip = False if isinstance(op, parser.LabelDef) else "no_save_ip" in op.annotations
return Properties(
Expand Down
Loading