Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 52 additions & 58 deletions codeflash/code_utils/line_profile_utils.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,29 @@
"""Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)"""

from collections import defaultdict
from pathlib import Path
from typing import Union

import isort
import libcst as cst
from pathlib import Path
from typing import Union, List
from libcst import ImportFrom, ImportAlias, Name

from codeflash.code_utils.code_utils import get_run_tmp_file


class LineProfilerDecoratorAdder(cst.CSTTransformer):
"""Transformer that adds a decorator to a function with a specific qualified name."""
#Todo we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure

def __init__(self, qualified_name: str, decorator_name: str):
"""
Initialize the transformer.
"""Initialize the transformer.

Args:
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func").
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.target_func").
decorator_name: The name of the decorator to add.

"""
super().__init__()
self.qualified_name_parts = qualified_name.split(".")
self.decorator_name = decorator_name

# Track our current context path, only add when we encounter a class
self.context_stack = []

def visit_ClassDef(self, node: cst.ClassDef) -> None:
Expand All @@ -48,47 +46,36 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
if self._matches_qualified_path():
# Check if the decorator is already present
has_decorator = any(
self._is_target_decorator(decorator.decorator)
for decorator in original_node.decorators
self._is_target_decorator(decorator.decorator) for decorator in original_node.decorators
)

# Only add the decorator if it's not already there
if not has_decorator:
new_decorator = cst.Decorator(
decorator=cst.Name(value=self.decorator_name)
)
new_decorator = cst.Decorator(decorator=cst.Name(value=self.decorator_name))

# Add our new decorator to the existing decorators
updated_decorators = [new_decorator] + list(updated_node.decorators)
updated_node = updated_node.with_changes(
decorators=tuple(updated_decorators)
)
updated_node = updated_node.with_changes(decorators=tuple(updated_decorators))

# Pop the context when we leave a function
self.context_stack.pop()
return updated_node

def _matches_qualified_path(self) -> bool:
"""Check if the current context stack matches the qualified name."""
if len(self.context_stack) != len(self.qualified_name_parts):
return False

for i, name in enumerate(self.qualified_name_parts):
if self.context_stack[i] != name:
return False

return True
return self.context_stack == self.qualified_name_parts

def _is_target_decorator(self, decorator_node: Union[cst.Name, cst.Attribute, cst.Call]) -> bool:
"""Check if a decorator matches our target decorator name."""
if isinstance(decorator_node, cst.Name):
return decorator_node.value == self.decorator_name
elif isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
return decorator_node.func.value == self.decorator_name
return False


class ProfileEnableTransformer(cst.CSTTransformer):
def __init__(self,filename):
def __init__(self, filename):
# Flag to track if we found the import statement
self.found_import = False
# Track indentation of the import statement
Expand All @@ -97,12 +84,14 @@ def __init__(self,filename):

def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
# Check if this is the line profiler import statement
if (isinstance(original_node.module, cst.Name) and
original_node.module.value == "line_profiler" and
any(name.name.value == "profile" and
(not name.asname or name.asname.name.value == "codeflash_line_profile")
for name in original_node.names)):

if (
isinstance(original_node.module, cst.Name)
and original_node.module.value == "line_profiler"
and any(
name.name.value == "profile" and (not name.asname or name.asname.name.value == "codeflash_line_profile")
for name in original_node.names
)
):
self.found_import = True
# Get the indentation from the original node
if hasattr(original_node, "leading_lines"):
Expand All @@ -124,31 +113,33 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
if isinstance(stmt, cst.SimpleStatementLine):
for small_stmt in stmt.body:
if isinstance(small_stmt, cst.ImportFrom):
if (isinstance(small_stmt.module, cst.Name) and
small_stmt.module.value == "line_profiler" and
any(name.name.value == "profile" and
(not name.asname or name.asname.name.value == "codeflash_line_profile")
for name in small_stmt.names)):
if (
isinstance(small_stmt.module, cst.Name)
and small_stmt.module.value == "line_profiler"
and any(
name.name.value == "profile"
and (not name.asname or name.asname.name.value == "codeflash_line_profile")
for name in small_stmt.names
)
):
import_index = i
break
if import_index is not None:
break

if import_index is not None:
# Create the new enable statement to insert after the import
enable_statement = cst.parse_statement(
f"codeflash_line_profile.enable(output_prefix='{self.filename}')"
)
enable_statement = cst.parse_statement(f"codeflash_line_profile.enable(output_prefix='{self.filename}')")

# Insert the new statement after the import statement
new_body.insert(import_index + 1, enable_statement)

# Create a new module with the updated body
return updated_node.with_changes(body=new_body)


def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
"""
Add a decorator to a function with the exact qualified name in the source code.
"""Add a decorator to a function with the exact qualified name in the source code.

Args:
module: The Python source code as a string.
Expand All @@ -157,6 +148,7 @@ def add_decorator_to_qualified_function(module, qualified_name, decorator_name):

Returns:
The modified source code as a string.

"""
# Parse the source code into a CST

