|
8 | 8 | import isort |
9 | 9 | import libcst as cst |
10 | 10 | import libcst.matchers as m |
| 11 | +from libcst.metadata import PositionProvider |
11 | 12 |
|
12 | 13 | from codeflash.cli_cmds.console import logger |
13 | 14 | from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module |
@@ -37,6 +38,34 @@ def normalize_code(code: str) -> str: |
37 | 38 | return ast.unparse(normalize_node(ast.parse(code))) |
38 | 39 |
|
39 | 40 |
|
| 41 | +class AddRequestArgument(cst.CSTTransformer): |
| 42 | + METADATA_DEPENDENCIES = (PositionProvider,) |
| 43 | + |
| 44 | + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002 |
| 45 | + args = updated_node.params.params |
| 46 | + arg_names = {arg.name.value for arg in args} |
| 47 | + |
| 48 | + # Skip if 'request' is already present |
| 49 | + if "request" in arg_names: |
| 50 | + return updated_node |
| 51 | + |
| 52 | + # Create a new 'request' param |
| 53 | + request_param = cst.Param(name=cst.Name("request")) |
| 54 | + |
| 55 | + # Add 'request' as the first argument (after 'self' or 'cls' if needed) |
| 56 | + if args: |
| 57 | + first_arg = args[0].name.value |
| 58 | + if first_arg in {"self", "cls"}: |
| 59 | + new_params = [args[0], request_param] + list(args[1:]) # noqa: RUF005 |
| 60 | + else: |
| 61 | + new_params = [request_param] + list(args) # noqa: RUF005 |
| 62 | + else: |
| 63 | + new_params = [request_param] |
| 64 | + |
| 65 | + new_param_list = updated_node.params.with_changes(params=new_params) |
| 66 | + return updated_node.with_changes(params=new_param_list) |
| 67 | + |
| 68 | + |
40 | 69 | class PytestMarkAdder(cst.CSTTransformer): |
41 | 70 | """Transformer that adds pytest marks to test functions.""" |
42 | 71 |
|
@@ -139,8 +168,10 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu |
139 | 168 | def disable_autouse(test_path: Path) -> str: |
140 | 169 | file_content = test_path.read_text(encoding="utf-8") |
141 | 170 | module = cst.parse_module(file_content) |
| 171 | + add_request_argument = AddRequestArgument() |
142 | 172 | disable_autouse_fixture = AutouseFixtureModifier() |
143 | | - modified_module = module.visit(disable_autouse_fixture) |
| 173 | + modified_module = module.visit(add_request_argument) |
| 174 | + modified_module = modified_module.visit(disable_autouse_fixture) |
144 | 175 | test_path.write_text(modified_module.code, encoding="utf-8") |
145 | 176 | return file_content |
146 | 177 |
|
|
0 commit comments