Skip to content

Commit 1676bdd

Browse files
committed
[mypyc] Speed up for loop over native generator
1 parent 35d8c69 commit 1676bdd

File tree

1 file changed

+67
-2
lines changed

1 file changed

+67
-2
lines changed

mypyc/irbuild/for_helpers.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,23 @@
2424
TypeAlias,
2525
)
2626
from mypyc.ir.ops import (
27+
ERR_NEVER,
2728
BasicBlock,
2829
Branch,
2930
Integer,
3031
IntOp,
3132
LoadAddress,
33+
LoadErrorValue,
3234
LoadMem,
35+
MethodCall,
3336
RaiseStandardError,
3437
Register,
3538
TupleGet,
3639
TupleSet,
3740
Value,
3841
)
3942
from mypyc.ir.rtypes import (
43+
RInstance,
4044
RTuple,
4145
RType,
4246
bool_rprimitive,
@@ -48,6 +52,8 @@
4852
is_short_int_rprimitive,
4953
is_str_rprimitive,
5054
is_tuple_rprimitive,
55+
object_pointer_rprimitive,
56+
object_rprimitive,
5157
pointer_rprimitive,
5258
short_int_rprimitive,
5359
)
@@ -62,7 +68,7 @@
6268
dict_next_value_op,
6369
dict_value_iter_op,
6470
)
65-
from mypyc.primitives.exc_ops import no_err_occurred_op
71+
from mypyc.primitives.exc_ops import no_err_occurred_op, propagate_if_error_op
6672
from mypyc.primitives.generic_ops import aiter_op, anext_op, iter_op, next_op
6773
from mypyc.primitives.list_ops import list_append_op, list_get_item_unsafe_op, new_list_set_item_op
6874
from mypyc.primitives.misc_ops import stop_async_iteration_op
@@ -511,7 +517,16 @@ def make_for_loop_generator(
511517
# Default to a generic for loop.
512518
if iterable_expr_reg is None:
513519
iterable_expr_reg = builder.accept(expr)
514-
for_obj = ForIterable(builder, index, body_block, loop_exit, line, nested)
520+
521+
helper_method = "__mypyc_generator_helper__"
522+
it = iterable_expr_reg.type
523+
for_obj: ForNativeGenerator | ForIterable
524+
if isinstance(it, RInstance) and it.class_ir.has_method(helper_method):
525+
# Directly call generator object methods if iterating over a native generator.
526+
for_obj = ForNativeGenerator(builder, index, body_block, loop_exit, line, nested)
527+
else:
528+
# Generic implementation that works of arbitrary iterables.
529+
for_obj = ForIterable(builder, index, body_block, loop_exit, line, nested)
515530
item_type = builder._analyze_iterable_item_type(expr)
516531
item_rtype = builder.type_to_rtype(item_type)
517532
for_obj.init(iterable_expr_reg, item_rtype)
@@ -623,6 +638,56 @@ def gen_cleanup(self) -> None:
623638
self.builder.call_c(no_err_occurred_op, [], self.line)
624639

625640

641+
class ForNativeGenerator(ForGenerator):
642+
"""Generate IR for a for loop over a native generator."""
643+
644+
def need_cleanup(self) -> bool:
645+
# Create a new cleanup block for when the loop is finished.
646+
return True
647+
648+
def init(self, expr_reg: Value, target_type: RType) -> None:
649+
# Define targets to contain the expression, along with the iterator that will be used
650+
# for the for-loop. If we are inside of a generator function, spill these into the
651+
# environment class.
652+
builder = self.builder
653+
self.iter_target = builder.maybe_spill(expr_reg)
654+
self.target_type = target_type
655+
656+
def gen_condition(self) -> None:
657+
builder = self.builder
658+
line = self.line
659+
helper_method = "__mypyc_generator_helper__"
660+
self.return_value = Register(object_rprimitive)
661+
err = builder.add(LoadErrorValue(object_rprimitive, undefines=True))
662+
builder.assign(self.return_value, err, line)
663+
ptr = builder.add(LoadAddress(object_pointer_rprimitive, self.return_value))
664+
nn = builder.none_object()
665+
helper_call =(
666+
MethodCall(builder.read(self.iter_target), helper_method, [nn, nn, nn, nn, ptr], line)
667+
)
668+
# We provide custom handling for error values.
669+
helper_call.error_kind = ERR_NEVER
670+
self.next_reg = builder.add(helper_call)
671+
builder.add(Branch(self.next_reg, self.loop_exit, self.body_block, Branch.IS_ERROR))
672+
673+
def begin_body(self) -> None:
674+
# Assign the value obtained from __next__ to the
675+
# lvalue so that it can be referenced by code in the body of the loop.
676+
builder = self.builder
677+
line = self.line
678+
# We unbox here so that iterating with tuple unpacking generates a tuple based
679+
# unpack instead of an iterator based one.
680+
next_reg = builder.coerce(self.next_reg, self.target_type, line)
681+
builder.assign(builder.get_assignment_target(self.index), next_reg, line)
682+
683+
def gen_step(self) -> None:
684+
# Nothing to do here, since we get the next item as part of gen_condition().
685+
pass
686+
687+
def gen_cleanup(self) -> None:
688+
self.builder.primitive_op(propagate_if_error_op, [self.return_value], self.line)
689+
690+
626691
class ForAsyncIterable(ForGenerator):
627692
"""Generate IR for an async for loop."""
628693

0 commit comments

Comments
 (0)