1010from codeflash .cli_cmds .console import logger
1111from codeflash .code_utils .code_extractor import add_needed_imports_from_module
1212from codeflash .models .models import FunctionParent
13- import isort
1413
1514if TYPE_CHECKING :
1615 from pathlib import Path
@@ -336,87 +335,4 @@ def function_to_optimize_original_worktree_fqn(
336335 str (worktrees [0 ].name / function_to_optimize .file_path .relative_to (git_root ).with_suffix ("" )).replace ("/" , "." )
337336 + "."
338337 + function_to_optimize .qualified_name
339- )
340-
341-
342- def add_decorator_cst (module_node , function_name , decorator_name ):
343- """Adds a decorator to a function definition in a LibCST module node."""
344-
345- class AddDecoratorTransformer (cst .CSTTransformer ):
346- def leave_FunctionDef (self , original_node , updated_node ):
347- if original_node .name .value == function_name :
348- new_decorator = cst .Decorator (
349- decorator = cst .Name (value = decorator_name )
350- )
351-
352- updated_decorators = list (updated_node .decorators )
353- updated_decorators .insert (0 , new_decorator )
354-
355- return updated_node .with_changes (decorators = updated_decorators )
356- return updated_node
357-
358- transformer = AddDecoratorTransformer ()
359- updated_module = module_node .visit (transformer )
360- return updated_module
361-
362- def add_decorator_imports (file_paths , fn_list , db_file ):
363- """Adds a decorator to a function in a Python file."""
364- for file_path , fn_name in zip (file_paths , fn_list ):
365- #open file
366- with open (file_path , "r" , encoding = "utf-8" ) as file :
367- file_contents = file .read ()
368-
369- # parse to cst
370- module_node = cst .parse_module (file_contents )
371- # add decorator
372- module_node = add_decorator_cst (module_node , fn_name , 'profile' )
373- # add imports
374- # Create a transformer to add the import
375- transformer = ImportAdder ("from line_profiler import profile" )
376-
377- # Apply the transformer to add the import
378- module_node = module_node .visit (transformer )
379- modified_code = isort .code (module_node .code , float_to_top = True )
380- # write to file
381- with open (file_path , "w" , encoding = "utf-8" ) as file :
382- file .write (modified_code )
383- #do this only for the main file and not the helper files, can use libcst but will go just with some simple string manipulation
384- with open (file_paths [0 ],'r' ) as f :
385- file_contents = f .readlines ()
386- for idx , line in enumerate (file_contents ):
387- if 'from line_profiler import profile' in line :
388- file_contents .insert (idx + 1 , f"profile.enable(output_prefix='{ db_file } ')\n " )
389- break
390- with open (file_paths [0 ],'w' ) as f :
391- f .writelines (file_contents )
392-
393-
394-
395-
396- class ImportAdder (cst .CSTTransformer ):
397- def __init__ (self , import_statement = 'from line_profiler import profile' ):
398- self .import_statement = import_statement
399- self .has_import = False
400-
401- def leave_Module (self , original_node , updated_node ):
402- # If the import is already there, don't add it again
403- if self .has_import :
404- return updated_node
405-
406- # Parse the import statement into a CST node
407- import_node = cst .parse_statement (self .import_statement )
408-
409- # Add the import to the module's body
410- return updated_node .with_changes (
411- body = [import_node ] + list (updated_node .body )
412- )
413-
414- def visit_Import (self , node ):
415- pass
416-
417- def visit_ImportFrom (self , node ):
418- # Check if the profile is already imported from line_profiler
419- if node .module and node .module .value == "line_profiler" :
420- for import_alias in node .names :
421- if import_alias .name .value == "profile" :
422- self .has_import = True
338+ )
0 commit comments