Skip to content

Commit 3a7095a

Browse files
h-joocopybara-github
authored andcommitted
Change the iteration order for functions containing async for and yield from.
The iteration order is determined by pytype constructing control flow information by looking at the bytecode instructions. It determines which instructions to elide, and how to connect BB with each other to decide how it should run within pytype's VM. The thing is, the implementation before was a bit incomplete in terms of detecting all control flow including exceptions. It was making some assumptions on what instructions or group of instructions comes after another, which did not hold anymore for python 3.12. In python 3.12 the instruction order around async construct has changed, also some new instructions were added (END_SEND) and how the instructions jump to one another too has changed. Due to this reason, pytype starts to break in 3.12 because of the iteration order being different compared to the real runtime, and it fails due to the wrong order of execution, and the result is that it fails due to insufficient stack elements when it's expecting some elements to be present at a moment. We can try to fix it to make pytype comprehend the full control graph, but I think that's going to take a bit longer to implement. Rather than doing that, with this change we group the basic blocks which are coming from async constructs into a single basic block, to prevent from getting split by the regular BB analyzer so that it runs sequentially without accidentally following the wrong control flow which never happens in the python runtime. PiperOrigin-RevId: 748587745
1 parent be15999 commit 3a7095a

File tree

4 files changed

+254
-34
lines changed

4 files changed

+254
-34
lines changed

pytype/blocks/blocks.py

Lines changed: 162 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Functions for computing the execution order of bytecode."""
22

33
from collections.abc import Iterator
4-
from typing import Any, cast
4+
from typing import Any, Sequence, cast
55
from pycnite import bytecode as pyc_bytecode
66
from pycnite import marshal as pyc_marshal
77
import pycnite.types
@@ -316,7 +316,9 @@ def add_pop_block_targets(bytecode: list[opcodes.Opcode]) -> None:
316316
todo.append((op.next, block_stack))
317317

318318

319-
def _split_bytecode(bytecode: list[opcodes.Opcode]) -> list[Block]:
319+
def _split_bytecode(
320+
bytecode: list[opcodes.Opcode], processed_blocks: set[Block]
321+
) -> list[Block]:
320322
"""Given a sequence of bytecodes, return basic blocks.
321323
322324
This will split the code at "basic block boundaries". These occur at
@@ -333,20 +335,169 @@ def _split_bytecode(bytecode: list[opcodes.Opcode]) -> list[Block]:
333335
targets = {op.target for op in bytecode if op.target}
334336
blocks = []
335337
code = []
336-
for op in bytecode:
338+
prev_block: Block = None
339+
i = 0
340+
while i < len(bytecode):
341+
op = bytecode[i]
342+
# SEND is only used in the context of async for and `yield from`.
343+
# These instructions are not used in other context, so it's safe to process
344+
# it assuming that these are the only constructs they're being used.
345+
if isinstance(op, opcodes.SEND):
346+
if code:
347+
prev_block = Block(code)
348+
blocks.append(prev_block)
349+
code = []
350+
new_blocks, i = _preprocess_async_for_and_yield(
351+
i, bytecode, prev_block, processed_blocks
352+
)
353+
blocks.extend(new_blocks)
354+
prev_block = blocks[-1]
355+
continue
356+
337357
code.append(op)
338358
if (
339359
op.no_next()
340360
or op.does_jump()
341361
or op.pops_block()
342362
or op.next is None
343-
or op.next in targets
363+
or (op.next in targets)
364+
and not isinstance(op.next, opcodes.GET_ANEXT)
344365
):
345-
blocks.append(Block(code))
366+
prev_block = Block(code)
367+
blocks.append(prev_block)
346368
code = []
369+
i += 1
370+
347371
return blocks
348372

349373

