Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7278ab3
Update replay_test.py
KRRT7 Jul 30, 2025
16d21e7
reinsert
KRRT7 Jul 30, 2025
cc84732
normalize
KRRT7 Jul 30, 2025
0e90f7a
Update test_instrument_tests.py
KRRT7 Jul 30, 2025
bb5996c
Update test_instrument_tests.py
KRRT7 Jul 30, 2025
3c4573e
Merge branch 'main' into part-1-windows-fixes
KRRT7 Sep 28, 2025
a826252
Update test_code_context_extractor.py
KRRT7 Sep 28, 2025
09e9d12
update context extractor
KRRT7 Sep 28, 2025
be1ae93
normalize paths and tmp_dir here too
KRRT7 Sep 28, 2025
328ccb2
windows shenanigans
KRRT7 Sep 28, 2025
e8547f9
Update test_instrument_tests.py
KRRT7 Sep 28, 2025
84b4054
Revert "Update test_instrument_tests.py"
KRRT7 Sep 28, 2025
c9cfaac
macos symlink shenanigans
KRRT7 Sep 28, 2025
d737e32
Update test_instrument_tests.py
KRRT7 Sep 28, 2025
b168638
test runner too
KRRT7 Sep 28, 2025
dca0f40
update to include oses
KRRT7 Sep 28, 2025
ceec0ed
normalize line prof
KRRT7 Sep 28, 2025
841f55b
normalize for trace and replay tests too
KRRT7 Sep 28, 2025
c7369e9
no timeout_decorator windows
KRRT7 Sep 28, 2025
4f84fb6
Update test_instrument_tests.py
KRRT7 Sep 28, 2025
02662c6
run unit tests on windows too
KRRT7 Sep 28, 2025
a1d5381
add E2E test on windows too
KRRT7 Sep 28, 2025
59a2b65
utf-8 encoding
KRRT7 Sep 28, 2025
6c26ad1
resolve paths in test
KRRT7 Sep 28, 2025
d953a1c
fix formatting
KRRT7 Sep 28, 2025
2c504ee
path_belongs_to_site_packages code review adjust
KRRT7 Sep 28, 2025
f978a40
Merge branch 'main' into part-1-windows-fixes
KRRT7 Sep 29, 2025
d305de8
Optimize generate_candidates
codeflash-ai[bot] Sep 29, 2025
d7b67a0
Merge pull request #783 from codeflash-ai/codeflash/optimize-pr363-20…
KRRT7 Sep 29, 2025
95ccec3
windows needs the env
KRRT7 Sep 30, 2025
fe90acf
no windows for now
KRRT7 Sep 30, 2025
824d8f6
ignore LP on windows for now?
KRRT7 Sep 30, 2025
0f51ff7
Optimize InitDecorator.visit_ClassDef
codeflash-ai[bot] Sep 30, 2025
91870c0
Merge pull request #784 from codeflash-ai/codeflash/optimize-pr363-20…
KRRT7 Sep 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ jobs:
uses: astral-sh/setup-uv@v5
with:
python-version: ${{ matrix.python-version }}
version: "0.5.30"

- name: install dependencies
run: uv sync
Expand Down
30 changes: 30 additions & 0 deletions .github/workflows/windows-unit-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: windows-unit-tests

on:
push:
branches: [main]
pull_request:
workflow_dispatch:

jobs:
windows-unit-tests:
continue-on-error: true
runs-on: windows-latest
env:
PYTHONIOENCODING: utf-8
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}

- name: Install uv
uses: astral-sh/setup-uv@v5
with:
python-version: "3.13"

- name: install dependencies
run: uv sync

- name: Unit tests
run: uv run pytest tests/
10 changes: 7 additions & 3 deletions codeflash/benchmarking/codeflash_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sqlite3
import threading
import time
from pathlib import Path
from typing import Any, Callable

