Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from mypy.nodes import (
ARG_POS,
LDEF,
BytesExpr,
CallExpr,
DictionaryComprehension,
Expand All @@ -28,7 +29,7 @@
TypeAlias,
Var,
)
from mypy.types import LiteralType, TupleType, get_proper_type, get_proper_types
from mypy.types import LiteralType, TupleType, Type, get_proper_type, get_proper_types
from mypyc.ir.ops import (
ERR_NEVER,
BasicBlock,
Expand Down Expand Up @@ -1241,3 +1242,28 @@ def get_expr_length_value(
return builder.builder.builtin_len(expr_reg, line, use_pyssize_t=use_pyssize_t)
# The expression result is known at compile time, so we can use a constant.
return Integer(length, c_pyssize_t_rprimitive if use_pyssize_t else short_int_rprimitive)


def _is_supported_forloop_iter(builder: IRBuilder, expr: Expression) -> bool:
if is_sequence_rprimitive(builder.node_type(expr)):
return True
if not isinstance(expr, CallExpr):
return False
if isinstance(expr.callee, RefExpr):
return expr.callee.fullname in {
"builtins.range",
"builtins.enumerate",
"builtins.zip",
"builtins.reversed",
}
elif isinstance(expr.callee, MemberExpr):
return expr.callee.fullname in {"keys", "values", "items"}
return False


def _create_iterable_lexpr(index_name: str, index_type: Type) -> NameExpr:
"""This helper spoofs a NameExpr to use as the lvalue in one of the for loop helpers."""
index = NameExpr(index_name)
index.kind = LDEF
index.node = Var(index_name, index_type)
return index
146 changes: 111 additions & 35 deletions mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
bytes_rprimitive,
c_int_rprimitive,
dict_rprimitive,
float_rprimitive,
int16_rprimitive,
int32_rprimitive,
int64_rprimitive,
Expand All @@ -69,6 +70,7 @@
is_int64_rprimitive,
is_int_rprimitive,
is_list_rprimitive,
is_object_rprimitive,
is_uint8_rprimitive,
list_rprimitive,
object_rprimitive,
Expand All @@ -79,7 +81,10 @@
from mypyc.irbuild.builder import IRBuilder
from mypyc.irbuild.constant_fold import constant_fold_expr
from mypyc.irbuild.for_helpers import (
_create_iterable_lexpr,
_is_supported_forloop_iter,
comprehension_helper,
for_loop_helper,
sequence_from_generator_preallocate_helper,
translate_list_comprehension,
translate_set_comprehension,
Expand Down Expand Up @@ -412,29 +417,70 @@ def translate_safe_generator_call(

@specialize_function("builtins.any")
def translate_any_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
if (
len(expr.args) == 1
and expr.arg_kinds == [ARG_POS]
and isinstance(expr.args[0], GeneratorExpr)
):
return any_all_helper(builder, expr.args[0], builder.false, lambda x: x, builder.true)
if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]:
arg = expr.args[0]
if isinstance(arg, GeneratorExpr):
return any_all_helper(builder, arg, builder.false, lambda x: x, builder.true)
elif _is_supported_forloop_iter(builder, arg):
retval = Register(bool_rprimitive)
builder.assign(retval, builder.false(), -1)
loop_exit = BasicBlock()
index_name = "__mypyc_any_item__"

def body_insts() -> None:
true_block = BasicBlock()
false_block = BasicBlock()
builder.add_bool_branch(builder.read(index_reg), true_block, false_block)
builder.activate_block(true_block)
builder.assign(retval, builder.true(), -1)
builder.goto(loop_exit)
builder.activate_block(false_block)

index_type = builder._analyze_iterable_item_type(arg)
index = _create_iterable_lexpr(index_name, index_type)
index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type)) # type: ignore [arg-type]

for_loop_helper(builder, index, arg, body_insts, None, is_async=False, line=expr.line)
builder.goto_and_activate(loop_exit)
return retval
return None


@specialize_function("builtins.all")
def translate_all_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
if (
len(expr.args) == 1
and expr.arg_kinds == [ARG_POS]
and isinstance(expr.args[0], GeneratorExpr)
):
return any_all_helper(
builder,
expr.args[0],
builder.true,
lambda x: builder.unary_op(x, "not", expr.line),
builder.false,
)
if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]:
arg = expr.args[0]
if isinstance(arg, GeneratorExpr):
return any_all_helper(
builder,
arg,
builder.true,
lambda x: builder.unary_op(x, "not", expr.line),
builder.false,
)

