Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/workflows/pre-commit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Lint
on:
pull_request:
push:
branches:
- main

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
lint:
name: Run pre-commit hooks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- uses: pre-commit/[email protected]
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.11.0"
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --config=pyproject.toml]
- id: ruff-format
6 changes: 2 additions & 4 deletions codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations

import time

import json
import os
import platform
import time
from typing import TYPE_CHECKING, Any

import requests
Expand Down Expand Up @@ -177,7 +176,7 @@ def optimize_python_code_line_profiler(

logger.info("Generating optimized candidates…")
console.rule()
if line_profiler_results=="":
if line_profiler_results == "":
logger.info("No LineProfiler results were provided, Skipping optimization.")
console.rule()
return []
Expand Down Expand Up @@ -209,7 +208,6 @@ def optimize_python_code_line_profiler(
console.rule()
return []


def log_results(
self,
function_trace_id: str,
Expand Down
55 changes: 44 additions & 11 deletions codeflash/benchmarking/codeflash_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def write_function_timings(self) -> None:
"(function_name, class_name, module_name, file_path, benchmark_function_name, "
"benchmark_module_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
self.function_calls_data
self.function_calls_data,
)
self._connection.commit()
self.function_calls_data = []
Expand Down Expand Up @@ -100,7 +100,8 @@ def __call__(self, func: Callable) -> Callable:
The wrapped function

"""
func_id = (func.__module__,func.__name__)
func_id = (func.__module__, func.__name__)

@functools.wraps(func)
def wrapper(*args, **kwargs):
# Initialize thread-local active functions set if it doesn't exist
Expand Down Expand Up @@ -139,9 +140,19 @@ def wrapper(*args, **kwargs):
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
overhead_time, None, None)
(
func.__name__,
class_name,
func.__module__,
func.__code__.co_filename,
benchmark_function_name,
benchmark_module_path,
benchmark_line_number,
execution_time,
overhead_time,
None,
None,
)
)
return result

Expand All @@ -155,9 +166,19 @@ def wrapper(*args, **kwargs):
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
overhead_time, None, None)
(
func.__name__,
class_name,
func.__module__,
func.__code__.co_filename,
benchmark_function_name,
benchmark_module_path,
benchmark_line_number,
execution_time,
overhead_time,
None,
None,
)
)
return result
# Flush to database every 100 calls
Expand All @@ -168,12 +189,24 @@ def wrapper(*args, **kwargs):
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
overhead_time, pickled_args, pickled_kwargs)
(
func.__name__,
class_name,
func.__module__,
func.__code__.co_filename,
benchmark_function_name,
benchmark_module_path,
benchmark_line_number,
execution_time,
overhead_time,
pickled_args,
pickled_kwargs,
)
)
return result

return wrapper


# Create a singleton instance
codeflash_trace = CodeflashTrace()
41 changes: 12 additions & 29 deletions codeflash/benchmarking/instrument_codeflash_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,20 @@ def __init__(self, target_functions: set[tuple[str, str]]) -> None:
self.added_codeflash_trace = False
self.class_name = ""
self.function_name = ""
self.decorator = cst.Decorator(
decorator=cst.Name(value="codeflash_trace")
)
self.decorator = cst.Decorator(decorator=cst.Name(value="codeflash_trace"))

def leave_ClassDef(self, original_node, updated_node):
if self.class_name == original_node.name.value:
self.class_name = "" # Even if nested classes are not visited, this function is still called on them
self.class_name = "" # Even if nested classes are not visited, this function is still called on them
return updated_node

def visit_ClassDef(self, node):
if self.class_name: # Don't go into nested class
if self.class_name: # Don't go into nested class
return False
self.class_name = node.name.value

def visit_FunctionDef(self, node):
if self.function_name: # Don't go into nested function
if self.function_name: # Don't go into nested function
return False
self.function_name = node.name.value

Expand All @@ -39,9 +37,7 @@ def leave_FunctionDef(self, original_node, updated_node):
# Add the new decorator after any existing decorators, so it gets executed first
updated_decorators = list(updated_node.decorators) + [self.decorator]
self.added_codeflash_trace = True
return updated_node.with_changes(
decorators=updated_decorators
)
return updated_node.with_changes(decorators=updated_decorators)

return updated_node

Expand All @@ -53,17 +49,10 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
body=[
cst.ImportFrom(
module=cst.Attribute(
value=cst.Attribute(
value=cst.Name(value="codeflash"),
attr=cst.Name(value="benchmarking")
),
attr=cst.Name(value="codeflash_trace")
value=cst.Attribute(value=cst.Name(value="codeflash"), attr=cst.Name(value="benchmarking")),
attr=cst.Name(value="codeflash_trace"),
),
names=[
cst.ImportAlias(
name=cst.Name(value="codeflash_trace")
)
]
names=[cst.ImportAlias(name=cst.Name(value="codeflash_trace"))],
)
]
)
Expand All @@ -73,6 +62,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c

return updated_node.with_changes(body=new_body)


def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[FunctionToOptimize]) -> str:
"""Add codeflash_trace to a function.

Expand All @@ -91,25 +81,18 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct
class_name = function_to_optimize.parents[0].name
target_functions.add((class_name, function_to_optimize.function_name))

transformer = AddDecoratorTransformer(
target_functions = target_functions,
)
transformer = AddDecoratorTransformer(target_functions=target_functions)

module = cst.parse_module(code)
modified_module = module.visit(transformer)
return modified_module.code


def instrument_codeflash_trace_decorator(
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
) -> None:
def instrument_codeflash_trace_decorator(file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]) -> None:
"""Instrument codeflash_trace decorator to functions to optimize."""
for file_path, functions_to_optimize in file_to_funcs_to_optimize.items():
original_code = file_path.read_text(encoding="utf-8")
new_code = add_codeflash_decorator_to_code(
original_code,
functions_to_optimize
)
new_code = add_codeflash_decorator_to_code(original_code, functions_to_optimize)
# Modify the code
modified_code = isort.code(code=new_code, float_to_top=True)

Expand Down
49 changes: 30 additions & 19 deletions codeflash/benchmarking/plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ class CodeFlashBenchmarkPlugin:
def __init__(self) -> None:
self._trace_path = None
self._connection = None
self._cursor = None
self.project_root = None
self.benchmark_timings = []

def setup(self, trace_path:str, project_root:str) -> None:
def setup(self, trace_path: str, project_root: str) -> None:
try:
# Open connection
self.project_root = project_root
Expand All @@ -35,7 +36,7 @@ def setup(self, trace_path:str, project_root:str) -> None:
"benchmark_time_ns INTEGER)"
)
self._connection.commit()
self.close() # Reopen only at the end of pytest session
self.close() # Reopen only at the end of pytest session
except Exception as e:
print(f"Database setup error: {e}")
if self._connection:
Expand All @@ -47,22 +48,21 @@ def write_benchmark_timings(self) -> None:
if not self.benchmark_timings:
return # No data to write

