Skip to content

Commit db75b81

Browse files
authored
walrus if codemod can handle if not operations (#452)
1 parent 72899fb commit db75b81

File tree

2 files changed

+34
-12
lines changed

2 files changed

+34
-12
lines changed

src/core_codemods/use_walrus_if.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import List, Optional, Tuple
44

55
import libcst as cst
6-
from libcst import matchers as m
76
from libcst._position import CodeRange
87
from libcst.metadata import ParentNodeProvider, ScopeProvider
98

@@ -88,8 +87,9 @@ def on_visit(self, node: cst.CSTNode) -> Optional[bool]:
8887
continue
8988

9089
assign, target, value = found_assign
91-
# If test can be a comparison expression
90+
9291
match if_test:
92+
# If test can be a comparison expression
9393
case cst.Comparison(
9494
left=cst.Name() as left,
9595
comparisons=[
@@ -100,15 +100,21 @@ def on_visit(self, node: cst.CSTNode) -> Optional[bool]:
100100
)
101101
],
102102
):
103+
103104
if left.value == target.value:
104105
named_expr = self._build_named_expr(target, value, parens=True)
105106
self.assigns[assign] = named_expr
106-
# If test can also be a bare name
107107
case cst.Name() as name:
108+
# If test can also be a bare name
108109
if name.value == target.value:
109110
named_expr = self._build_named_expr(target, value, parens=False)
110111
self.assigns[assign] = named_expr
111-
112+
case cst.UnaryOperation(
113+
operator=cst.Not(), expression=cst.Name() as name
114+
):
115+
if name.value == target.value:
116+
named_expr = self._build_named_expr(target, value, parens=True)
117+
self.assigns[assign] = named_expr
112118
return super().on_visit(node)
113119

114120
def visit_If(self, node: cst.If):
@@ -120,17 +126,21 @@ def visit_If(self, node: cst.If):
120126
def leave_If(self, original_node, updated_node):
121127
# TODO: add filter by include or exclude that works for nodes
122128
# that that have different start/end numbers.
129+
123130
if (result := self._if_stack.pop()) is not None:
124131
position, named_expr = result
125-
is_name = m.matches(updated_node.test, m.Name())
126132
self.add_change_from_position(position, self.change_description)
127-
return (
128-
updated_node.with_changes(test=named_expr)
129-
if is_name
130-
else updated_node.with_changes(
131-
test=updated_node.test.with_changes(left=named_expr)
132-
)
133-
)
133+
match updated_node.test:
134+
case cst.Name():
135+
return updated_node.with_changes(test=named_expr)
136+
case cst.UnaryOperation():
137+
return updated_node.with_changes(
138+
test=updated_node.test.with_changes(expression=named_expr)
139+
)
140+
case _:
141+
return updated_node.with_changes(
142+
test=updated_node.test.with_changes(left=named_expr)
143+
)
134144

135145
return original_node
136146

tests/codemods/test_walrus_if.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,18 @@ def test_walrus_if_name_only(self, tmpdir):
4040
"""
4141
self.run_and_assert(tmpdir, input_code, expected_output)
4242

43+
def test_walrus_if_not(self, tmpdir):
44+
input_code = """
45+
val = do_something()
46+
if not val:
47+
do_something_else(val)
48+
"""
49+
expected_output = """
50+
if not (val := do_something()):
51+
do_something_else(val)
52+
"""
53+
self.run_and_assert(tmpdir, input_code, expected_output)
54+
4355
def test_walrus_if_preserve_comments(self, tmpdir):
4456
input_code = """
4557
val = do_something() # comment

0 commit comments

Comments
 (0)