from codeflash.picklepatch.pickle_patcher import PicklePatcher
Expand Down Expand Up @@ -143,12 +144,13 @@ def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003, ANN401
print("Pickle limit reached")
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
normalized_file_path = Path(func.__code__.co_filename).as_posix()
self.function_calls_data.append(
(
func.__name__,
class_name,
func.__module__,
func.__code__.co_filename,
normalized_file_path,
benchmark_function_name,
benchmark_module_path,
benchmark_line_number,
Expand All @@ -169,12 +171,13 @@ def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003, ANN401
# Add to the list of function calls without pickled args. Used for timing info only
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
normalized_file_path = Path(func.__code__.co_filename).as_posix()
self.function_calls_data.append(
(
func.__name__,
class_name,
func.__module__,
func.__code__.co_filename,
normalized_file_path,
benchmark_function_name,
benchmark_module_path,
benchmark_line_number,
Expand All @@ -192,12 +195,13 @@ def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003, ANN401
# Add to the list of function calls with pickled args, to be used for replay tests
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
normalized_file_path = Path(func.__code__.co_filename).as_posix()
self.function_calls_data.append(
(
func.__name__,
class_name,
func.__module__,
func.__code__.co_filename,
normalized_file_path,
benchmark_function_name,
benchmark_module_path,
benchmark_line_number,
Expand Down
15 changes: 10 additions & 5 deletions codeflash/benchmarking/replay_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,24 @@ def get_next_arg_and_return(
cur = db.cursor()
limit = num_to_get

normalized_file_path = Path(file_path).as_posix()

if class_name is not None:
cursor = cur.execute(
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?",
(benchmark_function_name, function_name, file_path, class_name, limit),
(benchmark_function_name, function_name, normalized_file_path, class_name, limit),
)
else:
cursor = cur.execute(
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?",
(benchmark_function_name, function_name, file_path, limit),
(benchmark_function_name, function_name, normalized_file_path, limit),
)

while (val := cursor.fetchone()) is not None:
yield val[9], val[10] # pickled_args, pickled_kwargs
try:
while (val := cursor.fetchone()) is not None:
yield val[9], val[10] # pickled_args, pickled_kwargs
finally:
db.close()


def get_function_alias(module: str, function_name: str) -> str:
Expand Down Expand Up @@ -166,7 +171,7 @@ def create_trace_replay_test_code(
module_name = func.get("module_name")
function_name = func.get("function_name")
class_name = func.get("class_name")
file_path = func.get("file_path")
file_path = Path(func.get("file_path")).as_posix()
benchmark_function_name = func.get("benchmark_function_name")
function_properties = func.get("function_properties")
if not class_name:
Expand Down
22 changes: 12 additions & 10 deletions codeflash/code_utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import datetime
import json
import sys
import time
import uuid
from pathlib import Path
Expand All @@ -11,13 +10,16 @@
from rich.prompt import Confirm

from codeflash.cli_cmds.console import console
from codeflash.code_utils.compat import codeflash_temp_dir

if TYPE_CHECKING:
import argparse


class CodeflashRunCheckpoint:
def __init__(self, module_root: Path, checkpoint_dir: Path = Path("/tmp")) -> None: # noqa: S108
def __init__(self, module_root: Path, checkpoint_dir: Path | None = None) -> None:
if checkpoint_dir is None:
checkpoint_dir = codeflash_temp_dir
self.module_root = module_root
self.checkpoint_dir = Path(checkpoint_dir)
# Create a unique checkpoint file name
Expand All @@ -37,7 +39,7 @@ def _initialize_checkpoint_file(self) -> None:
"last_updated": time.time(),
}

with self.checkpoint_path.open("w") as f:
with self.checkpoint_path.open("w", encoding="utf-8") as f:
f.write(json.dumps(metadata) + "\n")

def add_function_to_checkpoint(
Expand Down Expand Up @@ -66,7 +68,7 @@ def add_function_to_checkpoint(
**additional_info,
}

with self.checkpoint_path.open("a") as f:
with self.checkpoint_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(function_data) + "\n")

# Update the metadata last_updated timestamp
Expand All @@ -75,7 +77,7 @@ def add_function_to_checkpoint(
def _update_metadata_timestamp(self) -> None:
"""Update the last_updated timestamp in the metadata."""
# Read the first line (metadata)
with self.checkpoint_path.open() as f:
with self.checkpoint_path.open(encoding="utf-8") as f:
metadata = json.loads(f.readline())
rest_content = f.read()

Expand All @@ -84,7 +86,7 @@ def _update_metadata_timestamp(self) -> None:

# Write all lines to a temporary file

with self.checkpoint_path.open("w") as f:
with self.checkpoint_path.open("w", encoding="utf-8") as f:
f.write(json.dumps(metadata) + "\n")
f.write(rest_content)

Expand All @@ -94,7 +96,7 @@ def cleanup(self) -> None:
self.checkpoint_path.unlink(missing_ok=True)

for file in self.checkpoint_dir.glob("codeflash_checkpoint_*.jsonl"):
with file.open() as f:
with file.open(encoding="utf-8") as f:
# Skip the first line (metadata)
first_line = next(f)
metadata = json.loads(first_line)
Expand All @@ -116,7 +118,7 @@ def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dic
to_delete = []

for file in checkpoint_dir.glob("codeflash_checkpoint_*.jsonl"):
with file.open() as f:
with file.open(encoding="utf-8") as f:
# Skip the first line (metadata)
first_line = next(f)
metadata = json.loads(first_line)
Expand All @@ -139,8 +141,8 @@ def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dic

def ask_should_use_checkpoint_get_functions(args: argparse.Namespace) -> Optional[dict[str, dict[str, str]]]:
previous_checkpoint_functions = None
if args.all and (sys.platform == "linux" or sys.platform == "darwin") and Path("/tmp").is_dir(): # noqa: S108 #TODO: use the temp dir from codeutils-compat.py
previous_checkpoint_functions = get_all_historical_functions(args.module_root, Path("/tmp")) # noqa: S108
if args.all and codeflash_temp_dir.is_dir():
previous_checkpoint_functions = get_all_historical_functions(args.module_root, codeflash_temp_dir)
if previous_checkpoint_functions and Confirm.ask(
"Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point?",
default=True,
Expand Down
7 changes: 4 additions & 3 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def module_name_from_file_path(file_path: Path, project_root_path: Path, *, trav
parent = file_path.parent
while parent not in (project_root_path, parent.parent):
try:
relative_path = file_path.relative_to(parent)
relative_path = file_path.resolve().relative_to(parent.resolve())
return relative_path.with_suffix("").as_posix().replace("/", ".")
except ValueError:
parent = parent.parent
Expand Down Expand Up @@ -245,8 +245,9 @@ def get_run_tmp_file(file_path: Path) -> Path:


def path_belongs_to_site_packages(file_path: Path) -> bool:
site_packages = [Path(p) for p in site.getsitepackages()]
return any(file_path.resolve().is_relative_to(site_package_path) for site_package_path in site_packages)
file_path_resolved = file_path.resolve()
site_packages = [Path(p).resolve() for p in site.getsitepackages()]
return any(file_path_resolved.is_relative_to(site_package_path) for site_package_path in site_packages)


def is_class_defined_in_file(class_name: str, file_path: Path) -> bool:
Expand Down
24 changes: 16 additions & 8 deletions codeflash/code_utils/coverage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,25 @@ def build_fully_qualified_name(function_name: str, code_context: CodeOptimizatio
def generate_candidates(source_code_path: Path) -> set[str]:
"""Generate all the possible candidates for coverage data based on the source code path."""
candidates = set()
candidates.add(source_code_path.name)
current_path = source_code_path.parent

last_added = source_code_path.name
while current_path != current_path.parent:
candidate_path = str(Path(current_path.name) / last_added)
# Add the filename as a candidate
name = source_code_path.name
candidates.add(name)

# Precompute parts for efficient candidate path construction
parts = source_code_path.parts
n = len(parts)

# Walk up the directory structure without creating Path objects or repeatedly converting to posix
last_added = name
# Start from the last parent and move up to the root, exclusive (skip the root itself)
for i in range(n - 2, 0, -1):
# Combine the ith part with the accumulated path (last_added)
candidate_path = f"{parts[i]}/{last_added}"
candidates.add(candidate_path)
last_added = candidate_path
current_path = current_path.parent

candidates.add(str(source_code_path))
# Add the absolute posix path as a candidate
candidates.add(source_code_path.as_posix())
return candidates


Expand Down
11 changes: 5 additions & 6 deletions codeflash/code_utils/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,17 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
if formatter_cmds[0] == "disabled":
return return_code
tmp_code = """print("hello world")"""
with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", suffix=".py") as f:
f.write(tmp_code)
f.flush()
tmp_file = Path(f.name)
with tempfile.TemporaryDirectory() as tmpdir:
tmp_file = Path(tmpdir) / "test_codeflash_formatter.py"
tmp_file.write_text(tmp_code, encoding="utf-8")
try:
format_code(formatter_cmds, tmp_file, print_status=False, exit_on_failure=exit_on_failure)
except Exception:
exit_with_message(
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again.",
error_on_exit=True,
)
return return_code
return return_code


@lru_cache(maxsize=1)
Expand Down Expand Up @@ -121,7 +120,7 @@ def get_cached_gh_event_data() -> dict[str, Any]:
event_path = os.getenv("GITHUB_EVENT_PATH")
if not event_path:
return {}
with Path(event_path).open() as f:
with Path(event_path).open(encoding="utf-8") as f:
return json.load(f) # type: ignore # noqa


Expand Down
12 changes: 9 additions & 3 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import ast
import platform
from pathlib import Path
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -143,7 +144,10 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
if node.name.startswith("test_"):
did_update = False
if self.test_framework == "unittest":
if self.test_framework == "unittest" and platform.system() != "Windows":
# Only add timeout decorator on non-Windows platforms
# Windows doesn't support SIGALRM signal required by timeout_decorator

node.decorator_list.append(
ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
Expand Down Expand Up @@ -220,7 +224,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None =
args=[
ast.JoinedStr(
values=[
ast.Constant(value=f"{get_run_tmp_file(Path('test_return_values_'))}"),
ast.Constant(
value=f"{get_run_tmp_file(Path('test_return_values_')).as_posix()}"
),
ast.FormattedValue(
value=ast.Name(id="codeflash_iteration", ctx=ast.Load()),
conversion=-1,
Expand Down Expand Up @@ -588,7 +594,7 @@ def inject_profiling_into_existing_test(
new_imports.extend(
[ast.Import(names=[ast.alias(name="sqlite3")]), ast.Import(names=[ast.alias(name="dill", asname="pickle")])]
)
if test_framework == "unittest":
if test_framework == "unittest" and platform.system() != "Windows":
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
additional_functions = [create_wrapper_function(mode)]

Expand Down
2 changes: 1 addition & 1 deletion codeflash/code_utils/line_profile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,6 @@ def add_decorator_imports(function_to_optimize: FunctionToOptimize, code_context
file.write(modified_code)
# Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
file_contents = function_to_optimize.file_path.read_text("utf-8")
modified_code = add_profile_enable(file_contents, str(line_profile_output_file))
modified_code = add_profile_enable(file_contents, line_profile_output_file.as_posix())
function_to_optimize.file_path.write_text(modified_code, "utf-8")
return line_profile_output_file
2 changes: 1 addition & 1 deletion codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def markdown(self) -> str:
"""
return "\n".join(
[
f"```python{':' + str(code_string.file_path) if code_string.file_path else ''}\n{code_string.code.strip()}\n```"
f"```python{':' + code_string.file_path.as_posix() if code_string.file_path else ''}\n{code_string.code.strip()}\n```"
for code_string in self.code_strings
]
)
Expand Down
10 changes: 5 additions & 5 deletions codeflash/result/create_pr.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def existing_tests_source_for(
):
print_optimized_runtime = format_time(optimized_tests_to_runtimes[filename][qualified_name])
print_original_runtime = format_time(original_tests_to_runtimes[filename][qualified_name])
print_filename = filename.relative_to(tests_root)
print_filename = filename.resolve().relative_to(tests_root.resolve()).as_posix()
greater = (
optimized_tests_to_runtimes[filename][qualified_name]
> original_tests_to_runtimes[filename][qualified_name]
Expand Down Expand Up @@ -192,9 +192,9 @@ def check_create_pr(
if pr_number is not None:
logger.info(f"Suggesting changes to PR #{pr_number} ...")
owner, repo = get_repo_owner_and_name(git_repo)
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
relative_path = explanation.file_path.resolve().relative_to(root_dir.resolve()).as_posix()
build_file_changes = {
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(
Path(p).resolve().relative_to(root_dir.resolve()).as_posix(): FileDiffContent(
oldContent=original_code[p], newContent=new_code[p]
)
for p in original_code
Expand Down Expand Up @@ -243,10 +243,10 @@ def check_create_pr(
if not check_and_push_branch(git_repo, git_remote, wait_for_push=True):
logger.warning("⏭️ Branch is not pushed, skipping PR creation...")
return
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
relative_path = explanation.file_path.resolve().relative_to(root_dir.resolve()).as_posix()
base_branch = get_current_branch()
build_file_changes = {
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(
Path(p).resolve().relative_to(root_dir.resolve()).as_posix(): FileDiffContent(
oldContent=original_code[p], newContent=new_code[p]
)
for p in original_code
Expand Down
Loading
Loading