Skip to content

Commit e376ec3

Browse files
constant fold range args
1 parent 29d1df0 commit e376ec3

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

mypyc/irbuild/for_helpers.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from __future__ import annotations
99

10+
import contextlib
1011
from typing import Callable, ClassVar
1112

1213
from mypy.nodes import (
@@ -68,6 +69,7 @@
6869
short_int_rprimitive,
6970
)
7071
from mypyc.irbuild.builder import IRBuilder
72+
from mypyc.irbuild.constant_fold import constant_fold_expr
7173
from mypyc.irbuild.prepare import GENERATOR_HELPER_NAME
7274
from mypyc.irbuild.targets import AssignmentTarget, AssignmentTargetTuple
7375
from mypyc.primitives.dict_ops import (
@@ -1204,18 +1206,18 @@ def gen_cleanup(self) -> None:
12041206
gen.gen_cleanup()
12051207

12061208

1207-
def get_expr_length(expr: Expression) -> int | None:
1209+
def get_expr_length(builder: IRBuilder, expr: Expression) -> int | None:
12081210
if isinstance(expr, (StrExpr, BytesExpr)):
12091211
return len(expr.value)
12101212
elif isinstance(expr, (ListExpr, TupleExpr)):
12111213
# if there are no star expressions, or we know the length of them,
12121214
# we know the length of the expression
1213-
stars = [get_expr_length(i) for i in expr.items if isinstance(i, StarExpr)]
1215+
stars = [get_expr_length(builder, i) for i in expr.items if isinstance(i, StarExpr)]
12141216
if None not in stars:
12151217
other = sum(not isinstance(i, StarExpr) for i in expr.items)
12161218
return other + sum(stars) # type: ignore [arg-type]
12171219
elif isinstance(expr, StarExpr):
1218-
return get_expr_length(expr.expr)
1220+
return get_expr_length(builder, expr.expr)
12191221
elif (
12201222
isinstance(expr, RefExpr)
12211223
and isinstance(expr.node, Var)
@@ -1241,19 +1243,20 @@ def get_expr_length(expr: Expression) -> int | None:
12411243
)
12421244
and len(expr.args) == 1
12431245
):
1244-
return get_expr_length(expr.args[0])
1246+
return get_expr_length(builder, expr.args[0])
12451247
elif fullname == "builtins.map" and len(expr.args) == 2:
1246-
return get_expr_length(expr.args[1])
1248+
return get_expr_length(builder, expr.args[1])
12471249
elif fullname == "builtins.zip" and expr.args:
1248-
arg_lengths = [get_expr_length(arg) for arg in expr.args]
1250+
arg_lengths = [get_expr_length(builder, arg) for arg in expr.args]
12491251
if all(arg is not None for arg in arg_lengths):
12501252
return min(arg_lengths) # type: ignore [type-var]
1251-
elif (
1252-
fullname == "builtins.range"
1253-
and len(expr.args) <= 3
1254-
and all(isinstance(arg, IntExpr) for arg in expr.args)
1255-
):
1256-
return len(range(*(arg.value for arg in expr.args))) # type: ignore [attr-defined]
1253+
elif fullname == "builtins.range" and len(expr.args) <= 3:
1254+
folded_args = [constant_fold_expr(builder, arg) for arg in args]
1255+
if all(isinstance(arg, int) for arg in folded_args):
1256+
try:
1257+
return len(range(*folded_args))
1258+
except ValueError: # prevent crash if invalid args
1259+
pass
12571260

12581261
# TODO: extend this, passing length of listcomp and genexp should have worthwhile
12591262
# performance boost and can be (sometimes) figured out pretty easily. set and dict

0 commit comments

Comments
 (0)