374+
def _preprocess_async_for_and_yield(
375+
idx: int,
376+
bytecode: Sequence[opcodes.Opcode],
377+
prev_block: Block,
378+
processed_blocks: set[Block],
379+
) -> tuple[list[Block], int]:
380+
"""Process bytecode instructions for yield and async for in a way that pytype can iterate correctly.
381+
382+
'Async for' and yield statements, contains instructions that starts with SEND
383+
and ends with END_SEND.
384+
385+
The reason why we need to pre process async for is because the control flow of
386+
async for is drastically different from regular control flows also due to the
387+
fact that the termination of the loop happens by STOP_ASYNC_ITERATION
388+
exception, not a regular control flow. So we need to split (or merge) the
389+
basic blocks in a way that pytype executes in the order that what'd happen in
390+
the runtime, so that it doesn't fail with wrong order of execution, which can
391+
result in a stack underrun.
392+
393+
Args:
394+
idx: The index of the SEND instruction.
395+
bytecode: A list of instances of opcodes.Opcode
396+
prev_block: The previous block that we want to connect the new blocks to.
397+
processed_blocks: Blocks that has been processed so that it doesn't get
398+
processed again by compute_order.
399+
400+
Returns:
401+
A tuple of (list[Block], int), where the Block is the block containing the
402+
iteration part of the async for construct, and the int is the index of the
403+
END_SEND instruction.
404+
"""
405+
assert isinstance(bytecode[idx], opcodes.SEND)
406+
i = next(
407+
i
408+
for i in range(idx + 1, len(bytecode))
409+
if isinstance(bytecode[i], opcodes.JUMP_BACKWARD_NO_INTERRUPT)
410+
)
411+
412+
end_block_idx = i + 1
413+
# In CLEANUP_THROW can be present after JUMP_BACKWARD_NO_INTERRUPT
414+
# depending on how the control flow graph is constructed.
415+
# Usually, CLEANUP_THROW comes way after
416+
if isinstance(bytecode[end_block_idx], opcodes.CLEANUP_THROW):
417+
end_block_idx += 1
418+
419+
# Somehow pytype expects the SEND and YIELD_VALUE to be in different
420+
# blocks, so we need to split.
421+
send_block = Block(bytecode[idx : idx + 1])
422+
yield_value_block = Block(bytecode[idx + 1 : end_block_idx])
423+
prev_block.connect_outgoing(send_block)
424+
send_block.connect_outgoing(yield_value_block)
425+
processed_blocks.update(send_block, yield_value_block)
426+
return [send_block, yield_value_block], end_block_idx
427+
428+
429+
def _remove_jmp_to_get_anext_and_merge(
430+
blocks: list[Block], processed_blocks: set[Block]
431+
) -> list[Block]:
432+
"""Remove JUMP_BACKWARD instructions to GET_ANEXT instructions.
433+
434+
And also merge the block that contains the END_ASYNC_FOR which is part of the
435+
same loop of the GET_ANEXT and JUMP_BACKWARD construct, to the JUMP_BACKWARD
436+
instruction. This is to ignore the JUMP_BACKWARD because in pytype's eyes it's
437+
useless (as it'll jump back to block that it already executed), and also
438+
this is the way to make pytype run the code of END_ASYNC_FOR and whatever
439+
comes afterwards.
440+
441+
Args:
442+
blocks: A list of Block instances.
443+
444+
Returns:
445+
A list of Block instances after the removal and merge.
446+
"""
447+
op_to_block = {}
448+
merge_list = []
449+
for block_idx, block in enumerate(blocks):
450+
for code in block.code:
451+
op_to_block[code] = block_idx
452+
453+
for block_idx, block in enumerate(blocks):
454+
for code in block.code:
455+
if code.end_async_for_target:
456+
merge_list.append((block_idx, op_to_block[code.end_async_for_target]))
457+
map_target = {}
458+
for block_idx, block_idx_to_merge in merge_list:
459+
# Remove JUMP_BACKWARD instruction as we don't want to execute it.
460+
jump_back_op = blocks[block_idx].code.pop()
461+
blocks[block_idx].code.extend(blocks[block_idx_to_merge].code)
462+
map_target[jump_back_op] = blocks[block_idx_to_merge].code[0]
463+
464+
if block_idx_to_merge < len(blocks) - 1:
465+
blocks[block_idx].connect_outgoing(blocks[block_idx_to_merge + 1])
466+
processed_blocks.add(blocks[block_idx])
467+
468+
to_delete = sorted({to_idx for _, to_idx in merge_list}, reverse=True)
469+
470+
for block_idx in to_delete:
471+
del blocks[block_idx]
472+
473+
for block in blocks:
474+
replace_op = map_target.get(block.code[-1].target, None)
475+
if replace_op:
476+
block.code[-1].target = replace_op
477+
478+
return blocks
479+
480+
481+
def _remove_jump_back_block(blocks: list[Block]):
482+
"""Remove JUMP_BACKWARD instructions which are exception handling for async for.
483+
484+
These are not needed as a regular pytype control flow.
485+
"""
486+
new_blocks = []
487+
for block in blocks:
488+
last_op = block.code[-1]
489+
if (
490+
isinstance(last_op, opcodes.JUMP_BACKWARD)
491+
and isinstance(last_op.target, opcodes.END_SEND)
492+
and len(block.code) >= 2
493+
and isinstance(block.code[-2], opcodes.CLEANUP_THROW)
494+
):
495+
continue
496+
new_blocks.append(block)
497+
498+
return new_blocks
499+
500+
350501
def compute_order(bytecode: list[opcodes.Opcode]) -> list[Block]:
351502
"""Split bytecode into blocks and order the blocks.
352503
@@ -359,10 +510,15 @@ def compute_order(bytecode: list[opcodes.Opcode]) -> list[Block]:
359510
Returns:
360511
A list of Block instances.
361512
"""
362-
blocks = _split_bytecode(bytecode)
513+
processed_blocks = set()
514+
blocks = _split_bytecode(bytecode, processed_blocks)
515+
blocks = _remove_jump_back_block(blocks)
516+
blocks = _remove_jmp_to_get_anext_and_merge(blocks, processed_blocks)
363517
first_op_to_block = {block.code[0]: block for block in blocks}
364518
for i, block in enumerate(blocks):
365519
next_block = blocks[i + 1] if i < len(blocks) - 1 else None
520+
if block in processed_blocks:
521+
continue
366522
first_op, last_op = block.code[0], block.code[-1]
367523
if next_block and not last_op.no_next():
368524
block.connect_outgoing(next_block)

