11"""Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)""" 
2+ 
23from  collections  import  defaultdict 
34from  pathlib  import  Path 
45from  typing  import  Union 
1213class  LineProfilerDecoratorAdder (cst .CSTTransformer ):
1314    """Transformer that adds a decorator to a function with a specific qualified name.""" 
1415
15-     #TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure 
16+     #  TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure 
1617    def  __init__ (self , qualified_name : str , decorator_name : str ):
1718        """Initialize the transformer. 
1819
@@ -45,24 +46,19 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
4546        function_name  =  original_node .name .value 
4647
4748        # Check if the current context path matches our target qualified name 
48-         if  self .context_stack == self .qualified_name_parts :
49+         if  self .context_stack   ==   self .qualified_name_parts :
4950            # Check if the decorator is already present 
5051            has_decorator  =  any (
51-                 self ._is_target_decorator (decorator .decorator )
52-                 for  decorator  in  original_node .decorators 
52+                 self ._is_target_decorator (decorator .decorator ) for  decorator  in  original_node .decorators 
5353            )
5454
5555            # Only add the decorator if it's not already there 
5656            if  not  has_decorator :
57-                 new_decorator  =  cst .Decorator (
58-                     decorator = cst .Name (value = self .decorator_name )
59-                 )
57+                 new_decorator  =  cst .Decorator (decorator = cst .Name (value = self .decorator_name ))
6058
6159                # Add our new decorator to the existing decorators 
6260                updated_decorators  =  [new_decorator ] +  list (updated_node .decorators )
63-                 updated_node  =  updated_node .with_changes (
64-                     decorators = tuple (updated_decorators )
65-                 )
61+                 updated_node  =  updated_node .with_changes (decorators = tuple (updated_decorators ))
6662
6763        # Pop the context when we leave a function 
6864        self .context_stack .pop ()
@@ -76,8 +72,9 @@ def _is_target_decorator(self, decorator_node: Union[cst.Name, cst.Attribute, cs
7672            return  decorator_node .func .value  ==  self .decorator_name 
7773        return  False 
7874
75+ 
7976class  ProfileEnableTransformer (cst .CSTTransformer ):
80-     def  __init__ (self ,filename ):
77+     def  __init__ (self ,  filename ):
8178        # Flag to track if we found the import statement 
8279        self .found_import  =  False 
8380        # Track indentation of the import statement 
@@ -86,12 +83,14 @@ def __init__(self,filename):
8683
8784    def  leave_ImportFrom (self , original_node : cst .ImportFrom , updated_node : cst .ImportFrom ) ->  cst .ImportFrom :
8885        # Check if this is the line profiler import statement 
89-         if  (isinstance (original_node .module , cst .Name ) and 
90-                 original_node .module .value  ==  "line_profiler"  and 
91-                 any (name .name .value  ==  "profile"  and 
92-                     (not  name .asname  or  name .asname .name .value  ==  "codeflash_line_profile" )
93-                     for  name  in  original_node .names )):
94- 
86+         if  (
87+             isinstance (original_node .module , cst .Name )
88+             and  original_node .module .value  ==  "line_profiler" 
89+             and  any (
90+                 name .name .value  ==  "profile"  and  (not  name .asname  or  name .asname .name .value  ==  "codeflash_line_profile" )
91+                 for  name  in  original_node .names 
92+             )
93+         ):
9594            self .found_import  =  True 
9695            # Get the indentation from the original node 
9796            if  hasattr (original_node , "leading_lines" ):
@@ -113,28 +112,31 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
113112            if  isinstance (stmt , cst .SimpleStatementLine ):
114113                for  small_stmt  in  stmt .body :
115114                    if  isinstance (small_stmt , cst .ImportFrom ):
116-                         if  (isinstance (small_stmt .module , cst .Name ) and 
117-                                 small_stmt .module .value  ==  "line_profiler"  and 
118-                                 any (name .name .value  ==  "profile"  and 
119-                                     (not  name .asname  or  name .asname .name .value  ==  "codeflash_line_profile" )
120-                                     for  name  in  small_stmt .names )):
115+                         if  (
116+                             isinstance (small_stmt .module , cst .Name )
117+                             and  small_stmt .module .value  ==  "line_profiler" 
118+                             and  any (
119+                                 name .name .value  ==  "profile" 
120+                                 and  (not  name .asname  or  name .asname .name .value  ==  "codeflash_line_profile" )
121+                                 for  name  in  small_stmt .names 
122+                             )
123+                         ):
121124                            import_index  =  i 
122125                            break 
123126                if  import_index  is  not   None :
124127                    break 
125128
126129        if  import_index  is  not   None :
127130            # Create the new enable statement to insert after the import 
128-             enable_statement  =  cst .parse_statement (
129-                 f"codeflash_line_profile.enable(output_prefix='{ self .filename }  ')" 
130-             )
131+             enable_statement  =  cst .parse_statement (f"codeflash_line_profile.enable(output_prefix='{ self .filename }  ')" )
131132
132133            # Insert the new statement after the import statement 
133134            new_body .insert (import_index  +  1 , enable_statement )
134135
135136        # Create a new module with the updated body 
136137        return  updated_node .with_changes (body = new_body )
137138
139+ 
138140def  add_decorator_to_qualified_function (module , qualified_name , decorator_name ):
139141    """Add a decorator to a function with the exact qualified name in the source code. 
140142
@@ -156,6 +158,7 @@ def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
156158    # Convert the modified CST back to source code 
157159    return  modified_module 
158160
161+ 
159162def  add_profile_enable (original_code : str , line_profile_output_file : str ) ->  str :
160163    # TODO modify by using a libcst transformer 
161164    module  =  cst .parse_module (original_code )
@@ -178,9 +181,7 @@ def leave_Module(self, original_node, updated_node):
178181        import_node  =  cst .parse_statement (self .import_statement )
179182
180183        # Add the import to the module's body 
181-         return  updated_node .with_changes (
182-             body = [import_node ] +  list (updated_node .body )
183-         )
184+         return  updated_node .with_changes (body = [import_node ] +  list (updated_node .body ))
184185
185186    def  visit_ImportFrom (self , node ):
186187        # Check if the profile is already imported from line_profiler 
@@ -192,15 +193,15 @@ def visit_ImportFrom(self, node):
192193
193194def  add_decorator_imports (function_to_optimize , code_context ):
194195    """Adds a profile decorator to a function in a Python file and all its helper functions.""" 
195-     #self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root 
196-     #grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile 
196+     #  self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root 
197+     #  grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile 
197198    file_paths  =  defaultdict (list )
198199    line_profile_output_file  =  get_run_tmp_file (Path ("baseline_lprof" ))
199200    file_paths [function_to_optimize .file_path ].append (function_to_optimize .qualified_name )
200201    for  elem  in  code_context .helper_functions :
201202        file_paths [elem .file_path ].append (elem .qualified_name )
202-     for  file_path ,fns_present  in  file_paths .items ():
203-         #open file 
203+     for  file_path ,  fns_present  in  file_paths .items ():
204+         #  open file 
204205        file_contents  =  file_path .read_text ("utf-8" )
205206        # parse to cst 
206207        module_node  =  cst .parse_module (file_contents )
@@ -216,8 +217,8 @@ def add_decorator_imports(function_to_optimize, code_context):
216217        # write to file 
217218        with  open (file_path , "w" , encoding = "utf-8" ) as  file :
218219            file .write (modified_code )
219-     #Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files 
220+     #  Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files 
220221    file_contents  =  function_to_optimize .file_path .read_text ("utf-8" )
221-     modified_code  =  add_profile_enable (file_contents ,str (line_profile_output_file ))
222-     function_to_optimize .file_path .write_text (modified_code ,"utf-8" )
222+     modified_code  =  add_profile_enable (file_contents ,  str (line_profile_output_file ))
223+     function_to_optimize .file_path .write_text (modified_code ,  "utf-8" )
223224    return  line_profile_output_file 
0 commit comments