Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Include/internal/pycore_opcode_metadata.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Include/internal/pycore_uop_metadata.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

145 changes: 102 additions & 43 deletions Tools/cases_generator/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
from typing import Optional, Callable

from parser import Stmt, SimpleStmt, BlockStmt, IfStmt, WhileStmt
from parser import Stmt, SimpleStmt, BlockStmt, IfStmt, WhileStmt, ForStmt, MacroIfStmt

@dataclass
class EscapingCall:
Expand Down Expand Up @@ -723,53 +723,57 @@ def visit(stmt: Stmt) -> None:
if error is not None:
raise analysis_error(f"Escaping call '{error.text} in condition", error)

def escaping_call_in_simple_stmt(stmt: SimpleStmt, result: dict[SimpleStmt, EscapingCall]) -> None:
tokens = stmt.contents
for idx, tkn in enumerate(tokens):
try:
next_tkn = tokens[idx+1]
except IndexError:
break
if next_tkn.kind != lexer.LPAREN:
continue
if tkn.kind == lexer.IDENTIFIER:
if tkn.text.upper() == tkn.text:
# simple macro
continue
#if not tkn.text.startswith(("Py", "_Py", "monitor")):
# continue
if tkn.text.startswith(("sym_", "optimize_", "PyJitRef")):
# Optimize functions
continue
if tkn.text.endswith("Check"):
continue
if tkn.text.startswith("Py_Is"):
continue
if tkn.text.endswith("CheckExact"):
continue
if tkn.text in NON_ESCAPING_FUNCTIONS:
continue
elif tkn.kind == "RPAREN":
prev = tokens[idx-1]
if prev.text.endswith("_t") or prev.text == "*" or prev.text == "int":
#cast
continue
elif tkn.kind != "RBRACKET":
continue
if tkn.text in ("PyStackRef_CLOSE", "PyStackRef_XCLOSE"):
if len(tokens) <= idx+2:
raise analysis_error("Unexpected end of file", next_tkn)
kills = tokens[idx+2]
if kills.kind != "IDENTIFIER":
raise analysis_error(f"Expected identifier, got '{kills.text}'", kills)
else:
kills = None
result[stmt] = EscapingCall(stmt, tkn, kills)


def find_escaping_api_calls(instr: parser.CodeDef) -> dict[SimpleStmt, EscapingCall]:
result: dict[SimpleStmt, EscapingCall] = {}

def visit(stmt: Stmt) -> None:
if not isinstance(stmt, SimpleStmt):
return
tokens = stmt.contents
for idx, tkn in enumerate(tokens):
try:
next_tkn = tokens[idx+1]
except IndexError:
break
if next_tkn.kind != lexer.LPAREN:
continue
if tkn.kind == lexer.IDENTIFIER:
if tkn.text.upper() == tkn.text:
# simple macro
continue
#if not tkn.text.startswith(("Py", "_Py", "monitor")):
# continue
if tkn.text.startswith(("sym_", "optimize_", "PyJitRef")):
# Optimize functions
continue
if tkn.text.endswith("Check"):
continue
if tkn.text.startswith("Py_Is"):
continue
if tkn.text.endswith("CheckExact"):
continue
if tkn.text in NON_ESCAPING_FUNCTIONS:
continue
elif tkn.kind == "RPAREN":
prev = tokens[idx-1]
if prev.text.endswith("_t") or prev.text == "*" or prev.text == "int":
#cast
continue
elif tkn.kind != "RBRACKET":
continue
if tkn.text in ("PyStackRef_CLOSE", "PyStackRef_XCLOSE"):
if len(tokens) <= idx+2:
raise analysis_error("Unexpected end of file", next_tkn)
kills = tokens[idx+2]
if kills.kind != "IDENTIFIER":
raise analysis_error(f"Expected identifier, got '{kills.text}'", kills)
else:
kills = None
result[stmt] = EscapingCall(stmt, tkn, kills)
escaping_call_in_simple_stmt(stmt, result)

instr.block.accept(visit)
check_escaping_calls(instr, result)
Expand Down Expand Up @@ -822,6 +826,61 @@ def stack_effect_only_peeks(instr: parser.InstDef) -> bool:
)


def op_escapes(op: parser.CodeDef) -> bool:

def is_simple_exit(stmt: Stmt) -> bool:
if not isinstance(stmt, SimpleStmt):
return False
tokens = stmt.contents
if len(tokens) < 4:
return False
return (
tokens[0].text in ("ERROR_IF", "DEOPT_IF", "EXIT_IF")
and
tokens[1].text == "("
and
tokens[2].text in ("true", "1")
and
tokens[3].text == ")"
)

def escapes_list(stmts: list[Stmt]) -> bool:
if not stmts:
return False
if is_simple_exit(stmts[-1]):
return False
for stmt in stmts:
if escapes(stmt):
return True
return False

def escapes(stmt: Stmt) -> bool:
if isinstance(stmt, BlockStmt):
return escapes_list(stmt.body)
elif isinstance(stmt, SimpleStmt):
for tkn in stmt.contents:
if tkn.text == "DECREF_INPUTS":
return True
d: dict[SimpleStmt, EscapingCall] = {}
escaping_call_in_simple_stmt(stmt, d)
return bool(d)
elif isinstance(stmt, IfStmt):
if stmt.else_body and escapes(stmt.else_body):
return True
return escapes(stmt.body)
elif isinstance(stmt, MacroIfStmt):
if stmt.else_body and escapes_list(stmt.else_body):
return True
return escapes_list(stmt.body)
elif isinstance(stmt, ForStmt):
return escapes(stmt.body)
elif isinstance(stmt, WhileStmt):
return escapes(stmt.body)
else:
assert False, "Unexpected statement type"

return escapes(op.block)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way these functions are nested is sort of weird. Is there a reason they're not just defined at the top level?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just grouping. I've flattened the hierarchy to be more consistent with the rest of the code.


def compute_properties(op: parser.CodeDef) -> Properties:
escaping_calls = find_escaping_api_calls(op)
has_free = (
Expand All @@ -843,7 +902,7 @@ def compute_properties(op: parser.CodeDef) -> Properties:
)
error_with_pop = has_error_with_pop(op)
error_without_pop = has_error_without_pop(op)
escapes = bool(escaping_calls) or variable_used(op, "DECREF_INPUTS")
escapes = op_escapes(op)
pure = False if isinstance(op, parser.LabelDef) else "pure" in op.annotations
no_save_ip = False if isinstance(op, parser.LabelDef) else "no_save_ip" in op.annotations
return Properties(
Expand Down
Loading