Skip to content

Commit 99b565a

Browse files
committed
Perform stack spilling in labels. WIP
1 parent a29a9c0 commit 99b565a

File tree

11 files changed

+145
-74
lines changed

11 files changed

+145
-74
lines changed

Python/bytecodes.c

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5213,6 +5213,7 @@ dummy_func(
52135213
}
52145214

52155215
label(error) {
5216+
SAVE_STACK();
52165217
/* Double-check exception status. */
52175218
#ifdef NDEBUG
52185219
if (!_PyErr_Occurred(tstate)) {
@@ -5235,11 +5236,12 @@ dummy_func(
52355236
goto exception_unwind;
52365237
}
52375238

5238-
label(exception_unwind) {
5239+
spilled label(exception_unwind) {
52395240
/* We can't use frame->instr_ptr here, as RERAISE may have set it */
52405241
int offset = INSTR_OFFSET()-1;
52415242
int level, handler, lasti;
5242-
if (get_exception_handler(_PyFrame_GetCode(frame), offset, &level, &handler, &lasti) == 0) {
5243+
int handled = get_exception_handler(_PyFrame_GetCode(frame), offset, &level, &handler, &lasti);
5244+
if (handled == 0) {
52435245
// No handlers, so exit.
52445246
assert(_PyErr_Occurred(tstate));
52455247

@@ -5276,7 +5278,8 @@ dummy_func(
52765278
PUSH(PyStackRef_FromPyObjectSteal(exc));
52775279
next_instr = _PyFrame_GetBytecode(frame) + handler;
52785280

5279-
if (monitor_handled(tstate, frame, next_instr, exc) < 0) {
5281+
int err = monitor_handled(tstate, frame, next_instr, exc);
5282+
if (err < 0) {
52805283
goto exception_unwind;
52815284
}
52825285
/* Resume normal execution */
@@ -5285,10 +5288,11 @@ dummy_func(
52855288
lltrace_resume_frame(frame);
52865289
}
52875290
#endif
5291+
RELOAD_STACK();
52885292
DISPATCH();
52895293
}
52905294

5291-
label(exit_unwind) {
5295+
spilled label(exit_unwind) {
52925296
assert(_PyErr_Occurred(tstate));
52935297
_Py_LeaveRecursiveCallPy(tstate);
52945298
assert(frame->owner != FRAME_OWNED_BY_INTERPRETER);
@@ -5304,16 +5308,16 @@ dummy_func(
53045308
return NULL;
53055309
}
53065310
next_instr = frame->instr_ptr;
5307-
stack_pointer = _PyFrame_GetStackPointer(frame);
5311+
RELOAD_STACK();
53085312
goto error;
53095313
}
53105314

5311-
label(start_frame) {
5312-
if (_Py_EnterRecursivePy(tstate)) {
5315+
spilled label(start_frame) {
5316+
int too_deep = _Py_EnterRecursivePy(tstate);
5317+
if (too_deep) {
53135318
goto exit_unwind;
53145319
}
53155320
next_instr = frame->instr_ptr;
5316-
stack_pointer = _PyFrame_GetStackPointer(frame);
53175321

53185322
#ifdef LLTRACE
53195323
{
@@ -5332,6 +5336,7 @@ dummy_func(
53325336
assert(!_PyErr_Occurred(tstate));
53335337
#endif
53345338

5339+
RELOAD_STACK();
53355340
DISPATCH();
53365341
}
53375342

Python/generated_cases.c.h

Lines changed: 16 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Tools/cases_generator/analyzer.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ def size(self) -> int:
120120
return 0
121121

122122

123+
124+
123125
@dataclass
124126
class StackItem:
125127
name: str
@@ -218,7 +220,24 @@ def is_super(self) -> bool:
218220
return False
219221

220222

223+
class Label:
224+
225+
def __init__(self, name: str, spilled: bool, body: list[lexer.Token],properties: Properties):
226+
self.name = name
227+
self.spilled = spilled
228+
self.body = body
229+
self.properties = properties
230+
231+
size:int = 0
232+
output_stores: list[lexer.Token] = []
233+
instruction_size = None
234+
235+
def __str__(self) -> str:
236+
return f"label({self.name})"
237+
238+
221239
Part = Uop | Skip | Flush
240+
CodeSection = Uop | Label
222241

223242

224243
@dataclass
@@ -258,12 +277,6 @@ def is_super(self) -> bool:
258277
return False
259278

260279

261-
@dataclass
262-
class Label:
263-
name: str
264-
body: list[lexer.Token]
265-
266-
267280
@dataclass
268281
class PseudoInstruction:
269282
name: str
@@ -481,22 +494,24 @@ def in_frame_push(idx: int) -> bool:
481494
return refs
482495

483496

484-
def variable_used(node: parser.InstDef, name: str) -> bool:
497+
def variable_used(node: parser.CodeDef, name: str) -> bool:
485498
"""Determine whether a variable with a given name is used in a node."""
486499
return any(
487500
token.kind == "IDENTIFIER" and token.text == name for token in node.block.tokens
488501
)
489502

490503

491-
def oparg_used(node: parser.InstDef) -> bool:
504+
def oparg_used(node: parser.CodeDef) -> bool:
492505
"""Determine whether `oparg` is used in a node."""
493506
return any(
494507
token.kind == "IDENTIFIER" and token.text == "oparg" for token in node.tokens
495508
)
496509

497510

498-
def tier_variable(node: parser.InstDef) -> int | None:
511+
def tier_variable(node: parser.CodeDef) -> int | None:
499512
"""Determine whether a tier variable is used in a node."""
513+
if isinstance(node, parser.LabelDef):
514+
return None
500515
for token in node.tokens:
501516
if token.kind == "ANNOTATION":
502517
if token.text == "specializing":
@@ -506,15 +521,15 @@ def tier_variable(node: parser.InstDef) -> int | None:
506521
return None
507522

508523

509-
def has_error_with_pop(op: parser.InstDef) -> bool:
524+
def has_error_with_pop(op: parser.CodeDef) -> bool:
510525
return (
511526
variable_used(op, "ERROR_IF")
512527
or variable_used(op, "pop_1_error")
513528
or variable_used(op, "exception_unwind")
514529
)
515530

516531

517-
def has_error_without_pop(op: parser.InstDef) -> bool:
532+
def has_error_without_pop(op: parser.CodeDef) -> bool:
518533
return (
519534
variable_used(op, "ERROR_NO_POP")
520535
or variable_used(op, "pop_1_error")
@@ -644,7 +659,7 @@ def has_error_without_pop(op: parser.InstDef) -> bool:
644659
"restart_backoff_counter",
645660
)
646661

647-
def find_stmt_start(node: parser.InstDef, idx: int) -> lexer.Token:
662+
def find_stmt_start(node: parser.CodeDef, idx: int) -> lexer.Token:
648663
assert idx < len(node.block.tokens)
649664
while True:
650665
tkn = node.block.tokens[idx-1]
@@ -657,15 +672,15 @@ def find_stmt_start(node: parser.InstDef, idx: int) -> lexer.Token:
657672
return node.block.tokens[idx]
658673

659674

660-
def find_stmt_end(node: parser.InstDef, idx: int) -> lexer.Token:
675+
def find_stmt_end(node: parser.CodeDef, idx: int) -> lexer.Token:
661676
assert idx < len(node.block.tokens)
662677
while True:
663678
idx += 1
664679
tkn = node.block.tokens[idx]
665680
if tkn.kind == "SEMI":
666681
return node.block.tokens[idx+1]
667682

668-
def check_escaping_calls(instr: parser.InstDef, escapes: dict[lexer.Token, tuple[lexer.Token, lexer.Token]]) -> None:
683+
def check_escaping_calls(instr: parser.CodeDef, escapes: dict[lexer.Token, tuple[lexer.Token, lexer.Token]]) -> None:
669684
calls = {escapes[t][0] for t in escapes}
670685
in_if = 0
671686
tkn_iter = iter(instr.block.tokens)
@@ -684,7 +699,7 @@ def check_escaping_calls(instr: parser.InstDef, escapes: dict[lexer.Token, tuple
684699
elif tkn in calls and in_if:
685700
raise analysis_error(f"Escaping call '{tkn.text} in condition", tkn)
686701

687-
def find_escaping_api_calls(instr: parser.InstDef) -> dict[lexer.Token, tuple[lexer.Token, lexer.Token]]:
702+
def find_escaping_api_calls(instr: parser.CodeDef) -> dict[lexer.Token, tuple[lexer.Token, lexer.Token]]:
688703
result: dict[lexer.Token, tuple[lexer.Token, lexer.Token]] = {}
689704
tokens = instr.block.tokens
690705
for idx, tkn in enumerate(tokens):
@@ -736,7 +751,7 @@ def find_escaping_api_calls(instr: parser.InstDef) -> dict[lexer.Token, tuple[le
736751
}
737752

738753

739-
def always_exits(op: parser.InstDef) -> bool:
754+
def always_exits(op: parser.CodeDef) -> bool:
740755
depth = 0
741756
tkn_iter = iter(op.tokens)
742757
for tkn in tkn_iter:
@@ -795,7 +810,7 @@ def effect_depends_on_oparg_1(op: parser.InstDef) -> bool:
795810
return False
796811

797812

798-
def compute_properties(op: parser.InstDef) -> Properties:
813+
def compute_properties(op: parser.CodeDef) -> Properties:
799814
escaping_calls = find_escaping_api_calls(op)
800815
has_free = (
801816
variable_used(op, "PyCell_New")
@@ -826,6 +841,8 @@ def compute_properties(op: parser.InstDef) -> Properties:
826841
variable_used(op, "PyStackRef_CLEAR") or
827842
variable_used(op, "SETLOCAL")
828843
)
844+
pure = False if isinstance(op, parser.LabelDef) else "pure" in op.annotations
845+
no_save_ip = False if isinstance(op, parser.LabelDef) else "no_save_ip" in op.annotations
829846
return Properties(
830847
escaping_calls=escaping_calls,
831848
escapes=escapes,
@@ -844,8 +861,8 @@ def compute_properties(op: parser.InstDef) -> Properties:
844861
uses_locals=(variable_used(op, "GETLOCAL") or variable_used(op, "SETLOCAL"))
845862
and not has_free,
846863
has_free=has_free,
847-
pure="pure" in op.annotations,
848-
no_save_ip="no_save_ip" in op.annotations,
864+
pure=pure,
865+
no_save_ip=no_save_ip,
849866
tier=tier_variable(op),
850867
needs_prev=variable_used(op, "prev_instr"),
851868
)
@@ -1024,7 +1041,8 @@ def add_label(
10241041
label: parser.LabelDef,
10251042
labels: dict[str, Label],
10261043
) -> None:
1027-
labels[label.name] = Label(label.name, label.block.tokens)
1044+
properties = compute_properties(label)
1045+
labels[label.name] = Label(label.name, label.spilled, label.block.tokens, properties)
10281046

10291047

10301048
def assign_opcodes(

0 commit comments

Comments
 (0)