1+ """Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)"""
2+ from collections import defaultdict
3+
14import isort
25import libcst as cst
36from pathlib import Path
4- from typing import Union
7+ from typing import Union , List
8+ from libcst import ImportFrom , ImportAlias , Name
59
610from codeflash .code_utils .code_utils import get_run_tmp_file
711
812
9- class DecoratorAdder (cst .CSTTransformer ):
13+ class LineProfilerDecoratorAdder (cst .CSTTransformer ):
1014 """Transformer that adds a decorator to a function with a specific qualified name."""
11-
15+ #Todo we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure
1216 def __init__ (self , qualified_name : str , decorator_name : str ):
1317 """
1418 Initialize the transformer.
@@ -21,7 +25,7 @@ def __init__(self, qualified_name: str, decorator_name: str):
2125 self .qualified_name_parts = qualified_name .split ("." )
2226 self .decorator_name = decorator_name
2327
24- # Track our current context path
28+ # Track our current context path, only add when we encounter a class
2529 self .context_stack = []
2630
2731 def visit_ClassDef (self , node : cst .ClassDef ) -> None :
@@ -83,6 +87,64 @@ def _is_target_decorator(self, decorator_node: Union[cst.Name, cst.Attribute, cs
8387 return decorator_node .func .value == self .decorator_name
8488 return False
8589
90+ class ProfileEnableTransformer (cst .CSTTransformer ):
91+ def __init__ (self ,filename ):
92+ # Flag to track if we found the import statement
93+ self .found_import = False
94+ # Track indentation of the import statement
95+ self .import_indentation = None
96+ self .filename = filename
97+
98+ def leave_ImportFrom (self , original_node : cst .ImportFrom , updated_node : cst .ImportFrom ) -> cst .ImportFrom :
99+ # Check if this is the line profiler import statement
100+ if (isinstance (original_node .module , cst .Name ) and
101+ original_node .module .value == "line_profiler" and
102+ any (name .name .value == "profile" and
103+ (not name .asname or name .asname .name .value == "codeflash_line_profile" )
104+ for name in original_node .names )):
105+
106+ self .found_import = True
107+ # Get the indentation from the original node
108+ if hasattr (original_node , "leading_lines" ):
109+ leading_whitespace = original_node .leading_lines [- 1 ].whitespace if original_node .leading_lines else ""
110+ self .import_indentation = leading_whitespace
111+
112+ return updated_node
113+
114+ def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ) -> cst .Module :
115+ if not self .found_import :
116+ return updated_node
117+
118+ # Create a list of statements from the original module
119+ new_body = list (updated_node .body )
120+
121+ # Find the index of the import statement
122+ import_index = None
123+ for i , stmt in enumerate (new_body ):
124+ if isinstance (stmt , cst .SimpleStatementLine ):
125+ for small_stmt in stmt .body :
126+ if isinstance (small_stmt , cst .ImportFrom ):
127+ if (isinstance (small_stmt .module , cst .Name ) and
128+ small_stmt .module .value == "line_profiler" and
129+ any (name .name .value == "profile" and
130+ (not name .asname or name .asname .name .value == "codeflash_line_profile" )
131+ for name in small_stmt .names )):
132+ import_index = i
133+ break
134+ if import_index is not None :
135+ break
136+
137+ if import_index is not None :
138+ # Create the new enable statement to insert after the import
139+ enable_statement = cst .parse_statement (
140+ f"codeflash_line_profile.enable(output_prefix='{ self .filename } ')"
141+ )
142+
143+ # Insert the new statement after the import statement
144+ new_body .insert (import_index + 1 , enable_statement )
145+
146+ # Create a new module with the updated body
147+ return updated_node .with_changes (body = new_body )
86148
87149def add_decorator_to_qualified_function (module , qualified_name , decorator_name ):
88150 """
@@ -99,48 +161,22 @@ def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
99161 # Parse the source code into a CST
100162
101163 # Apply our transformer
102- transformer = DecoratorAdder (qualified_name , decorator_name )
164+ transformer = LineProfilerDecoratorAdder (qualified_name , decorator_name )
103165 modified_module = module .visit (transformer )
104166
105167 # Convert the modified CST back to source code
106168 return modified_module
107169
108- def add_profile_enable (original_code : str , db_file : str ) -> str :
170+ def add_profile_enable (original_code : str , line_profile_output_file : str ) -> str :
171+ # todo modify by using a libcst transformer
109172 module = cst .parse_module (original_code )
110- found_index = - 1
111-
112- for idx , statement in enumerate (module .body ):
113- if isinstance (statement , cst .SimpleStatementLine ):
114- for stmt in statement .body :
115- if isinstance (stmt , cst .ImportFrom ):
116- if stmt .module and stmt .module .value == 'line_profiler' :
117- for name in stmt .names :
118- if isinstance (name , cst .ImportAlias ):
119- if name .name .value == 'profile' and name .asname is None :
120- found_index = idx
121- break
122- if found_index != - 1 :
123- break
124- if found_index != - 1 :
125- break
126-
127- if found_index == - 1 :
128- return original_code # or raise an exception if the import is not found
129-
130- # Create the new line to insert
131- new_line = f"profile.enable(output_prefix='{ db_file } ')\n "
132- new_statement = cst .parse_statement (new_line )
133-
134- # Insert the new statement into the module's body
135- new_body = list (module .body )
136- new_body .insert (found_index + 1 , new_statement )
137- modified_module = module .with_changes (body = new_body )
138-
173+ transformer = ProfileEnableTransformer (line_profile_output_file )
174+ modified_module = module .visit (transformer )
139175 return modified_module .code
140176
141177
142178class ImportAdder (cst .CSTTransformer ):
143- def __init__ (self , import_statement = 'from line_profiler import profile' ):
179+ def __init__ (self , import_statement ):
144180 self .import_statement = import_statement
145181 self .has_import = False
146182
@@ -166,45 +202,36 @@ def visit_ImportFrom(self, node):
166202
167203
168204def add_decorator_imports (function_to_optimize , code_context ):
205+ """Adds a profile decorator to a function in a Python file and all its helper functions."""
169206 #self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
170- #todo change function signature to get filepaths of fn, helpers and db
171- # modify libcst parser to visit with qualified name
172- file_paths = list ()
173- fn_list = list ()
174- db_file = get_run_tmp_file (Path ("baseline" ))
175- file_paths .append (function_to_optimize .file_path )
176- fn_list .append (function_to_optimize .qualified_name )
207+ #grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile
208+ file_paths = defaultdict (list )
209+ line_profile_output_file = get_run_tmp_file (Path ("baseline_lprof" ))
210+ file_paths [function_to_optimize .file_path ].append (function_to_optimize .qualified_name )
177211 for elem in code_context .helper_functions :
178- file_paths .append (elem .file_path )
179- fn_list .append (elem .qualified_name )
180- """Adds a decorator to a function in a Python file."""
181- for file_path , fn_name in zip (file_paths , fn_list ):
212+ file_paths [elem .file_path ].append (elem .qualified_name )
213+ for file_path ,fns_present in file_paths .items ():
182214 #open file
183215 with open (file_path , "r" , encoding = "utf-8" ) as file :
184216 file_contents = file .read ()
185217 # parse to cst
186218 module_node = cst .parse_module (file_contents )
187- # add decorator
188- module_node = add_decorator_to_qualified_function (module_node , fn_name , 'profile' )
219+ for fn_name in fns_present :
220+ # add decorator
221+ module_node = add_decorator_to_qualified_function (module_node , fn_name , 'codeflash_line_profile' )
189222 # add imports
190223 # Create a transformer to add the import
191- transformer = ImportAdder ("from line_profiler import profile" )
224+ transformer = ImportAdder ("from line_profiler import profile as codeflash_line_profile " )
192225 # Apply the transformer to add the import
193226 module_node = module_node .visit (transformer )
194227 modified_code = isort .code (module_node .code , float_to_top = True )
195228 # write to file
196229 with open (file_path , "w" , encoding = "utf-8" ) as file :
197230 file .write (modified_code )
198- #Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files, can use libcst seems like an overkill, will go just with some simple string manipulation
199- with open (file_paths [ 0 ] ,'r' ) as f :
231+ #Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
232+ with open (function_to_optimize . file_path ,'r' ) as f :
200233 file_contents = f .read ()
201- modified_code = add_profile_enable (file_contents ,db_file )
202- with open (file_paths [ 0 ] ,'w' ) as f :
234+ modified_code = add_profile_enable (file_contents ,str ( line_profile_output_file ) )
235+ with open (function_to_optimize . file_path ,'w' ) as f :
203236 f .write (modified_code )
204- return db_file
205-
206-
207- def prepare_lprofiler_files (prefix : str = "" ) -> tuple [Path ]:
208- """Prepare line profiler output file."""
209- lprofiler_database_file = get_run_tmp_file (Path (prefix ))
210- return lprofiler_database_file
237+ return line_profile_output_file
0 commit comments