Skip to content

Commit 091b3e0

Browse files
committed
[mypyc] feat: extend get_expr_length to work with RTuple and TupleType
1 parent 19697af commit 091b3e0

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

mypyc/irbuild/for_helpers.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
TypeAlias,
2929
Var,
3030
)
31+
from mypy.types import (
32+
TupleType,
33+
)
3134
from mypyc.ir.ops import (
3235
ERR_NEVER,
3336
BasicBlock,
@@ -1180,18 +1183,18 @@ def gen_cleanup(self) -> None:
11801183
gen.gen_cleanup()
11811184

11821185

1183-
def get_expr_length(expr: Expression) -> int | None:
1186+
def get_expr_length(builder: IRBuilder, expr: Expression) -> int | None:
11841187
if isinstance(expr, (StrExpr, BytesExpr)):
11851188
return len(expr.value)
11861189
elif isinstance(expr, (ListExpr, TupleExpr)):
11871190
# if there are no star expressions, or we know the length of them,
11881191
# we know the length of the expression
1189-
stars = [get_expr_length(i) for i in expr.items if isinstance(i, StarExpr)]
1192+
stars = [get_expr_length(builder, i) for i in expr.items if isinstance(i, StarExpr)]
11901193
if None not in stars:
11911194
other = sum(not isinstance(i, StarExpr) for i in expr.items)
11921195
return other + sum(stars) # type: ignore [arg-type]
11931196
elif isinstance(expr, StarExpr):
1194-
return get_expr_length(expr.expr)
1197+
return get_expr_length(builder, expr.expr)
11951198
elif (
11961199
isinstance(expr, RefExpr)
11971200
and isinstance(expr.node, Var)
@@ -1204,6 +1207,14 @@ def get_expr_length(expr: Expression) -> int | None:
12041207
# performance boost and can be (sometimes) figured out pretty easily. set and dict
12051208
# comps *can* be done as well but will need special logic to consider the possibility
12061209
# of key conflicts. Range, enumerate, zip are all simple logic.
1210+
1211+
# we might still be able to get the length direcly from the type
1212+
expr_rtype = builder.node_type(expr)
1213+
if isinstance(expr_rtype, RTuple):
1214+
return len(expr_rtype.types)
1215+
expr_type = builder.types[expr]
1216+
if isinstance(expr_type, TupleType):
1217+
return len(expr_type.items)
12071218
return None
12081219

12091220

@@ -1212,7 +1223,7 @@ def get_expr_length_value(
12121223
) -> Value:
12131224
rtype = builder.node_type(expr)
12141225
assert is_sequence_rprimitive(rtype), rtype
1215-
length = get_expr_length(expr)
1226+
length = get_expr_length(builder, expr)
12161227
if length is None:
12171228
# We cannot compute the length at compile time, so we will fetch it.
12181229
return builder.builder.builtin_len(expr_reg, line, use_pyssize_t=use_pyssize_t)

0 commit comments

Comments
 (0)