7
7
8
8
from __future__ import annotations
9
9
10
+ import contextlib
10
11
from typing import Callable , ClassVar
11
12
12
13
from mypy .nodes import (
68
69
short_int_rprimitive ,
69
70
)
70
71
from mypyc .irbuild .builder import IRBuilder
72
+ from mypyc .irbuild .constant_fold import constant_fold_expr
71
73
from mypyc .irbuild .prepare import GENERATOR_HELPER_NAME
72
74
from mypyc .irbuild .targets import AssignmentTarget , AssignmentTargetTuple
73
75
from mypyc .primitives .dict_ops import (
@@ -1204,18 +1206,18 @@ def gen_cleanup(self) -> None:
1204
1206
gen .gen_cleanup ()
1205
1207
1206
1208
1207
- def get_expr_length (expr : Expression ) -> int | None :
1209
+ def get_expr_length (builder : IRBuilder , expr : Expression ) -> int | None :
1208
1210
if isinstance (expr , (StrExpr , BytesExpr )):
1209
1211
return len (expr .value )
1210
1212
elif isinstance (expr , (ListExpr , TupleExpr )):
1211
1213
# if there are no star expressions, or we know the length of them,
1212
1214
# we know the length of the expression
1213
- stars = [get_expr_length (i ) for i in expr .items if isinstance (i , StarExpr )]
1215
+ stars = [get_expr_length (builder , i ) for i in expr .items if isinstance (i , StarExpr )]
1214
1216
if None not in stars :
1215
1217
other = sum (not isinstance (i , StarExpr ) for i in expr .items )
1216
1218
return other + sum (stars ) # type: ignore [arg-type]
1217
1219
elif isinstance (expr , StarExpr ):
1218
- return get_expr_length (expr .expr )
1220
+ return get_expr_length (builder , expr .expr )
1219
1221
elif (
1220
1222
isinstance (expr , RefExpr )
1221
1223
and isinstance (expr .node , Var )
@@ -1241,19 +1243,20 @@ def get_expr_length(expr: Expression) -> int | None:
1241
1243
)
1242
1244
and len (expr .args ) == 1
1243
1245
):
1244
- return get_expr_length (expr .args [0 ])
1246
+ return get_expr_length (builder , expr .args [0 ])
1245
1247
elif fullname == "builtins.map" and len (expr .args ) == 2 :
1246
- return get_expr_length (expr .args [1 ])
1248
+ return get_expr_length (builder , expr .args [1 ])
1247
1249
elif fullname == "builtins.zip" and expr .args :
1248
- arg_lengths = [get_expr_length (arg ) for arg in expr .args ]
1250
+ arg_lengths = [get_expr_length (builder , arg ) for arg in expr .args ]
1249
1251
if all (arg is not None for arg in arg_lengths ):
1250
1252
return min (arg_lengths ) # type: ignore [type-var]
1251
- elif (
1252
- fullname == "builtins.range"
1253
- and len (expr .args ) <= 3
1254
- and all (isinstance (arg , IntExpr ) for arg in expr .args )
1255
- ):
1256
- return len (range (* (arg .value for arg in expr .args ))) # type: ignore [attr-defined]
1253
+ elif fullname == "builtins.range" and len (expr .args ) <= 3 :
1254
+ folded_args = [constant_fold_expr (builder , arg ) for arg in args ]
1255
+ if all (isinstance (arg , int ) for arg in folded_args ):
1256
+ try :
1257
+ return len (range (* folded_args ))
1258
+ except ValueError : # prevent crash if invalid args
1259
+ pass
1257
1260
1258
1261
# TODO: extend this, passing length of listcomp and genexp should have worthwhile
1259
1262
# performance boost and can be (sometimes) figured out pretty easily. set and dict
0 commit comments