@@ -23,38 +23,38 @@ def leave_FunctionDef(self, original_node, updated_node):
2323 updated_module = module_node .visit (transformer )
2424 return updated_module
2525
26- def add_decorator_imports (file_paths , fn_list , db_file ):
27- """Adds a decorator to a function in a Python file."""
28- for file_path , fn_name in zip (file_paths , fn_list ):
29- #open file
30- with open (file_path , "r" , encoding = "utf-8" ) as file :
31- file_contents = file .read ()
26+ def add_profile_enable (original_code : str , db_file : str ) -> str :
27+ module = cst .parse_module (original_code )
28+ found_index = - 1
29+
30+ for idx , statement in enumerate (module .body ):
31+ if isinstance (statement , cst .SimpleStatementLine ):
32+ for stmt in statement .body :
33+ if isinstance (stmt , cst .ImportFrom ):
34+ if stmt .module and stmt .module .value == 'line_profiler' :
35+ for name in stmt .names :
36+ if isinstance (name , cst .ImportAlias ):
37+ if name .name .value == 'profile' and name .asname is None :
38+ found_index = idx
39+ break
40+ if found_index != - 1 :
41+ break
42+ if found_index != - 1 :
43+ break
3244
33- # parse to cst
34- module_node = cst .parse_module (file_contents )
35- # add decorator
36- module_node = add_decorator_cst (module_node , fn_name , 'profile' )
37- # add imports
38- # Create a transformer to add the import
39- transformer = ImportAdder ("from line_profiler import profile" )
45+ if found_index == - 1 :
46+ return original_code # or raise an exception if the import is not found
4047
41- # Apply the transformer to add the import
42- module_node = module_node .visit (transformer )
43- modified_code = isort .code (module_node .code , float_to_top = True )
44- # write to file
45- with open (file_path , "w" , encoding = "utf-8" ) as file :
46- file .write (modified_code )
47- #Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files, can use libcst seems like an overkill, will go just with some simple string manipulation
48- with open (file_paths [0 ],'r' ) as f :
49- file_contents = f .readlines ()
50- for idx , line in enumerate (file_contents ):
51- if 'from line_profiler import profile' in line :
52- file_contents .insert (idx + 1 , f"profile.enable(output_prefix='{ db_file } ')\n " )
53- break
54- with open (file_paths [0 ],'w' ) as f :
55- f .writelines (file_contents )
48+ # Create the new line to insert
49+ new_line = f"profile.enable(output_prefix='{ db_file } ')\n "
50+ new_statement = cst .parse_statement (new_line )
5651
52+ # Insert the new statement into the module's body
53+ new_body = list (module .body )
54+ new_body .insert (found_index + 1 , new_statement )
55+ modified_module = module .with_changes (body = new_body )
5756
57+ return modified_module .code
5858
5959
6060class ImportAdder (cst .CSTTransformer ):
@@ -82,6 +82,36 @@ def visit_ImportFrom(self, node):
8282 if import_alias .name .value == "profile" :
8383 self .has_import = True
8484
85+
86+ def add_decorator_imports (file_paths , fn_list , db_file ):
87+ """Adds a decorator to a function in a Python file."""
88+ for file_path , fn_name in zip (file_paths , fn_list ):
89+ #open file
90+ with open (file_path , "r" , encoding = "utf-8" ) as file :
91+ file_contents = file .read ()
92+
93+ # parse to cst
94+ module_node = cst .parse_module (file_contents )
95+ # add decorator
96+ module_node = add_decorator_cst (module_node , fn_name , 'profile' )
97+ # add imports
98+ # Create a transformer to add the import
99+ transformer = ImportAdder ("from line_profiler import profile" )
100+
101+ # Apply the transformer to add the import
102+ module_node = module_node .visit (transformer )
103+ modified_code = isort .code (module_node .code , float_to_top = True )
104+ # write to file
105+ with open (file_path , "w" , encoding = "utf-8" ) as file :
106+ file .write (modified_code )
107+ #Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files, can use libcst seems like an overkill, will go just with some simple string manipulation
108+ with open (file_paths [0 ],'r' ) as f :
109+ file_contents = f .read ()
110+ modified_code = add_profile_enable (file_contents ,db_file )
111+ with open (file_paths [0 ],'w' ) as f :
112+ f .write (modified_code )
113+
114+
85115def prepare_lprofiler_files (prefix : str = "" ) -> tuple [Path ]:
86116 """Prepare line profiler output file."""
87117 lprofiler_database_file = get_run_tmp_file (Path (prefix ))
0 commit comments