Skip to content

Commit c643183

Browse files
authored
Merge branch 'main' into part-1-windows-fixes
2 parents 02dc316 + ec531ba commit c643183

File tree

11 files changed

+85
-80
lines changed

11 files changed

+85
-80
lines changed

.github/workflows/unit-tests.yaml

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,4 @@ jobs:
3030
run: uv sync
3131

3232
- name: Unit tests
33-
run: uv run pytest tests/ --benchmark-skip -m "not ci_skip"
34-
35-
# unit-tests-windows:
36-
# runs-on: windows-latest
37-
# continue-on-error: true
38-
# steps:
39-
# - uses: actions/checkout@v4
40-
# with:
41-
# fetch-depth: 0
42-
# token: ${{ secrets.GITHUB_TOKEN }}
43-
44-
# - name: Install uv
45-
# uses: astral-sh/setup-uv@v5
46-
# with:
47-
# python-version: "3.11"
48-
49-
# - name: install dependencies
50-
# run: uv sync
51-
52-
# - name: Unit tests
53-
# run: uv run pytest tests/ --benchmark-skip -m "not ci_skip"
33+
run: uv run pytest tests/

codeflash/api/cfapi.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,16 @@ def mark_optimization_success(trace_id: str, *, is_optimization_found: bool) ->
261261
"""
262262
payload = {"trace_id": trace_id, "is_optimization_found": is_optimization_found}
263263
return make_cfapi_request(endpoint="/mark-as-success", method="POST", payload=payload)
264+
265+
266+
def send_completion_email() -> Response:
267+
"""Send an email notification when codeflash --all completes."""
268+
try:
269+
owner, repo = get_repo_owner_and_name()
270+
except Exception as e:
271+
sentry_sdk.capture_exception(e)
272+
response = requests.Response()
273+
response.status_code = 500
274+
return response
275+
payload = {"owner": owner, "repo": repo}
276+
return make_cfapi_request(endpoint="/send-completion-email", method="POST", payload=payload)

codeflash/benchmarking/plugin/plugin.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
from __future__ import annotations
22

3+
import importlib.util
34
import os
45
import sqlite3
56
import sys
67
import time
78
from pathlib import Path
9+
from typing import TYPE_CHECKING
810

911
import pytest
1012

1113
from codeflash.benchmarking.codeflash_trace import codeflash_trace
1214
from codeflash.code_utils.code_utils import module_name_from_file_path
13-
from codeflash.models.models import BenchmarkKey
15+
16+
if TYPE_CHECKING:
17+
from codeflash.models.models import BenchmarkKey
18+
19+
IS_PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
1420

1521

1622
class CodeFlashBenchmarkPlugin:
@@ -71,6 +77,8 @@ def close(self) -> None:
7177

7278
@staticmethod
7379
def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[BenchmarkKey, int]]:
80+
from codeflash.models.models import BenchmarkKey
81+
7482
"""Process the trace file and extract timing data for all functions.
7583
7684
Args:
@@ -131,6 +139,8 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
131139

