33from typing import List , Optional , Tuple
44
55import libcst as cst
6- from libcst import matchers as m
76from libcst ._position import CodeRange
87from libcst .metadata import ParentNodeProvider , ScopeProvider
98
@@ -88,8 +87,9 @@ def on_visit(self, node: cst.CSTNode) -> Optional[bool]:
8887 continue
8988
9089 assign , target , value = found_assign
91- # If test can be a comparison expression
90+
9291 match if_test :
92+ # If test can be a comparison expression
9393 case cst .Comparison (
9494 left = cst .Name () as left ,
9595 comparisons = [
@@ -100,15 +100,21 @@ def on_visit(self, node: cst.CSTNode) -> Optional[bool]:
100100 )
101101 ],
102102 ):
103+
103104 if left .value == target .value :
104105 named_expr = self ._build_named_expr (target , value , parens = True )
105106 self .assigns [assign ] = named_expr
106- # If test can also be a bare name
107107 case cst .Name () as name :
108+ # If test can also be a bare name
108109 if name .value == target .value :
109110 named_expr = self ._build_named_expr (target , value , parens = False )
110111 self .assigns [assign ] = named_expr
111-
112+ case cst .UnaryOperation (
113+ operator = cst .Not (), expression = cst .Name () as name
114+ ):
115+ if name .value == target .value :
116+ named_expr = self ._build_named_expr (target , value , parens = True )
117+ self .assigns [assign ] = named_expr
112118 return super ().on_visit (node )
113119
114120 def visit_If (self , node : cst .If ):
@@ -120,17 +126,21 @@ def visit_If(self, node: cst.If):
120126 def leave_If (self , original_node , updated_node ):
121127 # TODO: add filter by include or exclude that works for nodes
122128 # that that have different start/end numbers.
129+
123130 if (result := self ._if_stack .pop ()) is not None :
124131 position , named_expr = result
125- is_name = m .matches (updated_node .test , m .Name ())
126132 self .add_change_from_position (position , self .change_description )
127- return (
128- updated_node .with_changes (test = named_expr )
129- if is_name
130- else updated_node .with_changes (
131- test = updated_node .test .with_changes (left = named_expr )
132- )
133- )
133+ match updated_node .test :
134+ case cst .Name ():
135+ return updated_node .with_changes (test = named_expr )
136+ case cst .UnaryOperation ():
137+ return updated_node .with_changes (
138+ test = updated_node .test .with_changes (expression = named_expr )
139+ )
140+ case _:
141+ return updated_node .with_changes (
142+ test = updated_node .test .with_changes (left = named_expr )
143+ )
134144
135145 return original_node
136146
0 commit comments