Expand All @@ -167,8 +159,9 @@ def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
# Convert the modified CST back to source code
return modified_module


def add_profile_enable(original_code: str, line_profile_output_file: str) -> str:
# todo modify by using a libcst transformer
# TODO modify by using a libcst transformer
module = cst.parse_module(original_code)
transformer = ProfileEnableTransformer(line_profile_output_file)
modified_module = module.visit(transformer)
Expand All @@ -189,9 +182,7 @@ def leave_Module(self, original_node, updated_node):
import_node = cst.parse_statement(self.import_statement)

# Add the import to the module's body
return updated_node.with_changes(
body=[import_node] + list(updated_node.body)
)
return updated_node.with_changes(body=[import_node] + list(updated_node.body))

def visit_ImportFrom(self, node):
# Check if the profile is already imported from line_profiler
Expand All @@ -203,21 +194,22 @@ def visit_ImportFrom(self, node):

def add_decorator_imports(function_to_optimize, code_context):
"""Adds a profile decorator to a function in a Python file and all its helper functions."""
#self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
#grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile
# self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
# grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile
file_paths = defaultdict(list)
line_profile_output_file = get_run_tmp_file(Path("baseline_lprof"))
file_paths[function_to_optimize.file_path].append(function_to_optimize.qualified_name)
for elem in code_context.helper_functions:
file_paths[elem.file_path].append(elem.qualified_name)
for file_path,fns_present in file_paths.items():
#open file
file_contents = file_path.read_text("utf-8")
for file_path, fns_present in file_paths.items():
# open file
with open(file_path, encoding="utf-8") as file:
file_contents = file.read()
# parse to cst
module_node = cst.parse_module(file_contents)
for fn_name in fns_present:
# add decorator
module_node = add_decorator_to_qualified_function(module_node, fn_name, 'codeflash_line_profile')
module_node = add_decorator_to_qualified_function(module_node, fn_name, "codeflash_line_profile")
# add imports
# Create a transformer to add the import
transformer = ImportAdder("from line_profiler import profile as codeflash_line_profile")
Expand All @@ -227,8 +219,10 @@ def add_decorator_imports(function_to_optimize, code_context):
# write to file
with open(file_path, "w", encoding="utf-8") as file:
file.write(modified_code)
#Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
file_contents = function_to_optimize.file_path.read_text("utf-8")
modified_code = add_profile_enable(file_contents,str(line_profile_output_file))
function_to_optimize.file_path.write_text(modified_code,"utf-8")
return line_profile_output_file
# Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
with open(function_to_optimize.file_path) as f:
file_contents = f.read()
modified_code = add_profile_enable(file_contents, str(line_profile_output_file))
with open(function_to_optimize.file_path, "w") as f:
f.write(modified_code)
return line_profile_output_file
Loading