Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from mypy.nodes import (
ARG_POS,
LDEF,
BytesExpr,
CallExpr,
DictionaryComprehension,
Expand All @@ -28,7 +29,7 @@
TypeAlias,
Var,
)
from mypy.types import LiteralType, TupleType, get_proper_type, get_proper_types
from mypy.types import LiteralType, TupleType, Type, get_proper_type, get_proper_types
from mypyc.ir.ops import (
ERR_NEVER,
BasicBlock,
Expand Down Expand Up @@ -1241,3 +1242,27 @@ def get_expr_length_value(
return builder.builder.builtin_len(expr_reg, line, use_pyssize_t=use_pyssize_t)
# The expression result is known at compile time, so we can use a constant.
return Integer(length, c_pyssize_t_rprimitive if use_pyssize_t else short_int_rprimitive)


def expr_has_specialized_for_helper(builder: IRBuilder, expr: Expression) -> bool:
if is_sequence_rprimitive(builder.node_type(expr)):
return True
if not isinstance(expr, CallExpr):
return False
if isinstance(expr.callee, RefExpr):
return expr.callee.fullname in {"builtins.range", "builtins.enumerate", "builtins.zip"}
elif isinstance(expr.callee, MemberExpr):
return expr.callee.fullname in {
"builtins.dict.keys",
"builtins.dict.values",
"builtins.dict.items",
}
return False


def create_synthetic_nameexpr(index_name: str, index_type: Type) -> NameExpr:
"""This helper spoofs a NameExpr to use as the lvalue in one of the for loop helpers."""
index = NameExpr(index_name)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I probably need something here to ensure no name collisions globally, but not sure what

index.kind = LDEF
index.node = Var(index_name, index_type)
return index
80 changes: 62 additions & 18 deletions mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@
from mypyc.irbuild.builder import IRBuilder
from mypyc.irbuild.for_helpers import (
comprehension_helper,
create_synthetic_nameexpr,
expr_has_specialized_for_helper,
for_loop_helper,
sequence_from_generator_preallocate_helper,
translate_list_comprehension,
translate_set_comprehension,
Expand Down Expand Up @@ -412,29 +415,70 @@ def translate_safe_generator_call(

@specialize_function("builtins.any")
def translate_any_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
if (
len(expr.args) == 1
and expr.arg_kinds == [ARG_POS]
and isinstance(expr.args[0], GeneratorExpr)
):
return any_all_helper(builder, expr.args[0], builder.false, lambda x: x, builder.true)
if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]:
arg = expr.args[0]
if isinstance(arg, GeneratorExpr):
return any_all_helper(builder, arg, builder.false, lambda x: x, builder.true)
elif expr_has_specialized_for_helper(builder, arg):
retval = Register(bool_rprimitive)
builder.assign(retval, builder.false(), -1)
loop_exit = BasicBlock()
index_name = "__mypyc_any_item__"

def body_insts() -> None:
true_block = BasicBlock()
false_block = BasicBlock()
builder.add_bool_branch(builder.read(index_reg), true_block, false_block)
builder.activate_block(true_block)
builder.assign(retval, builder.true(), -1)
builder.goto(loop_exit)
builder.activate_block(false_block)

index_type = builder._analyze_iterable_item_type(arg)
index = create_synthetic_nameexpr(index_name, index_type)
index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type)) # type: ignore [arg-type]

for_loop_helper(builder, index, arg, body_insts, None, is_async=False, line=expr.line)
builder.goto_and_activate(loop_exit)
return retval
return None


@specialize_function("builtins.all")
def translate_all_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
if (
len(expr.args) == 1
and expr.arg_kinds == [ARG_POS]
and isinstance(expr.args[0], GeneratorExpr)
):
return any_all_helper(
builder,
expr.args[0],
builder.true,
lambda x: builder.unary_op(x, "not", expr.line),
builder.false,
)
if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]:
arg = expr.args[0]
if isinstance(arg, GeneratorExpr):
return any_all_helper(
builder,
arg,
builder.true,
lambda x: builder.unary_op(x, "not", expr.line),
builder.false,
)

