| 
1 | 1 | import ast  | 
 | 2 | + | 
 | 3 | + | 
2 | 4 | class PytestRaisesRemover(ast.NodeTransformer):  | 
3 | 5 |     """Replaces 'with pytest.raises()' blocks with the content inside them."""  | 
4 | 6 | 
 
  | 
5 | 7 |     def visit_With(self, node: ast.With) -> ast.AST | list[ast.AST]:  | 
6 |  | -        # Process any nested with blocks first by recursively visiting children  | 
7 |  | -        node = self.generic_visit(node)  | 
8 |  | - | 
 | 8 | +        # Directly visit children and check if they are nested with blocks  | 
9 | 9 |         for item in node.items:  | 
10 |  | -            # Check if this is a pytest.raises block  | 
11 |  | -            if (isinstance(item.context_expr, ast.Call) and  | 
12 |  | -                    isinstance(item.context_expr.func, ast.Attribute) and  | 
13 |  | -                    isinstance(item.context_expr.func.value, ast.Name) and  | 
14 |  | -                    item.context_expr.func.value.id == "pytest" and  | 
15 |  | -                    item.context_expr.func.attr == "raises"):  | 
16 |  | - | 
 | 10 | +            if (  | 
 | 11 | +                isinstance(item.context_expr, ast.Call)  | 
 | 12 | +                and isinstance(item.context_expr.func, ast.Attribute)  | 
 | 13 | +                and isinstance(item.context_expr.func.value, ast.Name)  | 
 | 14 | +                and item.context_expr.func.value.id == "pytest"  | 
 | 15 | +                and item.context_expr.func.attr == "raises"  | 
 | 16 | +            ):  | 
17 | 17 |                 # Return the body contents instead of the with block  | 
18 |  | -                # If there's multiple statements in the body, return them all  | 
19 |  | -                if len(node.body) == 1:  | 
20 |  | -                    return node.body[0]  | 
21 |  | -                return node.body  | 
 | 18 | +                return self._unwrap_body(node.body)  | 
 | 19 | + | 
 | 20 | +        # Generic visit for other types of 'with' blocks  | 
 | 21 | +        return self.generic_visit(node)  | 
22 | 22 | 
 
  | 
23 |  | -        return node  | 
 | 23 | +    def _unwrap_body(self, body: list[ast.stmt]) -> ast.AST | list[ast.AST]:  | 
 | 24 | +        # Unwrap the body either as a single statement or a list of statements  | 
 | 25 | +        if len(body) == 1:  | 
 | 26 | +            return body[0]  | 
 | 27 | +        return body  | 
24 | 28 | 
 
  | 
25 | 29 | 
 
  | 
26 | 30 | def remove_pytest_raises(tree: ast.AST) -> ast.AST:  | 
 | 
0 commit comments