Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,4 @@ jobs:
run: uvx poetry install --with dev

- name: Unit tests
run: uvx poetry run pytest tests/ --cov --cov-report=xml --benchmark-skip -m "not ci_skip"

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
if: matrix.python-version == '3.12.1'
with:
token: ${{ secrets.CODECOV_TOKEN }}
run: uvx poetry run pytest tests/ --benchmark-skip -m "not ci_skip"
9 changes: 8 additions & 1 deletion codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,14 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list


def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
return any(isinstance(node, ast.Return) for node in ast.walk(function_node))
# Custom DFS, return True as soon as a Return node is found
stack = [function_node]
while stack:
node = stack.pop()
if isinstance(node, ast.Return):
return True
stack.extend(ast.iter_child_nodes(node))
return False


def function_is_a_property(function_node: FunctionDef | AsyncFunctionDef) -> bool:
Expand Down
214 changes: 137 additions & 77 deletions codeflash/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#
from __future__ import annotations

import contextlib
import importlib.machinery
import io
import json
Expand Down Expand Up @@ -99,6 +100,7 @@ def __init__(
)
disable = True
self.disable = disable
self._db_lock: threading.Lock | None = None
if self.disable:
return
if sys.getprofile() is not None or sys.gettrace() is not None:
Expand All @@ -108,6 +110,9 @@ def __init__(
)
self.disable = True
return

self._db_lock = threading.Lock()

self.con = None
self.output_file = Path(output).resolve()
self.functions = functions
Expand All @@ -130,6 +135,7 @@ def __init__(
self.timeout = timeout
self.next_insert = 1000
self.trace_count = 0
self.path_cache = {} # Cache for resolved file paths

# Profiler variables
self.bias = 0 # calibration constant
Expand Down Expand Up @@ -178,34 +184,55 @@ def __enter__(self) -> None:
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
if self.disable:
if self.disable or self._db_lock is None:
return
sys.setprofile(None)
self.con.commit()
console.rule("Codeflash: Traced Program Output End", style="bold blue")
self.create_stats()
threading.setprofile(None)

cur = self.con.cursor()
cur.execute(
"CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
"call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
"cumulative_time_ns INTEGER, callers BLOB)"
)
for func, (cc, nc, tt, ct, callers) in self.stats.items():
remapped_callers = [{"key": k, "value": v} for k, v in callers.items()]
with self._db_lock:
if self.con is None:
return

self.con.commit() # Commit any pending from tracer_logic
console.rule("Codeflash: Traced Program Output End", style="bold blue")
self.create_stats() # This calls snapshot_stats which uses self.timings

cur = self.con.cursor()
cur.execute(
"INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(str(Path(func[0]).resolve()), func[1], func[2], func[3], cc, nc, tt, ct, json.dumps(remapped_callers)),
"CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
"call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
"cumulative_time_ns INTEGER, callers BLOB)"
)
self.con.commit()
# self.stats is populated by snapshot_stats() called within create_stats()
# Ensure self.stats is accessed after create_stats() and within the lock if it involves DB data
# For now, assuming self.stats is primarily in-memory after create_stats()
for func, (cc, nc, tt, ct, callers) in self.stats.items():
remapped_callers = [{"key": k, "value": v} for k, v in callers.items()]
cur.execute(
"INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
str(Path(func[0]).resolve()),
func[1],
func[2],
func[3],
cc,
nc,
tt,
ct,
json.dumps(remapped_callers),
),
)
self.con.commit()

self.make_pstats_compatible()
self.print_stats("tottime")
cur = self.con.cursor()
cur.execute("CREATE TABLE total_time (time_ns INTEGER)")
cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,))
self.con.commit()
self.con.close()
self.make_pstats_compatible() # Modifies self.stats and self.timings in-memory
self.print_stats("tottime") # Uses self.stats, prints to console

cur = self.con.cursor() # New cursor
cur.execute("CREATE TABLE total_time (time_ns INTEGER)")
cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,))
self.con.commit()
self.con.close()
self.con = None # Mark connection as closed

