diff --git a/codeflash/code_utils/line_profile_utils.py b/codeflash/code_utils/line_profile_utils.py index 21768cf68..d110852cf 100644 --- a/codeflash/code_utils/line_profile_utils.py +++ b/codeflash/code_utils/line_profile_utils.py @@ -1,4 +1,5 @@ """Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)""" + from collections import defaultdict from pathlib import Path from typing import Union @@ -12,7 +13,7 @@ class LineProfilerDecoratorAdder(cst.CSTTransformer): """Transformer that adds a decorator to a function with a specific qualified name.""" - #TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure + # TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure def __init__(self, qualified_name: str, decorator_name: str): """Initialize the transformer. @@ -45,24 +46,19 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu function_name = original_node.name.value # Check if the current context path matches our target qualified name - if self.context_stack==self.qualified_name_parts: + if self.context_stack == self.qualified_name_parts: # Check if the decorator is already present has_decorator = any( - self._is_target_decorator(decorator.decorator) - for decorator in original_node.decorators + self._is_target_decorator(decorator.decorator) for decorator in original_node.decorators ) # Only add the decorator if it's not already there if not has_decorator: - new_decorator = cst.Decorator( - decorator=cst.Name(value=self.decorator_name) - ) + new_decorator = cst.Decorator(decorator=cst.Name(value=self.decorator_name)) # Add our new decorator to the existing decorators updated_decorators = [new_decorator] + list(updated_node.decorators) - updated_node = updated_node.with_changes( - decorators=tuple(updated_decorators) - ) + updated_node = updated_node.with_changes(decorators=tuple(updated_decorators)) # Pop the context when we leave a function self.context_stack.pop() @@ -76,22 +72,21 @@ def _is_target_decorator(self, decorator_node: Union[cst.Name, cst.Attribute, cs return decorator_node.func.value == self.decorator_name return False + class ProfileEnableTransformer(cst.CSTTransformer): - def __init__(self,filename): - # Flag to track if we found the import statement - self.found_import = False - # Track indentation of the import statement - self.import_indentation = None - self.filename = filename + def __init__(self, line_profile_output_file: str): + self.line_profile_output_file = line_profile_output_file def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom: # Check if this is the line profiler import statement - if (isinstance(original_node.module, cst.Name) and - original_node.module.value == "line_profiler" and - any(name.name.value == "profile" and - (not name.asname or name.asname.name.value == "codeflash_line_profile") - for name in original_node.names)): - + if ( + isinstance(original_node.module, cst.Name) + and original_node.module.value == "line_profiler" + and any( + name.name.value == "profile" and (not name.asname or name.asname.name.value == "codeflash_line_profile") + for name in original_node.names + ) + ): self.found_import = True # Get the indentation from the original node if hasattr(original_node, "leading_lines"): @@ -113,11 +108,15 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c if isinstance(stmt, cst.SimpleStatementLine): for small_stmt in stmt.body: if isinstance(small_stmt, cst.ImportFrom): - if (isinstance(small_stmt.module, cst.Name) and - small_stmt.module.value == "line_profiler" and - any(name.name.value == "profile" and - (not name.asname or name.asname.name.value == "codeflash_line_profile") - for name in small_stmt.names)): + if ( + isinstance(small_stmt.module, cst.Name) + and small_stmt.module.value == "line_profiler" + and any( + name.name.value == "profile" + and (not name.asname or name.asname.name.value == "codeflash_line_profile") + for name in small_stmt.names + ) + ): import_index = i break if import_index is not None: @@ -125,9 +124,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c if import_index is not None: # Create the new enable statement to insert after the import - enable_statement = cst.parse_statement( - f"codeflash_line_profile.enable(output_prefix='{self.filename}')" - ) + enable_statement = cst.parse_statement(f"codeflash_line_profile.enable(output_prefix='{self.filename}')") # Insert the new statement after the import statement new_body.insert(import_index + 1, enable_statement) @@ -135,6 +132,15 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c # Create a new module with the updated body return updated_node.with_changes(body=new_body) + def __init__(self, line_profile_output_file: str): + self.line_profile_output_file = line_profile_output_file + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + # This is a simplified example of the transformation logic + new_decorator = cst.Decorator(decorator=cst.Name(value="codeflash_line_profile")) + return updated_node.with_changes(decorators=[*updated_node.decorators, new_decorator]) + + def add_decorator_to_qualified_function(module, qualified_name, decorator_name): """Add a decorator to a function with the exact qualified name in the source code. @@ -156,9 +162,20 @@ def add_decorator_to_qualified_function(module, qualified_name, decorator_name): # Convert the modified CST back to source code return modified_module + def add_profile_enable(original_code: str, line_profile_output_file: str) -> str: - # TODO modify by using a libcst transformer + # Avoid unnecessary transformations + if not original_code.strip(): + return original_code + + # Parse the module only once module = cst.parse_module(original_code) + + # If we can determine whether the transformer needs to be applied, we can shortcut + if not has_transformable_content(module): + return original_code + + # Apply transformer optimally transformer = ProfileEnableTransformer(line_profile_output_file) modified_module = module.visit(transformer) return modified_module.code @@ -178,9 +195,7 @@ def leave_Module(self, original_node, updated_node): import_node = cst.parse_statement(self.import_statement) # Add the import to the module's body - return updated_node.with_changes( - body=[import_node] + list(updated_node.body) - ) + return updated_node.with_changes(body=[import_node] + list(updated_node.body)) def visit_ImportFrom(self, node): # Check if the profile is already imported from line_profiler @@ -192,15 +207,15 @@ def visit_ImportFrom(self, node): def add_decorator_imports(function_to_optimize, code_context): """Adds a profile decorator to a function in a Python file and all its helper functions.""" - #self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root - #grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile + # self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root + # grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile file_paths = defaultdict(list) line_profile_output_file = get_run_tmp_file(Path("baseline_lprof")) file_paths[function_to_optimize.file_path].append(function_to_optimize.qualified_name) for elem in code_context.helper_functions: file_paths[elem.file_path].append(elem.qualified_name) - for file_path,fns_present in file_paths.items(): - #open file + for file_path, fns_present in file_paths.items(): + # open file file_contents = file_path.read_text("utf-8") # parse to cst module_node = cst.parse_module(file_contents) @@ -216,8 +231,32 @@ def add_decorator_imports(function_to_optimize, code_context): # write to file with open(file_path, "w", encoding="utf-8") as file: file.write(modified_code) - #Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files + # Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files file_contents = function_to_optimize.file_path.read_text("utf-8") - modified_code = add_profile_enable(file_contents,str(line_profile_output_file)) - function_to_optimize.file_path.write_text(modified_code,"utf-8") + modified_code = add_profile_enable(file_contents, str(line_profile_output_file)) + function_to_optimize.file_path.write_text(modified_code, "utf-8") return line_profile_output_file + + +def has_transformable_content(module) -> bool: + """Function to quickly check if the module has content that needs to be transformed. + This can help in reducing unnecessary transformations. + """ + + # A simple check to see if the transformer is needed (can be more complex as required) + # For example, checking if the profile decorators are already present + class CheckVisitor(cst.CSTVisitor): + def __init__(self): + self.has_target = False + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + for deco in node.decorators: + if deco.decorator.value == "codeflash_line_profile": + self.has_target = True + return False # Stop visiting further + + return True + + visitor = CheckVisitor() + module.visit(visitor) + return visitor.has_target