-
-
Notifications
You must be signed in to change notification settings - Fork 3k
[mypyc] feat: specialize any
and all
using for loop helpers if possible [1/2]
#19948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
05b3d71
841878c
02d4739
b0edaeb
7d98f34
7ef9f18
9c98174
819efe2
2f15dc3
88ab77c
24a4d1f
5fdc197
f5480bd
13d5aca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,6 +80,9 @@ | |
from mypyc.irbuild.builder import IRBuilder | ||
from mypyc.irbuild.for_helpers import ( | ||
comprehension_helper, | ||
create_synthetic_nameexpr, | ||
expr_has_specialized_for_helper, | ||
for_loop_helper, | ||
sequence_from_generator_preallocate_helper, | ||
translate_list_comprehension, | ||
translate_set_comprehension, | ||
|
@@ -412,29 +415,70 @@ def translate_safe_generator_call( | |
|
||
@specialize_function("builtins.any") | ||
def translate_any_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: | ||
if ( | ||
len(expr.args) == 1 | ||
and expr.arg_kinds == [ARG_POS] | ||
and isinstance(expr.args[0], GeneratorExpr) | ||
): | ||
return any_all_helper(builder, expr.args[0], builder.false, lambda x: x, builder.true) | ||
if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]: | ||
arg = expr.args[0] | ||
if isinstance(arg, GeneratorExpr): | ||
return any_all_helper(builder, arg, builder.false, lambda x: x, builder.true) | ||
elif expr_has_specialized_for_helper(builder, arg): | ||
retval = Register(bool_rprimitive) | ||
builder.assign(retval, builder.false(), -1) | ||
loop_exit = BasicBlock() | ||
index_name = "__mypyc_any_item__" | ||
|
||
def body_insts() -> None: | ||
true_block = BasicBlock() | ||
false_block = BasicBlock() | ||
builder.add_bool_branch(builder.read(index_reg), true_block, false_block) | ||
builder.activate_block(true_block) | ||
builder.assign(retval, builder.true(), -1) | ||
builder.goto(loop_exit) | ||
builder.activate_block(false_block) | ||
|
||
index_type = builder._analyze_iterable_item_type(arg) | ||
index = create_synthetic_nameexpr(index_name, index_type) | ||
index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type)) # type: ignore [arg-type] | ||
|
||
for_loop_helper(builder, index, arg, body_insts, None, is_async=False, line=expr.line) | ||
builder.goto_and_activate(loop_exit) | ||
return retval | ||
return None | ||
|
||
|
||
@specialize_function("builtins.all") | ||
def translate_all_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: | ||
if ( | ||
len(expr.args) == 1 | ||
and expr.arg_kinds == [ARG_POS] | ||
and isinstance(expr.args[0], GeneratorExpr) | ||
): | ||
return any_all_helper( | ||
builder, | ||
expr.args[0], | ||
builder.true, | ||
lambda x: builder.unary_op(x, "not", expr.line), | ||
builder.false, | ||
) | ||
if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]: | ||
arg = expr.args[0] | ||
if isinstance(arg, GeneratorExpr): | ||
return any_all_helper( | ||
builder, | ||
arg, | ||
builder.true, | ||
lambda x: builder.unary_op(x, "not", expr.line), | ||
builder.false, | ||
) | ||
|
||
elif expr_has_specialized_for_helper(builder, arg): | ||
retval = Register(bool_rprimitive) | ||
builder.assign(retval, builder.true(), -1) | ||
loop_exit = BasicBlock() | ||
index_name = "__mypyc_all_item__" | ||
|
||
def body_insts() -> None: | ||
true_block = BasicBlock() | ||
false_block = BasicBlock() | ||
builder.add_bool_branch(builder.read(index_reg), true_block, false_block) | ||
builder.activate_block(false_block) | ||
builder.assign(retval, builder.false(), -1) | ||
builder.goto(loop_exit) | ||
builder.activate_block(true_block) | ||
|
||
index_type = builder._analyze_iterable_item_type(arg) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should I deduplicate this block or is this fine? |
||
index = create_synthetic_nameexpr(index_name, index_type) | ||
index_reg = builder.add_local_reg(index.node, builder.type_to_rtype(index_type)) # type: ignore [arg-type] | ||
|
||
for_loop_helper(builder, index, arg, body_insts, None, is_async=False, line=expr.line) | ||
builder.goto_and_activate(loop_exit) | ||
return retval | ||
return None | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2812,6 +2812,12 @@ def call_any(l: Iterable[int]) -> bool: | |
def call_all(l: Iterable[int]) -> bool: | ||
return all(i == 0 for i in l) | ||
|
||
def call_any_helper(l: list[Iterable[int]]) -> bool: | ||
return any([str(i) for i in l]) | ||
|
||
def call_all_helper(l: list[Iterable[int]]) -> bool: | ||
return all([str(i) for i in l]) | ||
|
||
[out] | ||
def call_any(l): | ||
l :: object | ||
|
@@ -2870,6 +2876,118 @@ L6: | |
L7: | ||
L8: | ||
return r0 | ||
def call_any_helper(l): | ||
l :: list | ||
r0 :: bool | ||
r1 :: native_int | ||
r2 :: list | ||
r3, r4 :: native_int | ||
r5 :: bit | ||
r6, i :: object | ||
r7 :: str | ||
r8, r9, r10 :: native_int | ||
r11 :: bit | ||
r12 :: object | ||
r13, __mypyc_any_item__ :: str | ||
r14 :: bit | ||
r15 :: native_int | ||
L0: | ||
r0 = 0 | ||
r1 = var_object_size l | ||
r2 = PyList_New(r1) | ||
r3 = 0 | ||
L1: | ||
r4 = var_object_size l | ||
r5 = r3 < r4 :: signed | ||
if r5 goto L2 else goto L4 :: bool | ||
L2: | ||
r6 = list_get_item_unsafe l, r3 | ||
i = r6 | ||
r7 = PyObject_Str(i) | ||
CPyList_SetItemUnsafe(r2, r3, r7) | ||
L3: | ||
r8 = r3 + 1 | ||
r3 = r8 | ||
goto L1 | ||
L4: | ||
r9 = 0 | ||
L5: | ||
r10 = var_object_size r2 | ||
r11 = r9 < r10 :: signed | ||
if r11 goto L6 else goto L10 :: bool | ||
L6: | ||
r12 = list_get_item_unsafe r2, r9 | ||
r13 = cast(str, r12) | ||
__mypyc_any_item__ = r13 | ||
r14 = CPyStr_IsTrue(__mypyc_any_item__) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The boolean check for the |
||
if r14 goto L7 else goto L8 :: bool | ||
L7: | ||
r0 = 1 | ||
goto L11 | ||
L8: | ||
L9: | ||
r15 = r9 + 1 | ||
r9 = r15 | ||
goto L5 | ||
L10: | ||
L11: | ||
return r0 | ||
def call_all_helper(l): | ||
l :: list | ||
r0 :: bool | ||
r1 :: native_int | ||
r2 :: list | ||
r3, r4 :: native_int | ||
r5 :: bit | ||
r6, i :: object | ||
r7 :: str | ||
r8, r9, r10 :: native_int | ||
r11 :: bit | ||
r12 :: object | ||
r13, __mypyc_all_item__ :: str | ||
r14 :: bit | ||
r15 :: native_int | ||
L0: | ||
r0 = 1 | ||
r1 = var_object_size l | ||
r2 = PyList_New(r1) | ||
r3 = 0 | ||
L1: | ||
r4 = var_object_size l | ||
r5 = r3 < r4 :: signed | ||
if r5 goto L2 else goto L4 :: bool | ||
L2: | ||
r6 = list_get_item_unsafe l, r3 | ||
i = r6 | ||
r7 = PyObject_Str(i) | ||
CPyList_SetItemUnsafe(r2, r3, r7) | ||
L3: | ||
r8 = r3 + 1 | ||
r3 = r8 | ||
goto L1 | ||
L4: | ||
r9 = 0 | ||
L5: | ||
r10 = var_object_size r2 | ||
r11 = r9 < r10 :: signed | ||
if r11 goto L6 else goto L10 :: bool | ||
L6: | ||
r12 = list_get_item_unsafe r2, r9 | ||
r13 = cast(str, r12) | ||
__mypyc_all_item__ = r13 | ||
r14 = CPyStr_IsTrue(__mypyc_all_item__) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The boolean check for the |
||
if r14 goto L8 else goto L7 :: bool | ||
L7: | ||
r0 = 0 | ||
goto L11 | ||
L8: | ||
L9: | ||
r15 = r9 + 1 | ||
r9 = r15 | ||
goto L5 | ||
L10: | ||
L11: | ||
return r0 | ||
|
||
[case testSum] | ||
from typing import Callable, Iterable | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I probably need something here to ensure no name collisions globally, but not sure what