2222import sys
2323import threading
2424import time
25+ from argparse import ArgumentParser
2526from collections import defaultdict
2627from pathlib import Path
27- from typing import TYPE_CHECKING , Any , ClassVar
28+ from typing import TYPE_CHECKING , Any , Callable , ClassVar
2829
2930import dill
3031import isort
4647 from types import FrameType , TracebackType
4748
4849
50+ class FakeCode :
51+ def __init__ (self , filename : str , line : int , name : str ) -> None :
52+ self .co_filename = filename
53+ self .co_line = line
54+ self .co_name = name
55+ self .co_firstlineno = 0
56+
57+ def __repr__ (self ) -> str :
58+ return repr ((self .co_filename , self .co_line , self .co_name , None ))
59+
60+
61+ class FakeFrame :
62+ def __init__ (self , code : FakeCode , prior : FakeFrame | None ) -> None :
63+ self .f_code = code
64+ self .f_back = prior
65+ self .f_locals : dict = {}
66+
67+
4968# Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger.
5069class Tracer :
5170 """Use this class as a 'with' context manager to trace a function call.
@@ -75,7 +94,9 @@ def __init__(
7594 if functions is None :
7695 functions = []
7796 if os .environ .get ("CODEFLASH_TRACER_DISABLE" , "0" ) == "1" :
78- console .rule ("Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE" , style = "bold red" )
97+ console .rule (
98+ "Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE" , style = "bold red"
99+ )
79100 disable = True
80101 self .disable = disable
81102 if self .disable :
@@ -210,7 +231,8 @@ def __exit__(
210231 test_dir = Path (self .config ["tests_root" ]), function_name = function_path , test_type = "replay"
211232 )
212233 replay_test = isort .code (replay_test )
213- with open (test_file_path , "w" , encoding = "utf8" ) as file :
234+
235+ with Path (test_file_path ).open ("w" , encoding = "utf8" ) as file :
214236 file .write (replay_test )
215237
216238 console .print (
@@ -248,7 +270,7 @@ def tracer_logic(self, frame: FrameType, event: str) -> None:
248270 class_name = arguments ["self" ].__class__ .__name__
249271 elif "cls" in arguments and hasattr (arguments ["cls" ], "__name__" ):
250272 class_name = arguments ["cls" ].__name__
251- except :
273+ except : # noqa: E722
252274 # someone can override the getattr method and raise an exception. I'm looking at you wrapt
253275 return
254276 function_qualified_name = f"{ file_name } :{ (class_name + ':' if class_name else '' )} { code .co_name } "
@@ -354,7 +376,7 @@ def trace_dispatch_call(self, frame: FrameType, t: int) -> int:
354376
355377 # Only attempt to handle the frame mismatch if we have a valid rframe
356378 if (
357- not isinstance (rframe , Tracer . fake_frame )
379+ not isinstance (rframe , FakeFrame )
358380 and hasattr (rframe , "f_back" )
359381 and hasattr (frame , "f_back" )
360382 and rframe .f_back is frame .f_back
@@ -460,7 +482,7 @@ def trace_dispatch_return(self, frame: FrameType, t: int) -> int:
460482
461483 return 1
462484
463- dispatch : ClassVar [dict [str , callable ]] = {
485+ dispatch : ClassVar [dict [str , Callable [[ Tracer , FrameType , int ], int ] ]] = {
464486 "call" : trace_dispatch_call ,
465487 "exception" : trace_dispatch_exception ,
466488 "return" : trace_dispatch_return ,
@@ -469,26 +491,10 @@ def trace_dispatch_return(self, frame: FrameType, t: int) -> int:
469491 "c_return" : trace_dispatch_return ,
470492 }
471493
472- class fake_code :
473- def __init__ (self , filename , line , name ) -> None :
474- self .co_filename = filename
475- self .co_line = line
476- self .co_name = name
477- self .co_firstlineno = 0
478-
479- def __repr__ (self ) -> str :
480- return repr ((self .co_filename , self .co_line , self .co_name , None ))
481-
482- class fake_frame :
483- def __init__ (self , code , prior ) -> None :
484- self .f_code = code
485- self .f_back = prior
486- self .f_locals = {}
487-
488- def simulate_call (self , name ) -> None :
489- code = self .fake_code ("profiler" , 0 , name )
494+ def simulate_call (self , name : str ) -> None :
495+ code = FakeCode ("profiler" , 0 , name )
490496 pframe = self .cur [- 2 ] if self .cur else None
491- frame = self . fake_frame (code , pframe )
497+ frame = FakeFrame (code , pframe )
492498 self .dispatch ["call" ](self , frame , 0 )
493499
494500 def simulate_cmd_complete (self ) -> None :
@@ -709,9 +715,7 @@ def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, An
709715 return self
710716
711717
712- def main ():
713- from argparse import ArgumentParser
714-
718+ def main () -> ArgumentParser :
715719 parser = ArgumentParser (allow_abbrev = False )
716720 parser .add_argument ("-o" , "--outfile" , dest = "outfile" , help = "Save trace to <outfile>" , required = True )
717721 parser .add_argument ("--only-functions" , help = "Trace only these functions" , nargs = "+" , default = None )
0 commit comments