@@ -3227,115 +3227,65 @@ def test_folding_type_param_in_type_alias(self):
32273227            self .assert_ast (result_code , non_optimized_target , optimized_target )
32283228
32293229    def  test_folding_match_case_allowed_expressions (self ):
3230-         source  =  textwrap .dedent (""" 
3231-         match 0: 
3232-             case -0:                                   pass 
3233-             case -0.1:                                 pass 
3234-             case -0j:                                  pass 
3235-             case -0.1j:                                pass 
3236-             case 1 + 2j:                               pass 
3237-             case 1 - 2j:                               pass 
3238-             case 1.1 + 2.1j:                           pass 
3239-             case 1.1 - 2.1j:                           pass 
3240-             case -0 + 1j:                              pass 
3241-             case -0 - 1j:                              pass 
3242-             case -0.1 + 1.1j:                          pass 
3243-             case -0.1 - 1.1j:                          pass 
3244-             case {-0: 0}:                              pass 
3245-             case {-0.1: 0}:                            pass 
3246-             case {-0j: 0}:                             pass 
3247-             case {-0.1j: 0}:                           pass 
3248-             case {1 + 2j: 0}:                          pass 
3249-             case {1 - 2j: 0}:                          pass 
3250-             case {1.1 + 2.1j: 0}:                      pass 
3251-             case {1.1 - 2.1j: 0}:                      pass 
3252-             case {-0 + 1j: 0}:                         pass 
3253-             case {-0 - 1j: 0}:                         pass 
3254-             case {-0.1 + 1.1j: 0}:                     pass 
3255-             case {-0.1 - 1.1j: 0}:                     pass 
3256-             case {-0: 0, 0 + 1j: 0, 0.1 + 1j: 0}:      pass 
3257-             case [-0, -0.1, -0j, -0.1j]:               pass 
3258-             case (-0, -0.1, -0j, -0.1j):               pass 
3259-             case [[-0, -0.1], [-0j, -0.1j]]:           pass 
3260-             case ((-0, -0.1), (-0j, -0.1j)):           pass 
3261-         """ )
3262-         expected_constants  =  (
3263-             0 ,
3264-             - 0.1 ,
3265-             complex (0 , - 0 ),
3266-             complex (0 , - 0.1 ),
3267-             complex (1 , 2 ),
3268-             complex (1 , - 2 ),
3269-             complex (1.1 , 2.1 ),
3270-             complex (1.1 , - 2.1 ),
3271-             complex (- 0 , 1 ),
3272-             complex (- 0 , - 1 ),
3273-             complex (- 0.1 , 1.1 ),
3274-             complex (- 0.1 , - 1.1 ),
3275-             (0 , ),
3276-             (- 0.1 , ),
3277-             (complex (0 , - 0 ), ),
3278-             (complex (0 , - 0.1 ), ),
3279-             (complex (1 , 2 ), ),
3280-             (complex (1 , - 2 ), ),
3281-             (complex (1.1 , 2.1 ), ),
3282-             (complex (1.1 , - 2.1 ), ),
3283-             (complex (- 0 , 1 ), ),
3284-             (complex (- 0 , - 1 ), ),
3285-             (complex (- 0.1 , 1.1 ), ),
3286-             (complex (- 0.1 , - 1.1 ), ),
3287-             (0 , complex (0 , 1 ), complex (0.1 , 1 )),
3288-             (
3289-                 0 ,
3290-                 - 0.1 ,
3291-                 complex (0 , - 0 ),
3292-                 complex (0 , - 0.1 ),
3293-             ),
3294-             (
3295-                 0 ,
3296-                 - 0.1 ,
3297-                 complex (0 , - 0 ),
3298-                 complex (0 , - 0.1 ),
3299-             ),
3300-             (
3301-                 0 ,
3302-                 - 0.1 ,
3303-                 complex (0 , - 0 ),
3304-                 complex (0 , - 0.1 ),
3305-             ),
3306-             (
3307-                 0 ,
3308-                 - 0.1 ,
3309-                 complex (0 , - 0 ),
3310-                 complex (0 , - 0.1 ),
3311-             )
3312-         )
3313-         consts  =  iter (expected_constants )
3314-         tree  =  ast .parse (source , optimize = 1 )
3315-         match_stmt  =  tree .body [0 ]
3316-         for  case  in  match_stmt .cases :
3317-             pattern  =  case .pattern 
3318-             if  isinstance (pattern , ast .MatchValue ):
3319-                 self .assertIsInstance (pattern .value , ast .Constant )
3320-                 self .assertEqual (pattern .value .value , next (consts ))
3321-             elif  isinstance (pattern , ast .MatchMapping ):
3322-                 keys  =  iter (next (consts ))
3323-                 for  key  in  pattern .keys :
3324-                     self .assertIsInstance (key , ast .Constant )
3325-                     self .assertEqual (key .value , next (keys ))
3326-             elif  isinstance (pattern , ast .MatchSequence ):
3327-                 values  =  iter (next (consts ))
3328-                 for  pat  in  pattern .patterns :
3329-                     if  isinstance (pat , ast .MatchValue ):
3330-                         self .assertEqual (pat .value .value , next (values ))
3331-                     elif  isinstance (pat , ast .MatchSequence ):
3332-                         for  p  in  pat .patterns :
3333-                             self .assertIsInstance (p , ast .MatchValue )
3334-                             self .assertEqual (p .value .value , next (values ))
3335-                     else :
3336-                         self .fail (f"Expected ast.MatchValue or ast.MatchSequence, found: { type (pat )}  )
3230+         def  get_match_case_values (node ):
3231+             result  =  []
3232+             if  isinstance (node , ast .Constant ):
3233+                 result .append (node .value )
3234+             elif  isinstance (node , ast .MatchValue ):
3235+                 result .extend (get_match_case_values (node .value ))
3236+             elif  isinstance (node , ast .MatchMapping ):
3237+                 for  key  in  node .keys :
3238+                     result .extend (get_match_case_values (key ))
3239+             elif  isinstance (node , ast .MatchSequence ):
3240+                 for  pat  in  node .patterns :
3241+                     result .extend (get_match_case_values (pat ))
33373242            else :
3338-                 self .fail (f"Expected ast.MatchValue or ast.MatchMapping, found: { type (pattern )}  )
3243+                 self .fail (f"Unexpected node { node }  )
3244+             return  result 
3245+ 
3246+         tests  =  [
3247+             ("-0" , [0 ]),
3248+             ("-0.1" , [- 0.1 ]),
3249+             ("-0j" , [complex (0 , 0 )]),
3250+             ("-0.1j" , [complex (0 , - 0.1 )]),
3251+             ("1 + 2j" , [complex (1 , 2 )]),
3252+             ("1 - 2j" , [complex (1 , - 2 )]),
3253+             ("1.1 + 2.1j" , [complex (1.1 , 2.1 )]),
3254+             ("1.1 - 2.1j" , [complex (1.1 , - 2.1 )]),
3255+             ("-0 + 1j" , [complex (0 , 1 )]),
3256+             ("-0 - 1j" , [complex (0 , - 1 )]),
3257+             ("-0.1 + 1.1j" , [complex (- 0.1 , 1.1 )]),
3258+             ("-0.1 - 1.1j" , [complex (- 0.1 , - 1.1 )]),
3259+             ("{-0: 0}" , [0 ]),
3260+             ("{-0.1: 0}" , [- 0.1 ]),
3261+             ("{-0j: 0}" , [complex (0 , 0 )]),
3262+             ("{-0.1j: 0}" , [complex (0 , - 0.1 )]),
3263+             ("{1 + 2j: 0}" , [complex (1 , 2 )]),
3264+             ("{1 - 2j: 0}" , [complex (1 , - 2 )]),
3265+             ("{1.1 + 2.1j: 0}" , [complex (1.1 , 2.1 )]),
3266+             ("{1.1 - 2.1j: 0}" , [complex (1.1 , - 2.1 )]),
3267+             ("{-0 + 1j: 0}" , [complex (0 , 1 )]),
3268+             ("{-0 - 1j: 0}" , [complex (0 , - 1 )]),
3269+             ("{-0.1 + 1.1j: 0}" , [complex (- 0.1 , 1.1 )]),
3270+             ("{-0.1 - 1.1j: 0}" , [complex (- 0.1 , - 1.1 )]),
3271+             ("{-0: 0, 0 + 1j: 0, 0.1 + 1j: 0}" , [0 , complex (0 , 1 ), complex (0.1 , 1 )]),
3272+             ("[-0, -0.1, -0j, -0.1j]" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3273+             ("[[[[-0, -0.1, -0j, -0.1j]]]]" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3274+             ("[[-0, -0.1], -0j, -0.1j]" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3275+             ("[[-0, -0.1], [-0j, -0.1j]]" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3276+             ("(-0, -0.1, -0j, -0.1j)" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3277+             ("((((-0, -0.1, -0j, -0.1j))))" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3278+             ("((-0, -0.1), -0j, -0.1j)" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3279+             ("((-0, -0.1), (-0j, -0.1j))" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3280+         ]
3281+         for  match_expr , constants  in  tests :
3282+             with  self .subTest (match_expr ):
3283+                 src  =  f"match 0:\n \t  case { match_expr }  
3284+                 tree  =  ast .parse (src , optimize = 1 )
3285+                 match_stmt  =  tree .body [0 ]
3286+                 case  =  match_stmt .cases [0 ]
3287+                 values  =  get_match_case_values (case .pattern )
3288+                 self .assertListEqual (constants , values )
33393289
33403290
33413291if  __name__  ==  '__main__' :
0 commit comments