# filter any functions where we did not capture the return
self.function_modules = [
Expand Down Expand Up @@ -245,18 +272,29 @@ def __exit__(
def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
if event != "call":
return
if self.timeout is not None and (time.time() - self.start_time) > self.timeout:
if None is not self.timeout and (time.time() - self.start_time) > self.timeout:
sys.setprofile(None)
threading.setprofile(None)
console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.")
return
code = frame.f_code
if self.disable or self._db_lock is None or self.con is None:
return

file_name = Path(code.co_filename).resolve()
# TODO : It currently doesn't log the last return call from the first function
code = frame.f_code

# Check function name first before resolving path
if code.co_name in self.ignored_functions:
return

# Now resolve file path only if we need it
co_filename = code.co_filename
if co_filename in self.path_cache:
file_name = self.path_cache[co_filename]
else:
file_name = Path(co_filename).resolve()
self.path_cache[co_filename] = file_name
# TODO : It currently doesn't log the last return call from the first function

if not file_name.is_relative_to(self.project_root):
return
if not file_name.exists():
Expand All @@ -266,18 +304,29 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
class_name = None
arguments = frame.f_locals
try:
if (
"self" in arguments
and hasattr(arguments["self"], "__class__")
and hasattr(arguments["self"].__class__, "__name__")
):
class_name = arguments["self"].__class__.__name__
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
class_name = arguments["cls"].__name__
self_arg = arguments.get("self")
if self_arg is not None:
try:
class_name = self_arg.__class__.__name__
except AttributeError:
cls_arg = arguments.get("cls")
if cls_arg is not None:
with contextlib.suppress(AttributeError):
class_name = cls_arg.__name__
else:
cls_arg = arguments.get("cls")
if cls_arg is not None:
with contextlib.suppress(AttributeError):
class_name = cls_arg.__name__
except: # noqa: E722
# someone can override the getattr method and raise an exception. I'm looking at you wrapt
return
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"

try:
function_qualified_name = f"{file_name}:{code.co_qualname}"
except AttributeError:
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"

if function_qualified_name in self.ignored_qualified_functions:
return
if function_qualified_name not in self.function_count:
Expand Down Expand Up @@ -310,49 +359,59 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911

# TODO: Also check if this function arguments are unique from the values logged earlier

cur = self.con.cursor()
with self._db_lock:
# Check connection again inside lock, in case __exit__ closed it.
if self.con is None:
return

t_ns = time.perf_counter_ns()
original_recursion_limit = sys.getrecursionlimit()
try:
# pickling can be a recursive operator, so we need to increase the recursion limit
sys.setrecursionlimit(10000)
# We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
# directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
# leaks, bad references or side effects when unpickling.
arguments = dict(arguments.items())
if class_name and code.co_name == "__init__":
del arguments["self"]
local_vars = pickle.dumps(arguments, protocol=pickle.HIGHEST_PROTOCOL)
sys.setrecursionlimit(original_recursion_limit)
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
# we retry with dill if pickle fails. It's slower but more comprehensive
cur = self.con.cursor()

t_ns = time.perf_counter_ns()
original_recursion_limit = sys.getrecursionlimit()
try:
local_vars = dill.dumps(arguments, protocol=dill.HIGHEST_PROTOCOL)
# pickling can be a recursive operator, so we need to increase the recursion limit
sys.setrecursionlimit(10000)
# We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
# directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
# leaks, bad references or side effects when unpickling.
arguments_copy = dict(arguments.items()) # Use the local 'arguments' from frame.f_locals
if class_name and code.co_name == "__init__" and "self" in arguments_copy:
del arguments_copy["self"]
local_vars = pickle.dumps(arguments_copy, protocol=pickle.HIGHEST_PROTOCOL)
sys.setrecursionlimit(original_recursion_limit)
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
# we retry with dill if pickle fails. It's slower but more comprehensive
try:
sys.setrecursionlimit(10000) # Ensure limit is high for dill too
# arguments_copy should be used here as well if defined above
local_vars = dill.dumps(
arguments_copy if "arguments_copy" in locals() else dict(arguments.items()),
protocol=dill.HIGHEST_PROTOCOL,
)
sys.setrecursionlimit(original_recursion_limit)

except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError):
self.function_count[function_qualified_name] -= 1
return

except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError):
# give up
self.function_count[function_qualified_name] -= 1
return
cur.execute(
"INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)",
(
event,
code.co_name,
class_name,
str(file_name),
frame.f_lineno,
frame.f_back.__hash__(),
t_ns,
local_vars,
),
)
self.trace_count += 1
self.next_insert -= 1
if self.next_insert == 0:
self.next_insert = 1000
self.con.commit()
cur.execute(
"INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)",
(
event,
code.co_name,
class_name,
str(file_name),
frame.f_lineno,
frame.f_back.__hash__(),
t_ns,
local_vars,
),
)
self.trace_count += 1
self.next_insert -= 1
if self.next_insert == 0:
self.next_insert = 1000
self.con.commit()

def trace_callback(self, frame: FrameType, event: str, arg: str | None) -> None:
# profiler section
Expand Down Expand Up @@ -475,8 +534,9 @@ def trace_dispatch_return(self, frame: FrameType, t: int) -> int:
cc = cc + 1

if pfn in callers:
callers[pfn] = callers[pfn] + 1 # TODO: gather more
# stats such as the amount of time added to ct courtesy
# Increment call count between these functions
callers[pfn] = callers[pfn] + 1
# Note: This tracks stats such as the amount of time added to ct
# of this specific call, and the contribution to cc
# courtesy of this call.
else:
Expand Down Expand Up @@ -703,7 +763,7 @@ def create_stats(self) -> None:

def snapshot_stats(self) -> None:
self.stats = {}
for func, (cc, _ns, tt, ct, caller_dict) in self.timings.items():
for func, (cc, _ns, tt, ct, caller_dict) in list(self.timings.items()):
callers = caller_dict.copy()
nc = 0
for callcnt in callers.values():
Expand Down
21 changes: 1 addition & 20 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ types-cffi = ">=1.16.0.20240331"
types-openpyxl = ">=3.1.5.20241020"
types-regex = ">=2024.9.11.20240912"
types-python-dateutil = ">=2.9.0.20241003"
pytest-cov = "^6.0.0"
pytest-benchmark = ">=5.1.0"
types-gevent = "^24.11.0.20241230"
types-greenlet = "^3.1.0.20241221"
Expand Down
Loading
Loading