Skip to content

Commit 1142de5

Browse files
committed
linting / formatting
1 parent e119085 commit 1142de5

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

codeflash/code_utils/codeflash_wrap_decorator.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class VerificationType(str, Enum): # moved from codeflash/verification/codeflas
2929
F = TypeVar("F", bound=Callable[..., Any])
3030

3131

32-
def get_run_tmp_file(file_path: Path) -> Path: # moved from codeflash/code_utils/code_utils.py
32+
def get_run_tmp_file(file_path: Path) -> Path: # moved from codeflash/code_utils/code_utils.py
3333
if not hasattr(get_run_tmp_file, "tmpdir"):
3434
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
3535
return Path(get_run_tmp_file.tmpdir.name) / file_path
@@ -181,10 +181,14 @@ def extract_test_context_from_frame() -> tuple[str, str | None, str]:
181181

182182
# Framework integration detection
183183
if (
184-
function_name in ["runTest", "_runTest", "run", "_testMethodName"]
185-
or "pytest" in str(frame_obj.f_globals.get("__file__", ""))
186-
or "unittest" in str(frame_obj.f_globals.get("__file__", ""))
187-
) and class_name and (class_name.startswith("Test") or "test" in class_name.lower()):
184+
(
185+
function_name in ["runTest", "_runTest", "run", "_testMethodName"]
186+
or "pytest" in str(frame_obj.f_globals.get("__file__", ""))
187+
or "unittest" in str(frame_obj.f_globals.get("__file__", ""))
188+
)
189+
and class_name
190+
and (class_name.startswith("Test") or "test" in class_name.lower())
191+
):
188192
test_method = function_name
189193
if "self" in frame_locals:
190194
with contextlib.suppress(AttributeError, TypeError):

codeflash/code_utils/edit_generated_tests.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ def __init__(
3333
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
3434
self.context_stack.append(node.name)
3535
for inner_node in node.body:
36-
if isinstance(inner_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
36+
if isinstance(inner_node, ast.FunctionDef):
3737
self.visit_FunctionDef(inner_node)
38+
elif isinstance(inner_node, ast.AsyncFunctionDef):
39+
self.visit_AsyncFunctionDef(inner_node)
3840
self.context_stack.pop()
3941
return node
4042

@@ -50,14 +52,14 @@ def get_comment(self, match_key: str) -> str:
5052
return f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
5153

5254
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
53-
return self._process_function_def(node)
55+
self._process_function_def_common(node)
56+
return node
5457

5558
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
56-
return self._process_function_def(node)
59+
self._process_function_def_common(node)
60+
return node
5761

58-
def _process_function_def(
59-
self, node: ast.FunctionDef | ast.AsyncFunctionDef
60-
) -> ast.FunctionDef | ast.AsyncFunctionDef:
62+
def _process_function_def_common(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
6163
self.context_stack.append(node.name)
6264
i = len(node.body) - 1
6365
test_qualified_name = ".".join(self.context_stack)
@@ -83,7 +85,6 @@ def _process_function_def(
8385
self.results[line_node.lineno] = self.get_comment(match_key)
8486
i -= 1
8587
self.context_stack.pop()
86-
return node
8788

8889

8990
def get_fn_call_linenos(

0 commit comments

Comments
 (0)