elif expr_has_specialized_for_helper(builder, arg):
retval = Register(bool_rprimitive)
builder.assign(retval, builder.true(), -1)
loop_exit = BasicBlock()
index_name = "__mypyc_all_item__"

def body_insts() -> None:
true_block = BasicBlock()
false_block = BasicBlock()
builder.add_bool_branch(builder.read(index_reg), true_block, false_block)
builder.activate_block(false_block)
builder.assign(retval, builder.false(), -1)
builder.goto(loop_exit)
builder.activate_block(true_block)

index_type = builder._analyze_iterable_item_type(arg)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should I deduplicate this block or is this fine?

index = create_synthetic_nameexpr(index_name, index_type)
index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type)) # type: ignore [arg-type]

for_loop_helper(builder, index, arg, body_insts, None, is_async=False, line=expr.line)
builder.goto_and_activate(loop_exit)
return retval
return None


Expand Down
118 changes: 118 additions & 0 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -2812,6 +2812,12 @@ def call_any(l: Iterable[int]) -> bool:
def call_all(l: Iterable[int]) -> bool:
return all(i == 0 for i in l)

def call_any_helper(l: list[Iterable[int]]) -> bool:
return any([str(i) for i in l])

def call_all_helper(l: list[Iterable[int]]) -> bool:
return all([str(i) for i in l])

[out]
def call_any(l):
l :: object
Expand Down Expand Up @@ -2870,6 +2876,118 @@ L6:
L7:
L8:
return r0
def call_any_helper(l):
l :: list
r0 :: bool
r1 :: native_int
r2 :: list
r3, r4 :: native_int
r5 :: bit
r6, i :: object
r7 :: str
r8, r9, r10 :: native_int
r11 :: bit
r12 :: object
r13, __mypyc_any_item__ :: str
r14 :: bit
r15 :: native_int
L0:
r0 = 0
r1 = var_object_size l
r2 = PyList_New(r1)
r3 = 0
L1:
r4 = var_object_size l
r5 = r3 < r4 :: signed
if r5 goto L2 else goto L4 :: bool
L2:
r6 = list_get_item_unsafe l, r3
i = r6
r7 = PyObject_Str(i)
CPyList_SetItemUnsafe(r2, r3, r7)
L3:
r8 = r3 + 1
r3 = r8
goto L1
L4:
r9 = 0
L5:
r10 = var_object_size r2
r11 = r9 < r10 :: signed
if r11 goto L6 else goto L10 :: bool
L6:
r12 = list_get_item_unsafe r2, r9
r13 = cast(str, r12)
__mypyc_any_item__ = r13
r14 = CPyStr_IsTrue(__mypyc_any_item__)
Copy link
Contributor Author

@BobTheBuidler BobTheBuidler Sep 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The boolean check for the any call is now specialized per the dtype of the iterable, and we no longer have to use python's iterator protocol for the input

if r14 goto L7 else goto L8 :: bool
L7:
r0 = 1
goto L11
L8:
L9:
r15 = r9 + 1
r9 = r15
goto L5
L10:
L11:
return r0
def call_all_helper(l):
l :: list
r0 :: bool
r1 :: native_int
r2 :: list
r3, r4 :: native_int
r5 :: bit
r6, i :: object
r7 :: str
r8, r9, r10 :: native_int
r11 :: bit
r12 :: object
r13, __mypyc_all_item__ :: str
r14 :: bit
r15 :: native_int
L0:
r0 = 1
r1 = var_object_size l
r2 = PyList_New(r1)
r3 = 0
L1:
r4 = var_object_size l
r5 = r3 < r4 :: signed
if r5 goto L2 else goto L4 :: bool
L2:
r6 = list_get_item_unsafe l, r3
i = r6
r7 = PyObject_Str(i)
CPyList_SetItemUnsafe(r2, r3, r7)
L3:
r8 = r3 + 1
r3 = r8
goto L1
L4:
r9 = 0
L5:
r10 = var_object_size r2
r11 = r9 < r10 :: signed
if r11 goto L6 else goto L10 :: bool
L6:
r12 = list_get_item_unsafe r2, r9
r13 = cast(str, r12)
__mypyc_all_item__ = r13
r14 = CPyStr_IsTrue(__mypyc_all_item__)
Copy link
Contributor Author