pytype/pyc/opcodes.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class Opcode:
4848
"prev",
4949
"next",
5050
"target",
51+
"end_async_for_target",
5152
"block_target",
5253
"code",
5354
"annotation",
@@ -67,6 +68,9 @@ def __init__(self, index, line, endline=None, col=None, endcol=None):
6768
self.prev = None
6869
self.next = None
6970
self.target = None
71+
# The END_ASYNC_FOR instruction of which we want to make pytype jump to for
72+
# this instruction.
73+
self.end_async_for_target = None
7074
self.block_target = None
7175
self.code = None # If we have a CodeType or OrderedCode parent
7276
self.annotation = None
@@ -1306,30 +1310,6 @@ def _should_elide_opcode(
13061310
and isinstance(op_items[i + 1][1], END_ASYNC_FOR)
13071311
)
13081312

1309-
# In 3.12 all generators are compiled into infinite loops, too. In addition,
1310-
# YIELD_VALUE inserts exception handling instructions:
1311-
# CLEANUP_THROW
1312-
# JUMP_BACKWARD
1313-
# These can appear on their own or they can be inserted between JUMP_BACKWARD
1314-
# and END_ASYNC_FOR, possibly many times. We keep eliding the `async for` jump
1315-
# and also elide the exception handling cleanup codes because they're not
1316-
# relevant for pytype and complicate the block graph.
1317-
if python_version == (3, 12):
1318-
return (
1319-
isinstance(op, CLEANUP_THROW)
1320-
or (
1321-
isinstance(op, JUMP_BACKWARD)
1322-
and i >= 1
1323-
and isinstance(op_items[i - 1][1], CLEANUP_THROW)
1324-
)
1325-
or (
1326-
isinstance(op, JUMP_BACKWARD)
1327-
and isinstance(
1328-
_get_opcode_following_cleanup_throw_jump_pairs(op_items, i + 1),
1329-
END_ASYNC_FOR,
1330-
)
1331-
)
1332-
)
13331313
return False
13341314

13351315

@@ -1372,13 +1352,41 @@ def _add_jump_targets(ops, offset_to_index):
13721352
op.target = ops[op.arg]
13731353

13741354

