11"""Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)"""
2+
23from collections import defaultdict
34from pathlib import Path
45from typing import Union
1213class LineProfilerDecoratorAdder (cst .CSTTransformer ):
1314 """Transformer that adds a decorator to a function with a specific qualified name."""
1415
15- #TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure
16+ # TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure
1617 def __init__ (self , qualified_name : str , decorator_name : str ):
1718 """Initialize the transformer.
1819
@@ -45,24 +46,19 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
4546 function_name = original_node .name .value
4647
4748 # Check if the current context path matches our target qualified name
48- if self .context_stack == self .qualified_name_parts :
49+ if self .context_stack == self .qualified_name_parts :
4950 # Check if the decorator is already present
5051 has_decorator = any (
51- self ._is_target_decorator (decorator .decorator )
52- for decorator in original_node .decorators
52+ self ._is_target_decorator (decorator .decorator ) for decorator in original_node .decorators
5353 )
5454
5555 # Only add the decorator if it's not already there
5656 if not has_decorator :
57- new_decorator = cst .Decorator (
58- decorator = cst .Name (value = self .decorator_name )
59- )
57+ new_decorator = cst .Decorator (decorator = cst .Name (value = self .decorator_name ))
6058
6159 # Add our new decorator to the existing decorators
6260 updated_decorators = [new_decorator ] + list (updated_node .decorators )
63- updated_node = updated_node .with_changes (
64- decorators = tuple (updated_decorators )
65- )
61+ updated_node = updated_node .with_changes (decorators = tuple (updated_decorators ))
6662
6763 # Pop the context when we leave a function
6864 self .context_stack .pop ()
@@ -76,8 +72,9 @@ def _is_target_decorator(self, decorator_node: Union[cst.Name, cst.Attribute, cs
7672 return decorator_node .func .value == self .decorator_name
7773 return False
7874
75+
7976class ProfileEnableTransformer (cst .CSTTransformer ):
80- def __init__ (self ,filename ):
77+ def __init__ (self , filename ):
8178 # Flag to track if we found the import statement
8279 self .found_import = False
8380 # Track indentation of the import statement
@@ -86,12 +83,14 @@ def __init__(self,filename):
8683
8784 def leave_ImportFrom (self , original_node : cst .ImportFrom , updated_node : cst .ImportFrom ) -> cst .ImportFrom :
8885 # Check if this is the line profiler import statement
89- if (isinstance (original_node .module , cst .Name ) and
90- original_node .module .value == "line_profiler" and
91- any (name .name .value == "profile" and
92- (not name .asname or name .asname .name .value == "codeflash_line_profile" )
93- for name in original_node .names )):
94-
86+ if (
87+ isinstance (original_node .module , cst .Name )
88+ and original_node .module .value == "line_profiler"
89+ and any (
90+ name .name .value == "profile" and (not name .asname or name .asname .name .value == "codeflash_line_profile" )
91+ for name in original_node .names
92+ )
93+ ):
9594 self .found_import = True
9695 # Get the indentation from the original node
9796 if hasattr (original_node , "leading_lines" ):
@@ -113,28 +112,31 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
113112 if isinstance (stmt , cst .SimpleStatementLine ):
114113 for small_stmt in stmt .body :
115114 if isinstance (small_stmt , cst .ImportFrom ):
116- if (isinstance (small_stmt .module , cst .Name ) and
117- small_stmt .module .value == "line_profiler" and
118- any (name .name .value == "profile" and
119- (not name .asname or name .asname .name .value == "codeflash_line_profile" )
120- for name in small_stmt .names )):
115+ if (
116+ isinstance (small_stmt .module , cst .Name )
117+ and small_stmt .module .value == "line_profiler"
118+ and any (
119+ name .name .value == "profile"
120+ and (not name .asname or name .asname .name .value == "codeflash_line_profile" )
121+ for name in small_stmt .names
122+ )
123+ ):
121124 import_index = i
122125 break
123126 if import_index is not None :
124127 break
125128
126129 if import_index is not None :
127130 # Create the new enable statement to insert after the import
128- enable_statement = cst .parse_statement (
129- f"codeflash_line_profile.enable(output_prefix='{ self .filename } ')"
130- )
131+ enable_statement = cst .parse_statement (f"codeflash_line_profile.enable(output_prefix='{ self .filename } ')" )
131132
132133 # Insert the new statement after the import statement
133134 new_body .insert (import_index + 1 , enable_statement )
134135
135136 # Create a new module with the updated body
136137 return updated_node .with_changes (body = new_body )
137138
139+
138140def add_decorator_to_qualified_function (module , qualified_name , decorator_name ):
139141 """Add a decorator to a function with the exact qualified name in the source code.
140142
@@ -156,6 +158,7 @@ def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
156158 # Convert the modified CST back to source code
157159 return modified_module
158160
161+
159162def add_profile_enable (original_code : str , line_profile_output_file : str ) -> str :
160163 # TODO modify by using a libcst transformer
161164 module = cst .parse_module (original_code )
@@ -178,9 +181,7 @@ def leave_Module(self, original_node, updated_node):
178181 import_node = cst .parse_statement (self .import_statement )
179182
180183 # Add the import to the module's body
181- return updated_node .with_changes (
182- body = [import_node ] + list (updated_node .body )
183- )
184+ return updated_node .with_changes (body = [import_node ] + list (updated_node .body ))
184185
185186 def visit_ImportFrom (self , node ):
186187 # Check if the profile is already imported from line_profiler
@@ -192,15 +193,15 @@ def visit_ImportFrom(self, node):
192193
193194def add_decorator_imports (function_to_optimize , code_context ):
194195 """Adds a profile decorator to a function in a Python file and all its helper functions."""
195- #self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
196- #grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile
196+ # self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
197+ # grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile
197198 file_paths = defaultdict (list )
198199 line_profile_output_file = get_run_tmp_file (Path ("baseline_lprof" ))
199200 file_paths [function_to_optimize .file_path ].append (function_to_optimize .qualified_name )
200201 for elem in code_context .helper_functions :
201202 file_paths [elem .file_path ].append (elem .qualified_name )
202- for file_path ,fns_present in file_paths .items ():
203- #open file
203+ for file_path , fns_present in file_paths .items ():
204+ # open file
204205 file_contents = file_path .read_text ("utf-8" )
205206 # parse to cst
206207 module_node = cst .parse_module (file_contents )
@@ -216,8 +217,8 @@ def add_decorator_imports(function_to_optimize, code_context):
216217 # write to file
217218 with open (file_path , "w" , encoding = "utf-8" ) as file :
218219 file .write (modified_code )
219- #Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
220+ # Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
220221 file_contents = function_to_optimize .file_path .read_text ("utf-8" )
221- modified_code = add_profile_enable (file_contents ,str (line_profile_output_file ))
222- function_to_optimize .file_path .write_text (modified_code ,"utf-8" )
222+ modified_code = add_profile_enable (file_contents , str (line_profile_output_file ))
223+ function_to_optimize .file_path .write_text (modified_code , "utf-8" )
223224 return line_profile_output_file
0 commit comments