@BobTheBuidler BobTheBuidler Sep 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The boolean check for the all call is now specialized per the dtype of the iterable, and we no longer have to use python's iterator protocol for the input

if r14 goto L8 else goto L7 :: bool
L7:
r0 = 0
goto L11
L8:
L9:
r15 = r9 + 1
r9 = r15
goto L5
L10:
L11:
return r0

[case testSum]
from typing import Callable, Iterable
Expand Down
30 changes: 29 additions & 1 deletion mypyc/test-data/run-misc.test
Original file line number Diff line number Diff line change
Expand Up @@ -794,8 +794,18 @@ def call_all(l: Iterable[int], val: int = 0) -> int:
res = all(i == val for i in l)
return 0 if res else 1

def call_any_with_for_helper(l: Iterable[int], val: int = 0) -> int:
# this listcomp isnt a reasonable real world use but proves the any for loop specializer is good
res = any([i == val for i in l])
return 0 if res else 1

def call_all_with_for_helper(l: Iterable[int], val: int = 0) -> int:
# this listcomp isnt a reasonable real world use but proves the all for loop specializer is good
res = all([i == val for i in l])
return 0 if res else 1

[file driver.py]
from native import call_any, call_all, call_any_nested
from native import call_any, call_all, call_any_nested, call_any_with_for_helper, call_all_with_for_helper

zeros = [0, 0, 0]
ones = [1, 1, 1]
Expand All @@ -807,24 +817,42 @@ mixed_101 = [1, 0, 1]
mixed_110 = [1, 1, 0]

assert call_any([]) == 1
assert call_any_with_for_helper([]) == 1
assert call_any(zeros) == 0
assert call_any_with_for_helper(zeros) == 0
assert call_any(ones) == 1
assert call_any_with_for_helper(ones) == 1
assert call_any(mixed_001) == 0
assert call_any_with_for_helper(mixed_001) == 0
assert call_any(mixed_010) == 0
assert call_any_with_for_helper(mixed_010) == 0
assert call_any(mixed_100) == 0
assert call_any_with_for_helper(mixed_100) == 0
assert call_any(mixed_011) == 0
assert call_any_with_for_helper(mixed_011) == 0
assert call_any(mixed_101) == 0
assert call_any_with_for_helper(mixed_101) == 0
assert call_any(mixed_110) == 0
assert call_any_with_for_helper(mixed_110) == 0

assert call_all([]) == 0
assert call_all_with_for_helper([]) == 0
assert call_all(zeros) == 0
assert call_all_with_for_helper(zeros) == 0
assert call_all(ones) == 1
assert call_all_with_for_helper(ones) == 1
assert call_all(mixed_001) == 1
assert call_all_with_for_helper(mixed_001) == 1
assert call_all(mixed_010) == 1
assert call_all_with_for_helper(mixed_010) == 1
assert call_all(mixed_100) == 1
assert call_all_with_for_helper(mixed_100) == 1
assert call_all(mixed_011) == 1
assert call_all_with_for_helper(mixed_011) == 1
assert call_all(mixed_101) == 1
assert call_all_with_for_helper(mixed_101) == 1
assert call_all(mixed_110) == 1
assert call_all_with_for_helper(mixed_110) == 1

assert call_any_nested([[1, 1, 1], [1, 1], []]) == 1
assert call_any_nested([[1, 1, 1], [0, 1], []]) == 0
Expand Down