132140
@staticmethod
133141
def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
142+
from codeflash.models.models import BenchmarkKey
143+
134144
"""Extract total benchmark timings from trace files.
135145
136146
Args:
@@ -199,23 +209,6 @@ def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001, AR
199209
# Close the database connection
200210
self.close()
201211

202-
@staticmethod
203-
def pytest_addoption(parser: pytest.Parser) -> None:
204-
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")
205-
206-
@staticmethod
207-
def pytest_plugin_registered(plugin, manager) -> None: # noqa: ANN001
208-
# Not necessary since run with -p no:benchmark, but just in case
209-
if hasattr(plugin, "name") and plugin.name == "pytest-benchmark":
210-
manager.unregister(plugin)
211-
212-
@staticmethod
213-
def pytest_configure(config: pytest.Config) -> None:
214-
"""Register the benchmark marker."""
215-
config.addinivalue_line(
216-
"markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing"
217-
)
218-
219212
@staticmethod
220213
def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None:
221214
# Skip tests that don't have the benchmark fixture
@@ -259,8 +252,9 @@ def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003, ANN202
259252
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN202
260253
"""Actual benchmark implementation."""
261254
benchmark_module_path = module_name_from_file_path(
262-
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)
255+
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root), traverse_up=True
263256
)
257+
264258
benchmark_function_name = self.request.node.name
265259
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack # noqa: SLF001
266260
# Set env vars
@@ -286,13 +280,28 @@ def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003
286280

287281
return result
288282

289-
@staticmethod
290-
@pytest.fixture
291-
def benchmark(request: pytest.FixtureRequest) -> object:
292-
if not request.config.getoption("--codeflash-trace"):
293-
return None
294283

295-
return CodeFlashBenchmarkPlugin.Benchmark(request)
284+
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()
296285

297286

298-
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()
287+
def pytest_configure(config: pytest.Config) -> None:
288+
"""Register the benchmark marker and disable conflicting plugins."""
289+
config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing")
290+
291+
if config.getoption("--codeflash-trace") and IS_PYTEST_BENCHMARK_INSTALLED:
292+
config.option.benchmark_disable = True
293+
config.pluginmanager.set_blocked("pytest_benchmark")
294+
config.pluginmanager.set_blocked("pytest-benchmark")
295+
296+
297+
def pytest_addoption(parser: pytest.Parser) -> None:
298+
parser.addoption(
299+
"--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks"
300+
)
301+
302+
303+
@pytest.fixture
304+
def benchmark(request: pytest.FixtureRequest) -> object:
305+
if not request.config.getoption("--codeflash-trace"):
306+
return lambda func, *args, **kwargs: func(*args, **kwargs)
307+
return codeflash_benchmark_plugin.Benchmark(request)

codeflash/code_utils/code_utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,21 @@ def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
109109
return full_qualified_name[len(module_name) + 1 :]
110110

111111

112-
def module_name_from_file_path(file_path: Path, project_root_path: Path) -> str:
113-
relative_path = file_path.relative_to(project_root_path)
114-
return relative_path.with_suffix("").as_posix().replace("/", ".")
112+
def module_name_from_file_path(file_path: Path, project_root_path: Path, *, traverse_up: bool = False) -> str:
113+
try:
114+
relative_path = file_path.relative_to(project_root_path)
115+
return relative_path.with_suffix("").as_posix().replace("/", ".")
116+
except ValueError:
117+
if traverse_up:
118+
parent = file_path.parent
119+
while parent not in (project_root_path, parent.parent):
120+
try:
121+
relative_path = file_path.relative_to(parent)
122+
return relative_path.with_suffix("").as_posix().replace("/", ".")
123+
except ValueError:
124+
parent = parent.parent
125+
msg = f"File {file_path} is not within the project root {project_root_path}."
126+
raise ValueError(msg) # noqa: B904
115127

116128

117129
def file_path_from_module_name(module_name: str, project_root_path: Path) -> Path:

codeflash/code_utils/env_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def get_cached_gh_event_data() -> dict[str, Any]:
110110

111111
def is_repo_a_fork() -> bool:
112112
event = get_cached_gh_event_data()
113-
return bool(event.get("repository", {}).get("fork", False))
113+
return bool(event.get("pull_request", {}).get("head", {}).get("repo", {}).get("fork", False))
114114

115115

116116
@lru_cache(maxsize=1)

codeflash/optimization/optimizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import TYPE_CHECKING
1010

1111
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
12+
from codeflash.api.cfapi import send_completion_email
1213
from codeflash.cli_cmds.console import console, logger, progress_bar
1314
from codeflash.code_utils import env_utils
1415
from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file
@@ -64,6 +65,7 @@ def run_benchmarks(
6465
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
6566
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
6667

68+
console.rule()
6769
with progress_bar(
6870
f"Running benchmarks in {self.args.benchmarks_root}", transient=True, revert_to_print=bool(get_pr_number())
6971
):
@@ -342,6 +344,11 @@ def run(self) -> None:
342344
logger.info("❌ No optimizations found.")
343345
elif self.args.all:
344346
logger.info("✨ All functions have been optimized! ✨")
347+
response = send_completion_email() # TODO: Include more details in the email
348+
if response.ok:
349+
logger.info("✅ Completion email sent successfully.")
350+
else:
351+
logger.warning("⚠️ Failed to send completion email. Status")
345352
finally:
346353
if function_optimizer:
347354
function_optimizer.cleanup_generated_files()

codeflash/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# These version placeholders will be replaced by uv-dynamic-versioning during build.
2-
__version__ = "0.15.5"
2+
__version__ = "0.15.6"

docs/docs/optimizing-with-codeflash/trace-and-optimize.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ To optimize code called by pytest tests that you could normally run like `python
1717
codeflash optimize -m pytest tests/
1818
```
1919

20-
This powerful command creates high-quality optimizations, making it ideal when you need to optimize a workflow or script. The initial tracing process can be slow, so try to limit your script's runtime to under 1 minute for best results. If your workflow is longer, consider tracing it into smaller sections by using the Codeflash tracer as a context manager (point 3 below).
20+
This powerful command creates high-quality optimizations, making it ideal when you need to optimize a workflow or script. The initial tracing process can be slow, so try to limit your script's runtime to under 1 minute for best results. If your workflow is longer, consider tracing it into smaller sections by using the Codeflash tracer as a context manager (point 3 below).
21+
22+
The generated replay tests and the trace file are for the immediate optimization use, don't add them to git.
2123

2224
## What is the codeflash optimize command?
2325

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ dev = [
6767
"types-openpyxl>=3.1.5.20241020",
6868
"types-regex>=2024.9.11.20240912",
6969
"types-python-dateutil>=2.9.0.20241003",
70-
"pytest-benchmark>=5.1.0",
7170
"types-gevent>=24.11.0.20241230,<25",
7271
"types-greenlet>=3.1.0.20241221,<4",
7372
"types-pexpect>=4.9.0.20241208,<5",
@@ -300,3 +299,6 @@ markers = [
300299
[build-system]
301300
requires = ["hatchling", "uv-dynamic-versioning"]
302301
build-backend = "hatchling.build"
302+
303+
[project.entry-points.pytest11]
304+
codeflash = "codeflash.benchmarking.plugin.plugin"

tests/test_trace_benchmarks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_trace_benchmarks() -> None:
3333
function_calls = cursor.fetchall()
3434

3535
# Assert the length of function calls
36-
assert len(function_calls) == 8, f"Expected 8 function calls, but got {len(function_calls)}"
36+
assert len(function_calls) == 7, f"Expected 7 function calls, but got {len(function_calls)}"
3737

3838
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
3939
process_and_bubble_sort_path = (project_root / "process_and_bubble_sort_codeflash_trace.py").as_posix()

0 commit comments

Comments
 (0)