Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 8 additions & 17 deletions mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from mypy.nodes import (
ARG_POS,
BytesExpr,
CallExpr,
DictionaryComprehension,
Expression,
Expand All @@ -23,10 +22,8 @@
RefExpr,
SetExpr,
StarExpr,
StrExpr,
TupleExpr,
TypeAlias,
Var,
)
from mypy.types import LiteralType, TupleType, get_proper_type, get_proper_types
from mypyc.ir.ops import (
Expand Down Expand Up @@ -67,6 +64,7 @@
short_int_rprimitive,
)
from mypyc.irbuild.builder import IRBuilder
from mypyc.irbuild.constant_fold import constant_fold_expr
from mypyc.irbuild.prepare import GENERATOR_HELPER_NAME
from mypyc.irbuild.targets import AssignmentTarget, AssignmentTargetTuple
from mypyc.primitives.dict_ops import (
Expand Down Expand Up @@ -1203,26 +1201,19 @@ def gen_cleanup(self) -> None:
gen.gen_cleanup()


def get_expr_length(expr: Expression) -> int | None:
if isinstance(expr, (StrExpr, BytesExpr)):
return len(expr.value)
def get_expr_length(builder: IRBuilder, expr: Expression) -> int | None:
folded = constant_fold_expr(builder, expr)
if isinstance(folded, (str, bytes)):
return len(folded)
elif isinstance(expr, (ListExpr, TupleExpr)):
# if there are no star expressions, or we know the length of them,
# we know the length of the expression
stars = [get_expr_length(i) for i in expr.items if isinstance(i, StarExpr)]
stars = [get_expr_length(builder, i) for i in expr.items if isinstance(i, StarExpr)]
if None not in stars:
other = sum(not isinstance(i, StarExpr) for i in expr.items)
return other + sum(stars) # type: ignore [arg-type]
elif isinstance(expr, StarExpr):
return get_expr_length(expr.expr)
elif (
isinstance(expr, RefExpr)
and isinstance(expr.node, Var)
and expr.node.is_final
and isinstance(expr.node.final_value, str)
and expr.node.has_explicit_value
):
return len(expr.node.final_value)
return get_expr_length(builder, expr.expr)
# TODO: extend this, passing length of listcomp and genexp should have worthwhile
# performance boost and can be (sometimes) figured out pretty easily. set and dict
# comps *can* be done as well but will need special logic to consider the possibility
Expand All @@ -1235,7 +1226,7 @@ def get_expr_length_value(
) -> Value:
rtype = builder.node_type(expr)
assert is_sequence_rprimitive(rtype) or isinstance(rtype, RTuple), rtype
length = get_expr_length(expr)
length = get_expr_length(builder, expr)
if length is None:
# We cannot compute the length at compile time, so we will fetch it.
return builder.builder.builtin_len(expr_reg, line, use_pyssize_t=use_pyssize_t)
Expand Down
48 changes: 48 additions & 0 deletions mypyc/test-data/irbuild-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,54 @@ L4:
a = r1
return 1

[case testTupleBuiltFromConstantFolding]
from typing import Final

c: Final = "c"

def f2(val: str) -> str:
return val + "f2"

def test() -> None:
# `"ab" + c` should constant fold to "abc"
a = tuple(f2(x) for x in "ab" + c)

[out]
def f2(val):
val, r0, r1 :: str
L0:
r0 = 'f2'
r1 = PyUnicode_Concat(val, r0)
return r1
def test():
r0 :: str
r1 :: tuple
r2 :: native_int
r3 :: bit
r4, x, r5 :: str
r6 :: native_int
a :: tuple
L0:
r0 = 'abc'
r1 = PyTuple_New(3)
r2 = 0
goto L2
L1:
r3 = r2 < 3 :: signed
if r3 goto L2 else goto L4 :: bool
L2:
r4 = CPyStr_GetItemUnsafe(r0, r2)
x = r4
r5 = f2(x)
CPySequenceTuple_SetItemUnsafe(r1, r2, r5)
L3:
r6 = r2 + 1
r2 = r6
goto L1
L4:
a = r1
return 1

[case testTupleBuiltFromFinalStr]
from typing import Final

Expand Down