if self._connection is None:
self._connection = sqlite3.connect(self._trace_path)
self._ensure_connection()

try:
cur = self._connection.cursor()
# Insert data into the benchmark_timings table
cur.executemany(
self._cursor.executemany(
"INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)",
self.benchmark_timings
self.benchmark_timings,
)
self._connection.commit()
self.benchmark_timings = [] # Clear the benchmark timings list
self.benchmark_timings.clear() # Clear the benchmark timings list (reuses the list object)
except Exception as e:
print(f"Error writing to benchmark timings database: {e}")
self._connection.rollback()
raise

def close(self) -> None:
if self._connection:
self._connection.close()
Expand Down Expand Up @@ -196,12 +196,7 @@ def pytest_sessionfinish(self, session, exitstatus):

@staticmethod
def pytest_addoption(parser):
parser.addoption(
"--codeflash-trace",
action="store_true",
default=False,
help="Enable CodeFlash tracing"
)
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")

@staticmethod
def pytest_plugin_registered(plugin, manager):
Expand All @@ -213,9 +208,9 @@ def pytest_plugin_registered(plugin, manager):
def pytest_configure(config):
"""Register the benchmark marker."""
config.addinivalue_line(
"markers",
"benchmark: mark test as a benchmark that should be run with codeflash tracing"
"markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing"
)

@staticmethod
def pytest_collection_modifyitems(config, items):
# Skip tests that don't have the benchmark fixture
Expand Down Expand Up @@ -248,16 +243,19 @@ def __call__(self, func, *args, **kwargs):
if args or kwargs:
# Used as benchmark(func, *args, **kwargs)
return self._run_benchmark(func, *args, **kwargs)

# Used as @benchmark decorator
def wrapped_func(*args, **kwargs):
return func(*args, **kwargs)

result = self._run_benchmark(func)
return wrapped_func

def _run_benchmark(self, func, *args, **kwargs):
"""Actual benchmark implementation."""
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)),
Path(codeflash_benchmark_plugin.project_root))
benchmark_module_path = module_name_from_file_path(
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)
)
benchmark_function_name = self.request.node.name
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack
# Set env vars
Expand All @@ -278,7 +276,8 @@ def _run_benchmark(self, func, *args, **kwargs):
codeflash_trace.function_call_count = 0
# Add to the benchmark timings buffer
codeflash_benchmark_plugin.benchmark_timings.append(
(benchmark_module_path, benchmark_function_name, line_number, end - start))
(benchmark_module_path, benchmark_function_name, line_number, end - start)
)

return result

Expand All @@ -290,4 +289,16 @@ def benchmark(request):

return CodeFlashBenchmarkPlugin.Benchmark(request)

def _ensure_connection(self) -> None:
# Establish DB connection and optimize settings for faster inserts, if not already done
if self._connection is None:
self._connection = sqlite3.connect(self._trace_path)
self._cursor = self._connection.cursor()
# Speed up inserts by relaxing durability
self._cursor.execute("PRAGMA synchronous = OFF")
self._cursor.execute("PRAGMA journal_mode = MEMORY")
elif self._cursor is None:
self._cursor = self._connection.cursor()


codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()
Loading
Loading