elif _is_supported_forloop_iter(builder, arg):
retval = Register(bool_rprimitive)
builder.assign(retval, builder.true(), -1)
loop_exit = BasicBlock()
index_name = "__mypyc_all_item__"

def body_insts() -> None:
true_block = BasicBlock()
false_block = BasicBlock()
builder.add_bool_branch(builder.read(index_reg), true_block, false_block)
builder.activate_block(false_block)
builder.assign(retval, builder.false(), -1)
builder.goto(loop_exit)
builder.activate_block(true_block)

index_type = builder._analyze_iterable_item_type(arg)
index = _create_iterable_lexpr(index_name, index_type)
index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type)) # type: ignore [arg-type]

for_loop_helper(builder, index, arg, body_insts, None, is_async=False, line=expr.line)
builder.goto_and_activate(loop_exit)
return retval
return None


Expand Down Expand Up @@ -470,11 +516,11 @@ def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> V
# - only one or two arguments given (if not, sum() has been given invalid arguments)
# - first argument is a Generator (there is no benefit to optimizing the performance of eg.
# sum([1, 2, 3]), so non-Generator Iterables are not handled)
if not (
len(expr.args) in (1, 2)
and expr.arg_kinds[0] == ARG_POS
and isinstance(expr.args[0], GeneratorExpr)
):
if not (len(expr.args) in (1, 2) and expr.arg_kinds[0] == ARG_POS):
return None

arg = expr.args[0]
if not isinstance(arg, GeneratorExpr) and not _is_supported_forloop_iter(builder, arg):
return None

# handle 'start' argument, if given
Expand All @@ -486,21 +532,51 @@ def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> V
else:
start_expr = IntExpr(0)

gen_expr = expr.args[0]
target_type = builder.node_type(expr)
retval = Register(target_type)
builder.assign(retval, builder.coerce(builder.accept(start_expr), target_type, -1), -1)
item_type = builder._analyze_iterable_item_type(arg)
item_rtype = builder.type_to_rtype(item_type)
start_rtype = builder.node_type(start_expr)

def gen_inner_stmts() -> None:
call_expr = builder.accept(gen_expr.left_expr)
builder.assign(retval, builder.binary_op(retval, call_expr, "+", -1), -1)
if item_rtype is start_rtype:
acc_rtype = item_rtype
elif is_float_rprimitive(item_rtype) and is_int_rprimitive(start_rtype):
acc_rtype = float_rprimitive
elif is_bool_rprimitive(item_rtype) and is_int_rprimitive(start_rtype):
acc_rtype = int_rprimitive
elif is_object_rprimitive(item_rtype) and is_int_rprimitive(start_rtype):
acc_rtype = object_rprimitive

loop_params = list(
zip(gen_expr.indices, gen_expr.sequences, gen_expr.condlists, gen_expr.is_async)
)
comprehension_helper(builder, loop_params, gen_inner_stmts, gen_expr.line)
else:
# escape hatch, maybe figure out a better way to handle this whole block
# seeking ideas in review
return None

return retval
retval = Register(acc_rtype)
builder.assign(retval, builder.coerce(builder.accept(start_expr), acc_rtype, -1), -1)

if isinstance(arg, GeneratorExpr):

def gen_inner_stmts() -> None:
call_expr = builder.accept(arg.left_expr)
builder.assign(retval, builder.binary_op(retval, call_expr, "+", -1), -1)

loop_params = list(zip(arg.indices, arg.sequences, arg.condlists, arg.is_async))
comprehension_helper(builder, loop_params, gen_inner_stmts, arg.line)

return retval

else:
index_name = "__mypyc_sum_item__"

def body_insts() -> None:
total = builder.binary_op(retval, builder.read(index_reg), "+", expr.line)
builder.assign(retval, total, expr.line)

index_type = builder._analyze_iterable_item_type(arg)
index = _create_iterable_lexpr(index_name, index_type)
index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type)) # type: ignore [arg-type]

for_loop_helper(builder, index, arg, body_insts, None, is_async=False, line=expr.line)
return retval


@specialize_function("dataclasses.field")
Expand Down
Loading