Skip to content

Commit 7afd4f2

Browse files
authored
Merge branch 'main' into debug-cf
2 parents ce95abb + ec531ba commit 7afd4f2

File tree

14 files changed

+93
-68
lines changed

14 files changed

+93
-68
lines changed

.github/workflows/unit-tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ jobs:
3030
run: uv sync
3131

3232
- name: Unit tests
33-
run: uv run pytest tests/ --benchmark-skip -m "not ci_skip"
33+
run: uv run pytest tests/

codeflash/api/cfapi.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,16 @@ def mark_optimization_success(trace_id: str, *, is_optimization_found: bool) ->
261261
"""
262262
payload = {"trace_id": trace_id, "is_optimization_found": is_optimization_found}
263263
return make_cfapi_request(endpoint="/mark-as-success", method="POST", payload=payload)
264+
265+
266+
def send_completion_email() -> Response:
267+
"""Send an email notification when codeflash --all completes."""
268+
try:
269+
owner, repo = get_repo_owner_and_name()
270+
except Exception as e:
271+
sentry_sdk.capture_exception(e)
272+
response = requests.Response()
273+
response.status_code = 500
274+
return response
275+
payload = {"owner": owner, "repo": repo}
276+
return make_cfapi_request(endpoint="/send-completion-email", method="POST", payload=payload)

codeflash/benchmarking/plugin/plugin.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
from __future__ import annotations
22

3+
import importlib.util
34
import os
45
import sqlite3
56
import sys
67
import time
78
from pathlib import Path
9+
from typing import TYPE_CHECKING
810

911
import pytest
1012

1113
from codeflash.benchmarking.codeflash_trace import codeflash_trace
1214
from codeflash.code_utils.code_utils import module_name_from_file_path
13-
from codeflash.models.models import BenchmarkKey
15+
16+
if TYPE_CHECKING:
17+
from codeflash.models.models import BenchmarkKey
18+
19+
IS_PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
1420

1521

1622
# hello
@@ -72,6 +78,8 @@ def close(self) -> None:
7278

7379
@staticmethod
7480
def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[BenchmarkKey, int]]:
81+
from codeflash.models.models import BenchmarkKey
82+
7583
"""Process the trace file and extract timing data for all functions.
7684
7785
Args:
@@ -132,6 +140,8 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
132140

