Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 29 additions & 21 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,12 @@
from argparse import SUPPRESS, ArgumentParser, Namespace
from pathlib import Path

import git

from codeflash.cli_cmds import logging_config
from codeflash.cli_cmds.cli_common import apologize_and_exit
from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions
from codeflash.cli_cmds.console import logger
from codeflash.code_utils import env_utils
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.code_utils.git_utils import (
check_and_push_branch,
check_running_in_git_repo,
confirm_proceeding_with_no_git_repo,
get_repo_owner_and_name,
)
from codeflash.code_utils.github_utils import get_github_secrets_page_url, require_github_app_or_exit
from codeflash.version import __version__ as version


Expand Down Expand Up @@ -75,6 +66,13 @@ def parse_args() -> Namespace:


def process_and_validate_cmd_args(args: Namespace) -> Namespace:
from codeflash.code_utils.git_utils import (
check_running_in_git_repo,
confirm_proceeding_with_no_git_repo,
get_repo_owner_and_name,
)
from codeflash.code_utils.github_utils import require_github_app_or_exit

is_init: bool = args.command.startswith("init") if args.command else False
if args.verbose:
logging_config.set_level(logging.DEBUG, echo_setting=not is_init)
Expand Down Expand Up @@ -144,21 +142,26 @@ def process_pyproject_config(args: Namespace) -> Namespace:
assert Path(args.benchmarks_root).resolve().is_relative_to(Path(args.tests_root).resolve()), (
f"--benchmarks-root {args.benchmarks_root} must be a subdirectory of --tests-root {args.tests_root}"
)
if env_utils.get_pr_number() is not None:
assert env_utils.ensure_codeflash_api_key(), (
"Codeflash API key not found. When running in a Github Actions Context, provide the "
"'CODEFLASH_API_KEY' environment variable as a secret.\n"
"You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n"
"Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n"
f"Here's a direct link: {get_github_secrets_page_url()}\n"
"Exiting..."
)
if env_utils.get_pr_number() is not None:
import git

from codeflash.code_utils.git_utils import get_repo_owner_and_name
from codeflash.code_utils.github_utils import get_github_secrets_page_url, require_github_app_or_exit

assert env_utils.ensure_codeflash_api_key(), (
"Codeflash API key not found. When running in a Github Actions Context, provide the "
"'CODEFLASH_API_KEY' environment variable as a secret.\n"
"You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n"
"Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n"
f"Here's a direct link: {get_github_secrets_page_url()}\n"
"Exiting..."
)

repo = git.Repo(search_parent_directories=True)
repo = git.Repo(search_parent_directories=True)

owner, repo_name = get_repo_owner_and_name(repo)
owner, repo_name = get_repo_owner_and_name(repo)

require_github_app_or_exit(owner, repo_name)
require_github_app_or_exit(owner, repo_name)

if hasattr(args, "ignore_paths") and args.ignore_paths is not None:
normalized_ignore_paths = []
Expand Down Expand Up @@ -187,6 +190,11 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path)

def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
if hasattr(args, "all"):
import git

from codeflash.code_utils.git_utils import check_and_push_branch, get_repo_owner_and_name
from codeflash.code_utils.github_utils import require_github_app_or_exit

# Ensure that the user can actually open PRs on the repo.
try:
git_repo = git.Repo(search_parent_directories=True)
Expand Down
3 changes: 2 additions & 1 deletion codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import TYPE_CHECKING, Optional

