@@ -13,57 +13,46 @@ def __init__(self, target_functions: set[tuple[str, str]]) -> None:
1313 self .added_codeflash_trace = False
1414 self .class_name = ""
1515 self .function_name = ""
16- self .decorator = cst .Decorator (
17- decorator = cst .Name (value = "codeflash_trace" )
18- )
16+ self .decorator = cst .Decorator (decorator = cst .Name (value = "codeflash_trace" ))
1917
20- def leave_ClassDef (self , original_node , updated_node ):
18+ def leave_ClassDef (self , original_node , updated_node ): # noqa: ANN001, ANN201, N802
2119 if self .class_name == original_node .name .value :
22- self .class_name = "" # Even if nested classes are not visited, this function is still called on them
20+ self .class_name = "" # Even if nested classes are not visited, this function is still called on them
2321 return updated_node
2422
25- def visit_ClassDef (self , node ):
26- if self .class_name : # Don't go into nested class
23+ def visit_ClassDef (self , node ): # noqa: ANN001, ANN201, N802
24+ if self .class_name : # Don't go into nested class
2725 return False
28- self .class_name = node .name .value
26+ self .class_name = node .name .value # noqa: RET503
2927
30- def visit_FunctionDef (self , node ):
31- if self .function_name : # Don't go into nested function
28+ def visit_FunctionDef (self , node ): # noqa: ANN001, ANN201, N802
29+ if self .function_name : # Don't go into nested function
3230 return False
33- self .function_name = node .name .value
31+ self .function_name = node .name .value # noqa: RET503
3432
35- def leave_FunctionDef (self , original_node , updated_node ):
33+ def leave_FunctionDef (self , original_node , updated_node ): # noqa: ANN001, ANN201, N802
3634 if self .function_name == original_node .name .value :
3735 self .function_name = ""
3836 if (self .class_name , original_node .name .value ) in self .target_functions :
3937 # Add the new decorator after any existing decorators, so it gets executed first
40- updated_decorators = list (updated_node .decorators ) + [ self .decorator ]
38+ updated_decorators = [ * list (updated_node .decorators ), self .decorator ]
4139 self .added_codeflash_trace = True
42- return updated_node .with_changes (
43- decorators = updated_decorators
44- )
40+ return updated_node .with_changes (decorators = updated_decorators )
4541
4642 return updated_node
4743
48- def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ) -> cst .Module :
44+ def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ) -> cst .Module : # noqa: ARG002, N802
4945 # Create import statement for codeflash_trace
5046 if not self .added_codeflash_trace :
5147 return updated_node
5248 import_stmt = cst .SimpleStatementLine (
5349 body = [
5450 cst .ImportFrom (
5551 module = cst .Attribute (
56- value = cst .Attribute (
57- value = cst .Name (value = "codeflash" ),
58- attr = cst .Name (value = "benchmarking" )
59- ),
60- attr = cst .Name (value = "codeflash_trace" )
52+ value = cst .Attribute (value = cst .Name (value = "codeflash" ), attr = cst .Name (value = "benchmarking" )),
53+ attr = cst .Name (value = "codeflash_trace" ),
6154 ),
62- names = [
63- cst .ImportAlias (
64- name = cst .Name (value = "codeflash_trace" )
65- )
66- ]
55+ names = [cst .ImportAlias (name = cst .Name (value = "codeflash_trace" ))],
6756 )
6857 ]
6958 )
@@ -73,12 +62,13 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
7362
7463 return updated_node .with_changes (body = new_body )
7564
65+
7666def add_codeflash_decorator_to_code (code : str , functions_to_optimize : list [FunctionToOptimize ]) -> str :
7767 """Add codeflash_trace to a function.
7868
7969 Args:
8070 code: The source code as a string
81- function_to_optimize: The FunctionToOptimize instance containing function details
71+ functions_to_optimize: List of FunctionToOptimize instances containing function details
8272
8373 Returns:
8474 The modified source code as a string
@@ -91,25 +81,18 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct
9181 class_name = function_to_optimize .parents [0 ].name
9282 target_functions .add ((class_name , function_to_optimize .function_name ))
9383
94- transformer = AddDecoratorTransformer (
95- target_functions = target_functions ,
96- )
84+ transformer = AddDecoratorTransformer (target_functions = target_functions )
9785
9886 module = cst .parse_module (code )
9987 modified_module = module .visit (transformer )
10088 return modified_module .code
10189
10290
103- def instrument_codeflash_trace_decorator (
104- file_to_funcs_to_optimize : dict [Path , list [FunctionToOptimize ]]
105- ) -> None :
91+ def instrument_codeflash_trace_decorator (file_to_funcs_to_optimize : dict [Path , list [FunctionToOptimize ]]) -> None :
10692 """Instrument codeflash_trace decorator to functions to optimize."""
10793 for file_path , functions_to_optimize in file_to_funcs_to_optimize .items ():
10894 original_code = file_path .read_text (encoding = "utf-8" )
109- new_code = add_codeflash_decorator_to_code (
110- original_code ,
111- functions_to_optimize
112- )
95+ new_code = add_codeflash_decorator_to_code (original_code , functions_to_optimize )
11396 # Modify the code
11497 modified_code = isort .code (code = new_code , float_to_top = True )
11598
0 commit comments