133141
@staticmethod
134142
def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
143+
from codeflash.models.models import BenchmarkKey
144+
135145
"""Extract total benchmark timings from trace files.
136146
137147
Args:
@@ -200,23 +210,6 @@ def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001, AR
200210
# Close the database connection
201211
self.close()
202212

203-
@staticmethod
204-
def pytest_addoption(parser: pytest.Parser) -> None:
205-
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")
206-
207-
@staticmethod
208-
def pytest_plugin_registered(plugin, manager) -> None: # noqa: ANN001
209-
# Not necessary since run with -p no:benchmark, but just in case
210-
if hasattr(plugin, "name") and plugin.name == "pytest-benchmark":
211-
manager.unregister(plugin)
212-
213-
@staticmethod
214-
def pytest_configure(config: pytest.Config) -> None:
215-
"""Register the benchmark marker."""
216-
config.addinivalue_line(
217-
"markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing"
218-
)
219-
220213
@staticmethod
221214
def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None:
222215
# Skip tests that don't have the benchmark fixture
@@ -260,8 +253,9 @@ def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003, ANN202
260253
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN202
261254
"""Actual benchmark implementation."""
262255
benchmark_module_path = module_name_from_file_path(
263-
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)
256+
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root), traverse_up=True
264257
)
258+
265259
benchmark_function_name = self.request.node.name
266260
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack # noqa: SLF001
267261
# Set env vars
@@ -287,13 +281,28 @@ def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003
287281

288282
return result
289283

290-
@staticmethod
291-
@pytest.fixture
292-
def benchmark(request: pytest.FixtureRequest) -> object:
293-
if not request.config.getoption("--codeflash-trace"):
294-
return None
295284

296-
return CodeFlashBenchmarkPlugin.Benchmark(request)
285+
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()
297286

298287

299-
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()
288+
def pytest_configure(config: pytest.Config) -> None:
289+
"""Register the benchmark marker and disable conflicting plugins."""
290+
config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing")
291+
292+
if config.getoption("--codeflash-trace") and IS_PYTEST_BENCHMARK_INSTALLED:
293+
config.option.benchmark_disable = True
294+
config.pluginmanager.set_blocked("pytest_benchmark")
295+
config.pluginmanager.set_blocked("pytest-benchmark")
296+
297+
298+
def pytest_addoption(parser: pytest.Parser) -> None:
299+
parser.addoption(
300+
"--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks"
301+
)
302+
303+
304+
@pytest.fixture
305+
def benchmark(request: pytest.FixtureRequest) -> object:
306+
if not request.config.getoption("--codeflash-trace"):
307+
return lambda func, *args, **kwargs: func(*args, **kwargs)
308+
return codeflash_benchmark_plugin.Benchmark(request)

codeflash/cli_cmds/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
234234
"I need a git repository to run --all and open PRs for optimizations. Exiting..."
235235
)
236236
apologize_and_exit()
237-
if not args.no_pr and not check_and_push_branch(git_repo):
237+
if not args.no_pr and not check_and_push_branch(git_repo, git_remote=args.git_remote):
238238
exit_with_message("Branch is not pushed...", error_on_exit=True)
239239
owner, repo = get_repo_owner_and_name(git_repo)
240240
if not args.no_pr:

codeflash/code_utils/code_utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,21 @@ def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
109109
return full_qualified_name[len(module_name) + 1 :]
110110

111111

112-
def module_name_from_file_path(file_path: Path, project_root_path: Path) -> str:
113-
relative_path = file_path.relative_to(project_root_path)
114-
return relative_path.with_suffix("").as_posix().replace("/", ".")
112+
def module_name_from_file_path(file_path: Path, project_root_path: Path, *, traverse_up: bool = False) -> str:
113+
try:
114+
relative_path = file_path.relative_to(project_root_path)
115+
return relative_path.with_suffix("").as_posix().replace("/", ".")
116+
except ValueError:
117+
if traverse_up:
118+
parent = file_path.parent
119+
while parent not in (project_root_path, parent.parent):
120+
try:
121+
relative_path = file_path.relative_to(parent)
122+
return relative_path.with_suffix("").as_posix().replace("/", ".")
123+
except ValueError:
124+
parent = parent.parent
125+
msg = f"File {file_path} is not within the project root {project_root_path}."
126+
raise ValueError(msg) # noqa: B904
115127

116128

117129
def file_path_from_module_name(module_name: str, project_root_path: Path) -> Path:

codeflash/code_utils/env_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def get_cached_gh_event_data() -> dict[str, Any]:
111111

112112
def is_repo_a_fork() -> bool:
113113
event = get_cached_gh_event_data()
114-
return bool(event.get("repository", {}).get("fork", False))
114+
return bool(event.get("pull_request", {}).get("head", {}).get("repo", {}).get("fork", False))
115115

116116

117117
@lru_cache(maxsize=1)

codeflash/code_utils/git_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,12 @@ def confirm_proceeding_with_no_git_repo() -> str | bool:
117117
return True
118118

119119

120-
def check_and_push_branch(repo: git.Repo, wait_for_push: bool = False) -> bool: # noqa: FBT001, FBT002
120+
def check_and_push_branch(repo: git.Repo, git_remote: str | None = "origin", wait_for_push: bool = False) -> bool: # noqa: FBT001, FBT002
121121
current_branch = repo.active_branch.name
122-
origin = repo.remote(name="origin")
122+
remote = repo.remote(name=git_remote)
123123

124124
# Check if the branch is pushed
125-
if f"origin/{current_branch}" not in repo.refs:
125+
if f"{git_remote}/{current_branch}" not in repo.refs:
126126
logger.warning(f"⚠️ The branch '{current_branch}' is not pushed to the remote repository.")
127127
if not sys.__stdin__.isatty():
128128
logger.warning("Non-interactive shell detected. Branch will not be pushed.")
@@ -132,13 +132,13 @@ def check_and_push_branch(repo: git.Repo, wait_for_push: bool = False) -> bool:
132132
f"the branch '{current_branch}' to the remote repository?",
133133
default=False,
134134
):
135-
origin.push(current_branch)
136-
logger.info(f"⬆️ Branch '{current_branch}' has been pushed to origin.")
135+
remote.push(current_branch)
136+
logger.info(f"⬆️ Branch '{current_branch}' has been pushed to {git_remote}.")
137137
if wait_for_push:
138138
time.sleep(3) # adding this to give time for the push to register with GitHub,
139139
# so that our modifications to it are not rejected
140140
return True
141-
logger.info(f"🔘 Branch '{current_branch}' has not been pushed to origin.")
141+
logger.info(f"🔘 Branch '{current_branch}' has not been pushed to {git_remote}.")
142142
return False
143143
logger.debug(f"The branch '{current_branch}' is present in the remote repository.")
144144
return True

codeflash/optimization/optimizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import TYPE_CHECKING
1010

1111
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
12+
from codeflash.api.cfapi import send_completion_email
1213
from codeflash.cli_cmds.console import console, logger, progress_bar
1314
from codeflash.code_utils import env_utils
1415
from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file
@@ -64,6 +65,7 @@ def run_benchmarks(
6465
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
6566
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
6667

68+
console.rule()
6769
with progress_bar(
6870
f"Running benchmarks in {self.args.benchmarks_root}", transient=True, revert_to_print=bool(get_pr_number())
6971
):
@@ -342,6 +344,11 @@ def run(self) -> None:
342344
logger.info("❌ No optimizations found.")
343345
elif self.args.all:
344346
logger.info("✨ All functions have been optimized! ✨")
347+
response = send_completion_email() # TODO: Include more details in the email
348+
if response.ok:
349+
logger.info("✅ Completion email sent successfully.")
350+
else:
351+
logger.warning("⚠️ Failed to send completion email. Status")
345352
finally:
346353
if function_optimizer:
347354
function_optimizer.cleanup_generated_files()

codeflash/result/create_pr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def check_create_pr(
185185
owner, repo = get_repo_owner_and_name(git_repo, git_remote)
186186
logger.info(f"Pushing to {git_remote} - Owner: {owner}, Repo: {repo}")
187187
console.rule()
188-
if not check_and_push_branch(git_repo, wait_for_push=True):
188+
if not check_and_push_branch(git_repo, git_remote, wait_for_push=True):
189189
logger.warning("⏭️ Branch is not pushed, skipping PR creation...")
190190
return
191191
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()

codeflash/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# These version placeholders will be replaced by uv-dynamic-versioning during build.
2-
__version__ = "0.15.5"
2+
__version__ = "0.15.6"

0 commit comments

Comments
 (0)