11"""Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)""" 
2+ 
23from  collections  import  defaultdict 
4+ from  pathlib  import  Path 
5+ from  typing  import  Union 
36
47import  isort 
58import  libcst  as  cst 
6- from  pathlib  import  Path 
7- from  typing  import  Union , List 
8- from  libcst  import  ImportFrom , ImportAlias , Name 
99
1010from  codeflash .code_utils .code_utils  import  get_run_tmp_file 
1111
1212
1313class  LineProfilerDecoratorAdder (cst .CSTTransformer ):
1414    """Transformer that adds a decorator to a function with a specific qualified name.""" 
15-      #Todo we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure 
15+ 
1616    def  __init__ (self , qualified_name : str , decorator_name : str ):
17-         """ 
18-         Initialize the transformer. 
17+         """Initialize the transformer. 
1918
2019        Args: 
21-             qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func. target_func"). 
20+             qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.target_func"). 
2221            decorator_name: The name of the decorator to add. 
22+ 
2323        """ 
2424        super ().__init__ ()
2525        self .qualified_name_parts  =  qualified_name .split ("." )
2626        self .decorator_name  =  decorator_name 
27- 
28-         # Track our current context path, only add when we encounter a class 
2927        self .context_stack  =  []
3028
3129    def  visit_ClassDef (self , node : cst .ClassDef ) ->  None :
@@ -48,47 +46,36 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
4846        if  self ._matches_qualified_path ():
4947            # Check if the decorator is already present 
5048            has_decorator  =  any (
51-                 self ._is_target_decorator (decorator .decorator )
52-                 for  decorator  in  original_node .decorators 
49+                 self ._is_target_decorator (decorator .decorator ) for  decorator  in  original_node .decorators 
5350            )
5451
5552            # Only add the decorator if it's not already there 
5653            if  not  has_decorator :
57-                 new_decorator  =  cst .Decorator (
58-                     decorator = cst .Name (value = self .decorator_name )
59-                 )
54+                 new_decorator  =  cst .Decorator (decorator = cst .Name (value = self .decorator_name ))
6055
6156                # Add our new decorator to the existing decorators 
6257                updated_decorators  =  [new_decorator ] +  list (updated_node .decorators )
63-                 updated_node  =  updated_node .with_changes (
64-                     decorators = tuple (updated_decorators )
65-                 )
58+                 updated_node  =  updated_node .with_changes (decorators = tuple (updated_decorators ))
6659
6760        # Pop the context when we leave a function 
6861        self .context_stack .pop ()
6962        return  updated_node 
7063
7164    def  _matches_qualified_path (self ) ->  bool :
7265        """Check if the current context stack matches the qualified name.""" 
73-         if  len (self .context_stack ) !=  len (self .qualified_name_parts ):
74-             return  False 
75- 
76-         for  i , name  in  enumerate (self .qualified_name_parts ):
77-             if  self .context_stack [i ] !=  name :
78-                 return  False 
79- 
80-         return  True 
66+         return  self .context_stack  ==  self .qualified_name_parts 
8167
8268    def  _is_target_decorator (self , decorator_node : Union [cst .Name , cst .Attribute , cst .Call ]) ->  bool :
8369        """Check if a decorator matches our target decorator name.""" 
8470        if  isinstance (decorator_node , cst .Name ):
8571            return  decorator_node .value  ==  self .decorator_name 
86-         elif  isinstance (decorator_node , cst .Call ) and  isinstance (decorator_node .func , cst .Name ):
72+         if  isinstance (decorator_node , cst .Call ) and  isinstance (decorator_node .func , cst .Name ):
8773            return  decorator_node .func .value  ==  self .decorator_name 
8874        return  False 
8975
76+ 
9077class  ProfileEnableTransformer (cst .CSTTransformer ):
91-     def  __init__ (self ,filename ):
78+     def  __init__ (self ,  filename ):
9279        # Flag to track if we found the import statement 
9380        self .found_import  =  False 
9481        # Track indentation of the import statement 
@@ -97,12 +84,14 @@ def __init__(self,filename):
9784
9885    def  leave_ImportFrom (self , original_node : cst .ImportFrom , updated_node : cst .ImportFrom ) ->  cst .ImportFrom :
9986        # Check if this is the line profiler import statement 
100-         if  (isinstance (original_node .module , cst .Name ) and 
101-                 original_node .module .value  ==  "line_profiler"  and 
102-                 any (name .name .value  ==  "profile"  and 
103-                     (not  name .asname  or  name .asname .name .value  ==  "codeflash_line_profile" )
104-                     for  name  in  original_node .names )):
105- 
87+         if  (
88+             isinstance (original_node .module , cst .Name )
89+             and  original_node .module .value  ==  "line_profiler" 
90+             and  any (
91+                 name .name .value  ==  "profile"  and  (not  name .asname  or  name .asname .name .value  ==  "codeflash_line_profile" )
92+                 for  name  in  original_node .names 
93+             )
94+         ):
10695            self .found_import  =  True 
10796            # Get the indentation from the original node 
10897            if  hasattr (original_node , "leading_lines" ):
@@ -124,31 +113,33 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
124113            if  isinstance (stmt , cst .SimpleStatementLine ):
125114                for  small_stmt  in  stmt .body :
126115                    if  isinstance (small_stmt , cst .ImportFrom ):
127-                         if  (isinstance (small_stmt .module , cst .Name ) and 
128-                                 small_stmt .module .value  ==  "line_profiler"  and 
129-                                 any (name .name .value  ==  "profile"  and 
130-                                     (not  name .asname  or  name .asname .name .value  ==  "codeflash_line_profile" )
131-                                     for  name  in  small_stmt .names )):
116+                         if  (
117+                             isinstance (small_stmt .module , cst .Name )
118+                             and  small_stmt .module .value  ==  "line_profiler" 
119+                             and  any (
120+                                 name .name .value  ==  "profile" 
121+                                 and  (not  name .asname  or  name .asname .name .value  ==  "codeflash_line_profile" )
122+                                 for  name  in  small_stmt .names 
123+                             )
124+                         ):
132125                            import_index  =  i 
133126                            break 
134127                if  import_index  is  not   None :
135128                    break 
136129
137130        if  import_index  is  not   None :
138131            # Create the new enable statement to insert after the import 
139-             enable_statement  =  cst .parse_statement (
140-                 f"codeflash_line_profile.enable(output_prefix='{ self .filename }  ')" 
141-             )
132+             enable_statement  =  cst .parse_statement (f"codeflash_line_profile.enable(output_prefix='{ self .filename }  ')" )
142133
143134            # Insert the new statement after the import statement 
144135            new_body .insert (import_index  +  1 , enable_statement )
145136
146137        # Create a new module with the updated body 
147138        return  updated_node .with_changes (body = new_body )
148139
140+ 
149141def  add_decorator_to_qualified_function (module , qualified_name , decorator_name ):
150-     """ 
151-     Add a decorator to a function with the exact qualified name in the source code. 
142+     """Add a decorator to a function with the exact qualified name in the source code. 
152143
153144    Args: 
154145        module: The Python source code as a string. 
@@ -157,6 +148,7 @@ def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
157148
158149    Returns: 
159150        The modified source code as a string. 
151+ 
160152    """ 
161153    # Parse the source code into a CST 
162154
@@ -167,8 +159,9 @@ def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
167159    # Convert the modified CST back to source code 
168160    return  modified_module 
169161
162+ 
170163def  add_profile_enable (original_code : str , line_profile_output_file : str ) ->  str :
171-     # todo  modify by using a libcst transformer 
164+     # TODO  modify by using a libcst transformer 
172165    module  =  cst .parse_module (original_code )
173166    transformer  =  ProfileEnableTransformer (line_profile_output_file )
174167    modified_module  =  module .visit (transformer )
@@ -189,9 +182,7 @@ def leave_Module(self, original_node, updated_node):
189182        import_node  =  cst .parse_statement (self .import_statement )
190183
191184        # Add the import to the module's body 
192-         return  updated_node .with_changes (
193-             body = [import_node ] +  list (updated_node .body )
194-         )
185+         return  updated_node .with_changes (body = [import_node ] +  list (updated_node .body ))
195186
196187    def  visit_ImportFrom (self , node ):
197188        # Check if the profile is already imported from line_profiler 
@@ -203,21 +194,22 @@ def visit_ImportFrom(self, node):
203194
204195def  add_decorator_imports (function_to_optimize , code_context ):
205196    """Adds a profile decorator to a function in a Python file and all its helper functions.""" 
206-     #self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root 
207-     #grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile 
197+     #  self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root 
198+     #  grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile 
208199    file_paths  =  defaultdict (list )
209200    line_profile_output_file  =  get_run_tmp_file (Path ("baseline_lprof" ))
210201    file_paths [function_to_optimize .file_path ].append (function_to_optimize .qualified_name )
211202    for  elem  in  code_context .helper_functions :
212203        file_paths [elem .file_path ].append (elem .qualified_name )
213-     for  file_path ,fns_present  in  file_paths .items ():
214-         #open file 
215-         file_contents  =  file_path .read_text ("utf-8" )
204+     for  file_path , fns_present  in  file_paths .items ():
205+         # open file 
206+         with  open (file_path , encoding = "utf-8" ) as  file :
207+             file_contents  =  file .read ()
216208        # parse to cst 
217209        module_node  =  cst .parse_module (file_contents )
218210        for  fn_name  in  fns_present :
219211            # add decorator 
220-             module_node  =  add_decorator_to_qualified_function (module_node , fn_name , ' codeflash_line_profile'  )
212+             module_node  =  add_decorator_to_qualified_function (module_node , fn_name , " codeflash_line_profile"  )
221213        # add imports 
222214        # Create a transformer to add the import 
223215        transformer  =  ImportAdder ("from line_profiler import profile as codeflash_line_profile" )
@@ -227,8 +219,10 @@ def add_decorator_imports(function_to_optimize, code_context):
227219        # write to file 
228220        with  open (file_path , "w" , encoding = "utf-8" ) as  file :
229221            file .write (modified_code )
230-     #Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files 
231-     file_contents  =  function_to_optimize .file_path .read_text ("utf-8" )
232-     modified_code  =  add_profile_enable (file_contents ,str (line_profile_output_file ))
233-     function_to_optimize .file_path .write_text (modified_code ,"utf-8" )
234-     return  line_profile_output_file 
222+     # Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files 
223+     with  open (function_to_optimize .file_path ) as  f :
224+         file_contents  =  f .read ()
225+     modified_code  =  add_profile_enable (file_contents , str (line_profile_output_file ))
226+     with  open (function_to_optimize .file_path , "w" ) as  f :
227+         f .write (modified_code )
228+     return  line_profile_output_file 
0 commit comments