@@ -3201,8 +3201,8 @@ class ASTOptimiziationTests(unittest.TestCase):
32013201 def wrap_expr (self , expr ):
32023202 return ast .Module (body = [ast .Expr (value = expr )])
32033203
3204- def wrap_for (self , for_statement ):
3205- return ast .Module (body = [for_statement ])
3204+ def wrap_statement (self , statement ):
3205+ return ast .Module (body = [statement ])
32063206
32073207 def assert_ast (self , code , non_optimized_target , optimized_target ):
32083208
@@ -3230,16 +3230,16 @@ def assert_ast(self, code, non_optimized_target, optimized_target):
32303230 f"{ ast .dump (optimized_tree )} " ,
32313231 )
32323232
3233+ def create_binop (self , operand , left = ast .Constant (1 ), right = ast .Constant (1 )):
3234+ return ast .BinOp (left = left , op = self .binop [operand ], right = right )
3235+
32333236 def test_folding_binop (self ):
32343237 code = "1 %s 1"
32353238 operators = self .binop .keys ()
32363239
3237- def create_binop (operand , left = ast .Constant (1 ), right = ast .Constant (1 )):
3238- return ast .BinOp (left = left , op = self .binop [operand ], right = right )
3239-
32403240 for op in operators :
32413241 result_code = code % op
3242- non_optimized_target = self .wrap_expr (create_binop (op ))
3242+ non_optimized_target = self .wrap_expr (self . create_binop (op ))
32433243 optimized_target = self .wrap_expr (ast .Constant (value = eval (result_code )))
32443244
32453245 with self .subTest (
@@ -3251,7 +3251,7 @@ def create_binop(operand, left=ast.Constant(1), right=ast.Constant(1)):
32513251
32523252 # Multiplication of constant tuples must be folded
32533253 code = "(1,) * 3"
3254- non_optimized_target = self .wrap_expr (create_binop ("*" , ast .Tuple (elts = [ast .Constant (value = 1 )]), ast .Constant (value = 3 )))
3254+ non_optimized_target = self .wrap_expr (self . create_binop ("*" , ast .Tuple (elts = [ast .Constant (value = 1 )]), ast .Constant (value = 3 )))
32553255 optimized_target = self .wrap_expr (ast .Constant (eval (code )))
32563256
32573257 self .assert_ast (code , non_optimized_target , optimized_target )
@@ -3362,12 +3362,12 @@ def test_folding_iter(self):
33623362 ]
33633363
33643364 for left , right , ast_cls , optimized_iter in braces :
3365- non_optimized_target = self .wrap_for (ast .For (
3365+ non_optimized_target = self .wrap_statement (ast .For (
33663366 target = ast .Name (id = "_" , ctx = ast .Store ()),
33673367 iter = ast_cls (elts = [ast .Constant (1 )]),
33683368 body = [ast .Pass ()]
33693369 ))
3370- optimized_target = self .wrap_for (ast .For (
3370+ optimized_target = self .wrap_statement (ast .For (
33713371 target = ast .Name (id = "_" , ctx = ast .Store ()),
33723372 iter = ast .Constant (value = optimized_iter ),
33733373 body = [ast .Pass ()]
@@ -3385,6 +3385,92 @@ def test_folding_subscript(self):
33853385
33863386 self .assert_ast (code , non_optimized_target , optimized_target )
33873387
3388+ def test_folding_type_param_in_function_def (self ):
3389+ code = "def foo[%s = 1 + 1](): pass"
3390+
3391+ unoptimized_binop = self .create_binop ("+" )
3392+ unoptimized_type_params = [
3393+ ("T" , "T" , ast .TypeVar ),
3394+ ("**P" , "P" , ast .ParamSpec ),
3395+ ("*Ts" , "Ts" , ast .TypeVarTuple ),
3396+ ]
3397+
3398+ for type , name , type_param in unoptimized_type_params :
3399+ result_code = code % type
3400+ optimized_target = self .wrap_statement (
3401+ ast .FunctionDef (
3402+ name = 'foo' ,
3403+ args = ast .arguments (),
3404+ body = [ast .Pass ()],
3405+ type_params = [type_param (name = name , default_value = ast .Constant (2 ))]
3406+ )
3407+ )
3408+ non_optimized_target = self .wrap_statement (
3409+ ast .FunctionDef (
3410+ name = 'foo' ,
3411+ args = ast .arguments (),
3412+ body = [ast .Pass ()],
3413+ type_params = [type_param (name = name , default_value = unoptimized_binop )]
3414+ )
3415+ )
3416+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3417+
3418+ def test_folding_type_param_in_class_def (self ):
3419+ code = "class foo[%s = 1 + 1]: pass"
3420+
3421+ unoptimized_binop = self .create_binop ("+" )
3422+ unoptimized_type_params = [
3423+ ("T" , "T" , ast .TypeVar ),
3424+ ("**P" , "P" , ast .ParamSpec ),
3425+ ("*Ts" , "Ts" , ast .TypeVarTuple ),
3426+ ]
3427+
3428+ for type , name , type_param in unoptimized_type_params :
3429+ result_code = code % type
3430+ optimized_target = self .wrap_statement (
3431+ ast .ClassDef (
3432+ name = 'foo' ,
3433+ body = [ast .Pass ()],
3434+ type_params = [type_param (name = name , default_value = ast .Constant (2 ))]
3435+ )
3436+ )
3437+ non_optimized_target = self .wrap_statement (
3438+ ast .ClassDef (
3439+ name = 'foo' ,
3440+ body = [ast .Pass ()],
3441+ type_params = [type_param (name = name , default_value = unoptimized_binop )]
3442+ )
3443+ )
3444+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3445+
3446+ def test_folding_type_param_in_type_alias (self ):
3447+ code = "type foo[%s = 1 + 1] = 1"
3448+
3449+ unoptimized_binop = self .create_binop ("+" )
3450+ unoptimized_type_params = [
3451+ ("T" , "T" , ast .TypeVar ),
3452+ ("**P" , "P" , ast .ParamSpec ),
3453+ ("*Ts" , "Ts" , ast .TypeVarTuple ),
3454+ ]
3455+
3456+ for type , name , type_param in unoptimized_type_params :
3457+ result_code = code % type
3458+ optimized_target = self .wrap_statement (
3459+ ast .TypeAlias (
3460+ name = ast .Name (id = 'foo' , ctx = ast .Store ()),
3461+ type_params = [type_param (name = name , default_value = ast .Constant (2 ))],
3462+ value = ast .Constant (value = 1 ),
3463+ )
3464+ )
3465+ non_optimized_target = self .wrap_statement (
3466+ ast .TypeAlias (
3467+ name = ast .Name (id = 'foo' , ctx = ast .Store ()),
3468+ type_params = [type_param (name = name , default_value = unoptimized_binop )],
3469+ value = ast .Constant (value = 1 ),
3470+ )
3471+ )
3472+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3473+
33883474
33893475if __name__ == "__main__" :
33903476 unittest .main ()
0 commit comments