1355+
def _add_async_for_jump_back_targets(
1356+
ops: list[Opcode],
1357+
offset_to_op: dict[int, Opcode],
1358+
exc_table: pycnite.types.ExceptionTable,
1359+
):
1360+
"""Find the END_ASYNC_FOR target of which is related to a JUMP_BACKWARD instruction.
1361+
1362+
Also, assign them in a attribute end_async_for_target so that we can process
1363+
it later.
1364+
"""
1365+
1366+
get_anext_incoming: dict[JUMP_BACKWARD, set[GET_ANEXT]] = {}
1367+
for op in ops:
1368+
if isinstance(op, JUMP_BACKWARD) and isinstance(op.target, GET_ANEXT):
1369+
if op.target not in get_anext_incoming:
1370+
get_anext_incoming[op.target] = set()
1371+
get_anext_incoming[op.target].add(op)
1372+
1373+
for e in exc_table.entries:
1374+
if e.start in offset_to_op and isinstance(offset_to_op[e.start], GET_ANEXT):
1375+
get_anext = offset_to_op[e.start]
1376+
if get_anext not in get_anext_incoming:
1377+
continue
1378+
for jump_backward in get_anext_incoming[get_anext]:
1379+
jump_backward.end_async_for_target = offset_to_op[e.target]
1380+
1381+
13751382
def build_opcodes(dis_code: pycnite.types.DisassembledCode) -> list[Opcode]:
13761383
"""Build a list of opcodes from pycnite opcodes."""
13771384
offset_to_op = _make_opcodes(dis_code.opcodes, dis_code.python_version)
13781385
if dis_code.exception_table:
13791386
_add_setup_except(offset_to_op, dis_code.exception_table)
13801387
ops, offset_to_idx = _make_opcode_list(offset_to_op, dis_code.python_version)
13811388
_add_jump_targets(ops, offset_to_idx)
1389+
_add_async_for_jump_back_targets(ops, offset_to_op, dis_code.exception_table)
13821390
return ops
13831391

13841392

pytype/state.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,6 @@ def merge_into(self, other):
206206
self.data_stack,
207207
other.data_stack,
208208
)
209-
assert len(self.block_stack) == len(other.block_stack), (
210-
self.block_stack,
211-
other.block_stack,
212-
)
213209
both = list(zip(self.data_stack, other.data_stack))
214210
if any(v1 is not v2 for v1, v2 in both):
215211
for v, o in both:

pytype/tests/test_async_generators.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,66 @@ async def gen():
406406
x4: Coroutine[Any, Any, None] = gen().aclose()
407407
""")
408408

409+
@test_utils.skipBeforePy((3, 11), "New in 3.11")
410+
def test_async_gen_coroutines_error(self):
411+
"""Test whether the async for within async with does not fail at runtime."""
412+
self.Check("""
413+
def outer(f):
414+
async def wrapper(t, *args, **kwargs):
415+
if t is None:
416+
async with f():
417+
async for c in f():
418+
yield c
419+
else:
420+
async for c in f():
421+
yield c
422+
return wrapper
423+
""")
424+
425+
@test_utils.skipBeforePy((3, 11), "New in 3.11")
426+
def test_abb_asdf(self):
427+
self.Check("""
428+
async def iterate(num):
429+
try:
430+
async for s in range(num): # pytype: disable=attribute-error
431+
if s > 3:
432+
yield ''
433+
except ValueError as e:
434+
yield ''
435+
yield ''
436+
""")
437+
438+
@test_utils.skipBeforePy((3, 11), "New in 3.11")
439+
def test_abb_asdf2(self):
440+
self.Check("""
441+
from typing import Any
442+
import random
443+
async def iterate(stream: Any):
444+
async for _ in stream:
445+
if (random.randint(0, 100) != 30 or random.randint(0, 100) != 40):
446+
continue
447+
yield random.randint(0, 100)
448+
""")
449+
450+
@test_utils.skipBeforePy((3, 11), "New in 3.11")
451+
def test_async_double_for_loop(self):
452+
self.Check("""
453+
def outer(f):
454+
async def wrapper(t, *args, **kwargs):
455+
if t is None:
456+
async with f():
457+
async for c in f():
458+
async for d in f():
459+
yield c + d
460+
yield c
461+
else:
462+
async for c in f():
463+
async for d in f():
464+
yield c + d
465+
yield c
466+
return wrapper
467+
""")
468+
409469

410470
if __name__ == "__main__":
411471
test_base.main()

0 commit comments

Comments
 (0)