Skip to content

Commit ef7a432

Browse files
[refactor to use match] AssertionRewriter.visit_Compare()
1 parent 2499ce2 commit ef7a432

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

src/_pytest/assertion/rewrite.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,12 +1112,13 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:
11121112
def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
11131113
self.push_format_context()
11141114
# We first check if we have overwritten a variable in the previous assert
1115-
if isinstance(
1116-
comp.left, ast.Name
1117-
) and comp.left.id in self.variables_overwrite.get(self.scope, {}):
1118-
comp.left = self.variables_overwrite[self.scope][comp.left.id] # type:ignore[assignment]
1119-
if isinstance(comp.left, ast.NamedExpr):
1120-
self.variables_overwrite[self.scope][comp.left.target.id] = comp.left # type:ignore[assignment]
1115+
match comp.left:
1116+
case ast.Name(id=name_id) if name_id in self.variables_overwrite.get(
1117+
self.scope, {}
1118+
):
1119+
comp.left = self.variables_overwrite[self.scope][name_id] # type: ignore[assignment]
1120+
case ast.NamedExpr(target=ast.Name(id=target_id)):
1121+
self.variables_overwrite[self.scope][target_id] = comp.left # type: ignore[assignment]
11211122
left_res, left_expl = self.visit(comp.left)
11221123
if isinstance(comp.left, ast.Compare | ast.BoolOp):
11231124
left_expl = f"({left_expl})"
@@ -1129,13 +1130,14 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
11291130
syms: list[ast.expr] = []
11301131
results = [left_res]
11311132
for i, op, next_operand in it:
1132-
if (
1133-
isinstance(next_operand, ast.NamedExpr)
1134-
and isinstance(left_res, ast.Name)
1135-
and next_operand.target.id == left_res.id
1136-
):
1137-
next_operand.target.id = self.variable()
1138-
self.variables_overwrite[self.scope][left_res.id] = next_operand # type:ignore[assignment]
1133+
match (next_operand, left_res):
1134+
case (
1135+
ast.NamedExpr(target=ast.Name(id=target_id)),
1136+
ast.Name(id=name_id),
1137+
) if target_id == name_id:
1138+
next_operand.target.id = self.variable()
1139+
self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment]
1140+
11391141
next_res, next_expl = self.visit(next_operand)
11401142
if isinstance(next_operand, ast.Compare | ast.BoolOp):
11411143
next_expl = f"({next_expl})"

0 commit comments

Comments
 (0)