@@ -3062,8 +3062,8 @@ class ASTOptimiziationTests(unittest.TestCase):
30623062 def wrap_expr (self , expr ):
30633063 return ast .Module (body = [ast .Expr (value = expr )])
30643064
3065- def wrap_for (self , for_statement ):
3066- return ast .Module (body = [for_statement ])
3065+ def wrap_statement (self , statement ):
3066+ return ast .Module (body = [statement ])
30673067
30683068 def assert_ast (self , code , non_optimized_target , optimized_target ):
30693069 non_optimized_tree = ast .parse (code , optimize = - 1 )
@@ -3090,16 +3090,16 @@ def assert_ast(self, code, non_optimized_target, optimized_target):
30903090 f"{ ast .dump (optimized_tree )} " ,
30913091 )
30923092
3093+ def create_binop (self , operand , left = ast .Constant (1 ), right = ast .Constant (1 )):
3094+ return ast .BinOp (left = left , op = self .binop [operand ], right = right )
3095+
30933096 def test_folding_binop (self ):
30943097 code = "1 %s 1"
30953098 operators = self .binop .keys ()
30963099
3097- def create_binop (operand , left = ast .Constant (1 ), right = ast .Constant (1 )):
3098- return ast .BinOp (left = left , op = self .binop [operand ], right = right )
3099-
31003100 for op in operators :
31013101 result_code = code % op
3102- non_optimized_target = self .wrap_expr (create_binop (op ))
3102+ non_optimized_target = self .wrap_expr (self . create_binop (op ))
31033103 optimized_target = self .wrap_expr (ast .Constant (value = eval (result_code )))
31043104
31053105 with self .subTest (
@@ -3111,7 +3111,7 @@ def create_binop(operand, left=ast.Constant(1), right=ast.Constant(1)):
31113111
31123112 # Multiplication of constant tuples must be folded
31133113 code = "(1,) * 3"
3114- non_optimized_target = self .wrap_expr (create_binop ("*" , ast .Tuple (elts = [ast .Constant (value = 1 )]), ast .Constant (value = 3 )))
3114+ non_optimized_target = self .wrap_expr (self . create_binop ("*" , ast .Tuple (elts = [ast .Constant (value = 1 )]), ast .Constant (value = 3 )))
31153115 optimized_target = self .wrap_expr (ast .Constant (eval (code )))
31163116
31173117 self .assert_ast (code , non_optimized_target , optimized_target )
@@ -3222,12 +3222,12 @@ def test_folding_iter(self):
32223222 ]
32233223
32243224 for left , right , ast_cls , optimized_iter in braces :
3225- non_optimized_target = self .wrap_for (ast .For (
3225+ non_optimized_target = self .wrap_statement (ast .For (
32263226 target = ast .Name (id = "_" , ctx = ast .Store ()),
32273227 iter = ast_cls (elts = [ast .Constant (1 )]),
32283228 body = [ast .Pass ()]
32293229 ))
3230- optimized_target = self .wrap_for (ast .For (
3230+ optimized_target = self .wrap_statement (ast .For (
32313231 target = ast .Name (id = "_" , ctx = ast .Store ()),
32323232 iter = ast .Constant (value = optimized_iter ),
32333233 body = [ast .Pass ()]
@@ -3245,6 +3245,92 @@ def test_folding_subscript(self):
32453245
32463246 self .assert_ast (code , non_optimized_target , optimized_target )
32473247
3248+ def test_folding_type_param_in_function_def (self ):
3249+ code = "def foo[%s = 1 + 1](): pass"
3250+
3251+ unoptimized_binop = self .create_binop ("+" )
3252+ unoptimized_type_params = [
3253+ ("T" , "T" , ast .TypeVar ),
3254+ ("**P" , "P" , ast .ParamSpec ),
3255+ ("*Ts" , "Ts" , ast .TypeVarTuple ),
3256+ ]
3257+
3258+ for type , name , type_param in unoptimized_type_params :
3259+ result_code = code % type
3260+ optimized_target = self .wrap_statement (
3261+ ast .FunctionDef (
3262+ name = 'foo' ,
3263+ args = ast .arguments (),
3264+ body = [ast .Pass ()],
3265+ type_params = [type_param (name = name , default_value = ast .Constant (2 ))]
3266+ )
3267+ )
3268+ non_optimized_target = self .wrap_statement (
3269+ ast .FunctionDef (
3270+ name = 'foo' ,
3271+ args = ast .arguments (),
3272+ body = [ast .Pass ()],
3273+ type_params = [type_param (name = name , default_value = unoptimized_binop )]
3274+ )
3275+ )
3276+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3277+
3278+ def test_folding_type_param_in_class_def (self ):
3279+ code = "class foo[%s = 1 + 1]: pass"
3280+
3281+ unoptimized_binop = self .create_binop ("+" )
3282+ unoptimized_type_params = [
3283+ ("T" , "T" , ast .TypeVar ),
3284+ ("**P" , "P" , ast .ParamSpec ),
3285+ ("*Ts" , "Ts" , ast .TypeVarTuple ),
3286+ ]
3287+
3288+ for type , name , type_param in unoptimized_type_params :
3289+ result_code = code % type
3290+ optimized_target = self .wrap_statement (
3291+ ast .ClassDef (
3292+ name = 'foo' ,
3293+ body = [ast .Pass ()],
3294+ type_params = [type_param (name = name , default_value = ast .Constant (2 ))]
3295+ )
3296+ )
3297+ non_optimized_target = self .wrap_statement (
3298+ ast .ClassDef (
3299+ name = 'foo' ,
3300+ body = [ast .Pass ()],
3301+ type_params = [type_param (name = name , default_value = unoptimized_binop )]
3302+ )
3303+ )
3304+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3305+
3306+ def test_folding_type_param_in_type_alias (self ):
3307+ code = "type foo[%s = 1 + 1] = 1"
3308+
3309+ unoptimized_binop = self .create_binop ("+" )
3310+ unoptimized_type_params = [
3311+ ("T" , "T" , ast .TypeVar ),
3312+ ("**P" , "P" , ast .ParamSpec ),
3313+ ("*Ts" , "Ts" , ast .TypeVarTuple ),
3314+ ]
3315+
3316+ for type , name , type_param in unoptimized_type_params :
3317+ result_code = code % type
3318+ optimized_target = self .wrap_statement (
3319+ ast .TypeAlias (
3320+ name = ast .Name (id = 'foo' , ctx = ast .Store ()),
3321+ type_params = [type_param (name = name , default_value = ast .Constant (2 ))],
3322+ value = ast .Constant (value = 1 ),
3323+ )
3324+ )
3325+ non_optimized_target = self .wrap_statement (
3326+ ast .TypeAlias (
3327+ name = ast .Name (id = 'foo' , ctx = ast .Store ()),
3328+ type_params = [type_param (name = name , default_value = unoptimized_binop )],
3329+ value = ast .Constant (value = 1 ),
3330+ )
3331+ )
3332+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3333+
32483334
32493335if __name__ == "__main__" :
32503336 unittest .main ()
0 commit comments