import libcst as cst
import libcst.matchers as m
from libcst.codemod import CodemodContext
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
from libcst.helpers import calculate_module_and_package
Expand Down Expand Up @@ -248,6 +247,8 @@ class FutureAliasedImportTransformer(cst.CSTTransformer):
def leave_ImportFrom(
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.BaseSmallStatement | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel:
import libcst.matchers as m

if (
(updated_node_module := updated_node.module)
and updated_node_module.value == "__future__"
Expand Down
10 changes: 8 additions & 2 deletions codeflash/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from collections import defaultdict
from itertools import chain
from pathlib import Path # noqa: TC003
from typing import TYPE_CHECKING

import jedi
import libcst as cst
from jedi.api.classes import Name # noqa: TC002
from libcst import CSTNode # noqa: TC002

from codeflash.cli_cmds.console import logger
Expand All @@ -24,6 +23,9 @@
)
from codeflash.optimization.function_context import belongs_to_function_qualified

if TYPE_CHECKING:
from jedi.api.classes import Name


def get_code_optimization_context(
function_to_optimize: FunctionToOptimize,
Expand Down Expand Up @@ -354,6 +356,8 @@ def extract_code_markdown_context_from_files(
def get_function_to_optimize_as_function_source(
function_to_optimize: FunctionToOptimize, project_root_path: Path
) -> FunctionSource:
import jedi

# Use jedi to find function to optimize
script = jedi.Script(path=function_to_optimize.file_path, project=jedi.Project(path=project_root_path))

Expand Down Expand Up @@ -389,6 +393,8 @@ def get_function_to_optimize_as_function_source(
def get_function_sources_from_jedi(
file_path_to_qualified_function_names: dict[Path, set[str]], project_root_path: Path
) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]:
import jedi

file_path_to_function_source = defaultdict(set)
function_source_list: list[FunctionSource] = []
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
Expand Down
3 changes: 2 additions & 1 deletion codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional

import jedi
import pytest
from pydantic.dataclasses import dataclass

Expand Down Expand Up @@ -281,6 +280,8 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
def process_test_files(
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
) -> dict[str, list[FunctionCalledInTest]]:
import jedi

project_root_path = cfg.project_root_path
test_framework = cfg.test_framework

Expand Down
4 changes: 3 additions & 1 deletion codeflash/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from codeflash.cli_cmds.console import paneled_text
from codeflash.code_utils.checkpoint import ask_should_use_checkpoint_get_functions
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.optimization import optimizer
from codeflash.telemetry import posthog_cf
from codeflash.telemetry.sentry import init_sentry

Expand Down Expand Up @@ -41,6 +40,9 @@ def main() -> None:
args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args)
init_sentry(not args.disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(not args.disable_telemetry)

from codeflash.optimization import optimizer

optimizer.run_with_args(args)


Expand Down
32 changes: 20 additions & 12 deletions codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,20 @@
from typing import TYPE_CHECKING

from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator
from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin
from codeflash.benchmarking.replay_test import generate_replay_test
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
from codeflash.cli_cmds.console import console, logger, progress_bar
from codeflash.code_utils import env_utils
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
from codeflash.code_utils.code_utils import cleanup_paths
from codeflash.code_utils.static_analysis import analyze_imported_modules, get_first_top_level_function_or_method_ast
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
from codeflash.either import is_successful
from codeflash.models.models import ValidCode
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.verification_utils import TestConfig

if TYPE_CHECKING:
from argparse import Namespace

from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import BenchmarkKey, FunctionCalledInTest
from codeflash.optimization.function_optimizer import FunctionOptimizer


class Optimizer:
Expand Down Expand Up @@ -63,6 +53,8 @@ def create_function_optimizer(
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
total_benchmark_timings: dict[BenchmarkKey, float] | None = None,
) -> FunctionOptimizer:
from codeflash.optimization.function_optimizer import FunctionOptimizer

return FunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=self.test_cfg,
Expand All @@ -77,6 +69,16 @@ def create_function_optimizer(
)

def run(self) -> None:
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
from codeflash.code_utils.code_utils import cleanup_paths
from codeflash.code_utils.static_analysis import (
analyze_imported_modules,
get_first_top_level_function_or_method_ast,
)
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize

ph("cli-optimize-run-start")
logger.info("Running optimizer.")
console.rule()
Expand All @@ -102,6 +104,12 @@ def run(self) -> None:
function_benchmark_timings: dict[str, dict[BenchmarkKey, int]] = {}
total_benchmark_timings: dict[BenchmarkKey, int] = {}
if self.args.benchmark and num_optimizable_functions > 0:
from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator
from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin
from codeflash.benchmarking.replay_test import generate_replay_test
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table

with progress_bar(f"Running benchmarks in {self.args.benchmarks_root}", transient=True):
# Insert decorator
file_path_to_source_code = defaultdict(str)
Expand Down
Loading