Skip to content

Commit e9539c1

Browse files
committed
ruff format
1 parent 53406de commit e9539c1

File tree

1 file changed

+32
-28
lines changed

1 file changed

+32
-28
lines changed

codeflash/tracer.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
import sys
2323
import threading
2424
import time
25+
from argparse import ArgumentParser
2526
from collections import defaultdict
2627
from pathlib import Path
27-
from typing import TYPE_CHECKING, Any, ClassVar
28+
from typing import TYPE_CHECKING, Any, Callable, ClassVar
2829

2930
import dill
3031
import isort
@@ -46,6 +47,24 @@
4647
from types import FrameType, TracebackType
4748

4849

50+
class FakeCode:
51+
def __init__(self, filename: str, line: int, name: str) -> None:
52+
self.co_filename = filename
53+
self.co_line = line
54+
self.co_name = name
55+
self.co_firstlineno = 0
56+
57+
def __repr__(self) -> str:
58+
return repr((self.co_filename, self.co_line, self.co_name, None))
59+
60+
61+
class FakeFrame:
62+
def __init__(self, code: FakeCode, prior: FakeFrame | None) -> None:
63+
self.f_code = code
64+
self.f_back = prior
65+
self.f_locals: dict = {}
66+
67+
4968
# Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger.
5069
class Tracer:
5170
"""Use this class as a 'with' context manager to trace a function call.
@@ -75,7 +94,9 @@ def __init__(
7594
if functions is None:
7695
functions = []
7796
if os.environ.get("CODEFLASH_TRACER_DISABLE", "0") == "1":
78-
console.rule("Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE", style="bold red")
97+
console.rule(
98+
"Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE", style="bold red"
99+
)
79100
disable = True
80101
self.disable = disable
81102
if self.disable:
@@ -210,7 +231,8 @@ def __exit__(
210231
test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay"
211232
)
212233
replay_test = isort.code(replay_test)
213-
with open(test_file_path, "w", encoding="utf8") as file:
234+
235+
with Path(test_file_path).open("w", encoding="utf8") as file:
214236
file.write(replay_test)
215237

216238
console.print(
@@ -248,7 +270,7 @@ def tracer_logic(self, frame: FrameType, event: str) -> None:
248270
class_name = arguments["self"].__class__.__name__
249271
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
250272
class_name = arguments["cls"].__name__
251-
except:
273+
except: # noqa: E722
252274
# someone can override the getattr method and raise an exception. I'm looking at you wrapt
253275
return
254276
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"
@@ -354,7 +376,7 @@ def trace_dispatch_call(self, frame: FrameType, t: int) -> int:
354376

355377
# Only attempt to handle the frame mismatch if we have a valid rframe
356378
if (
357-
not isinstance(rframe, Tracer.fake_frame)
379+
not isinstance(rframe, FakeFrame)
358380
and hasattr(rframe, "f_back")
359381
and hasattr(frame, "f_back")
360382
and rframe.f_back is frame.f_back
@@ -460,7 +482,7 @@ def trace_dispatch_return(self, frame: FrameType, t: int) -> int:
460482

461483
return 1
462484

463-
dispatch: ClassVar[dict[str, callable]] = {
485+
dispatch: ClassVar[dict[str, Callable[[Tracer, FrameType, int], int]]] = {
464486
"call": trace_dispatch_call,
465487
"exception": trace_dispatch_exception,
466488
"return": trace_dispatch_return,
@@ -469,26 +491,10 @@ def trace_dispatch_return(self, frame: FrameType, t: int) -> int:
469491
"c_return": trace_dispatch_return,
470492
}
471493

472-
class fake_code:
473-
def __init__(self, filename, line, name) -> None:
474-
self.co_filename = filename
475-
self.co_line = line
476-
self.co_name = name
477-
self.co_firstlineno = 0
478-
479-
def __repr__(self) -> str:
480-
return repr((self.co_filename, self.co_line, self.co_name, None))
481-
482-
class fake_frame:
483-
def __init__(self, code, prior) -> None:
484-
self.f_code = code
485-
self.f_back = prior
486-
self.f_locals = {}
487-
488-
def simulate_call(self, name) -> None:
489-
code = self.fake_code("profiler", 0, name)
494+
def simulate_call(self, name: str) -> None:
495+
code = FakeCode("profiler", 0, name)
490496
pframe = self.cur[-2] if self.cur else None
491-
frame = self.fake_frame(code, pframe)
497+
frame = FakeFrame(code, pframe)
492498
self.dispatch["call"](self, frame, 0)
493499

494500
def simulate_cmd_complete(self) -> None:
@@ -709,9 +715,7 @@ def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, An
709715
return self
710716

711717

712-
def main():
713-
from argparse import ArgumentParser
714-
718+
def main() -> ArgumentParser:
715719
parser = ArgumentParser(allow_abbrev=False)
716720
parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to <outfile>", required=True)
717721
parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None)

0 commit comments

Comments
 (0)