diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index c515db8b0..c6aaebfe3 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -62,6 +62,8 @@ def parse_args() -> Namespace: type=str, help="Path to the directory of the project, where all the pytest-benchmark tests are located.", ) + parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs") + args: Namespace = parser.parse_args() return process_and_validate_cmd_args(args) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index ec0c5c7d4..f1c22a9dd 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast +import json import os import tempfile import time @@ -86,6 +87,11 @@ def run(self) -> None: return if not env_utils.check_formatter_installed(self.args.formatter_cmds): return + + if self.args.no_draft and is_pr_draft(): + logger.warning("PR is in draft mode, skipping optimization") + return + function_optimizer = None file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] num_optimizable_functions: int @@ -301,3 +307,18 @@ def run_with_args(args: Namespace) -> None: optimizer.cleanup_temporary_paths() raise SystemExit from None + + +def is_pr_draft() -> bool: + """Check if the PR is draft. in the github action context.""" + try: + event_path = os.getenv("GITHUB_EVENT_PATH") + pr_number = get_pr_number() + if pr_number is not None and event_path: + with Path(event_path).open() as f: + event_data = json.load(f) + return bool(event_data["pull_request"]["draft"]) + return False # noqa + except Exception as e: + logger.warning(f"Error checking if PR is draft: {e}") + return False