diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index cce7208da..3a6f7dba2 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -3,21 +3,12 @@ from argparse import SUPPRESS, ArgumentParser, Namespace from pathlib import Path -import git - from codeflash.cli_cmds import logging_config from codeflash.cli_cmds.cli_common import apologize_and_exit from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions from codeflash.cli_cmds.console import logger from codeflash.code_utils import env_utils from codeflash.code_utils.config_parser import parse_config_file -from codeflash.code_utils.git_utils import ( - check_and_push_branch, - check_running_in_git_repo, - confirm_proceeding_with_no_git_repo, - get_repo_owner_and_name, -) -from codeflash.code_utils.github_utils import get_github_secrets_page_url, require_github_app_or_exit from codeflash.version import __version__ as version @@ -75,6 +66,13 @@ def parse_args() -> Namespace: def process_and_validate_cmd_args(args: Namespace) -> Namespace: + from codeflash.code_utils.git_utils import ( + check_running_in_git_repo, + confirm_proceeding_with_no_git_repo, + get_repo_owner_and_name, + ) + from codeflash.code_utils.github_utils import require_github_app_or_exit + is_init: bool = args.command.startswith("init") if args.command else False if args.verbose: logging_config.set_level(logging.DEBUG, echo_setting=not is_init) @@ -144,21 +142,26 @@ def process_pyproject_config(args: Namespace) -> Namespace: assert Path(args.benchmarks_root).resolve().is_relative_to(Path(args.tests_root).resolve()), ( f"--benchmarks-root {args.benchmarks_root} must be a subdirectory of --tests-root {args.tests_root}" ) - if env_utils.get_pr_number() is not None: - assert env_utils.ensure_codeflash_api_key(), ( - "Codeflash API key not found. When running in a Github Actions Context, provide the " - "'CODEFLASH_API_KEY' environment variable as a secret.\n" - "You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n" - "Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n" - f"Here's a direct link: {get_github_secrets_page_url()}\n" - "Exiting..." - ) + if env_utils.get_pr_number() is not None: + import git + + from codeflash.code_utils.git_utils import get_repo_owner_and_name + from codeflash.code_utils.github_utils import get_github_secrets_page_url, require_github_app_or_exit + + assert env_utils.ensure_codeflash_api_key(), ( + "Codeflash API key not found. When running in a Github Actions Context, provide the " + "'CODEFLASH_API_KEY' environment variable as a secret.\n" + "You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n" + "Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n" + f"Here's a direct link: {get_github_secrets_page_url()}\n" + "Exiting..." + ) - repo = git.Repo(search_parent_directories=True) + repo = git.Repo(search_parent_directories=True) - owner, repo_name = get_repo_owner_and_name(repo) + owner, repo_name = get_repo_owner_and_name(repo) - require_github_app_or_exit(owner, repo_name) + require_github_app_or_exit(owner, repo_name) if hasattr(args, "ignore_paths") and args.ignore_paths is not None: normalized_ignore_paths = [] @@ -187,6 +190,11 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path) def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace: if hasattr(args, "all"): + import git + + from codeflash.code_utils.git_utils import check_and_push_branch, get_repo_owner_and_name + from codeflash.code_utils.github_utils import require_github_app_or_exit + # Ensure that the user can actually open PRs on the repo. try: git_repo = git.Repo(search_parent_directories=True) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 8a2a89e95..0dcc2357f 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Optional import libcst as cst -import libcst.matchers as m from libcst.codemod import CodemodContext from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor from libcst.helpers import calculate_module_and_package @@ -248,6 +247,8 @@ class FutureAliasedImportTransformer(cst.CSTTransformer): def leave_ImportFrom( self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom ) -> cst.BaseSmallStatement | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel: + import libcst.matchers as m + if ( (updated_node_module := updated_node.module) and updated_node_module.value == "__future__" diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 0f97a983c..ba8929343 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -4,10 +4,9 @@ from collections import defaultdict from itertools import chain from pathlib import Path # noqa: TC003 +from typing import TYPE_CHECKING -import jedi import libcst as cst -from jedi.api.classes import Name # noqa: TC002 from libcst import CSTNode # noqa: TC002 from codeflash.cli_cmds.console import logger @@ -24,6 +23,9 @@ ) from codeflash.optimization.function_context import belongs_to_function_qualified +if TYPE_CHECKING: + from jedi.api.classes import Name + def get_code_optimization_context( function_to_optimize: FunctionToOptimize, @@ -354,6 +356,8 @@ def extract_code_markdown_context_from_files( def get_function_to_optimize_as_function_source( function_to_optimize: FunctionToOptimize, project_root_path: Path ) -> FunctionSource: + import jedi + # Use jedi to find function to optimize script = jedi.Script(path=function_to_optimize.file_path, project=jedi.Project(path=project_root_path)) @@ -389,6 +393,8 @@ def get_function_to_optimize_as_function_source( def get_function_sources_from_jedi( file_path_to_qualified_function_names: dict[Path, set[str]], project_root_path: Path ) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]: + import jedi + file_path_to_function_source = defaultdict(set) function_source_list: list[FunctionSource] = [] for file_path, qualified_function_names in file_path_to_qualified_function_names.items(): diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 77a357225..b76e63a91 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -12,7 +12,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Callable, Optional -import jedi import pytest from pydantic.dataclasses import dataclass @@ -281,6 +280,8 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N def process_test_files( file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig ) -> dict[str, list[FunctionCalledInTest]]: + import jedi + project_root_path = cfg.project_root_path test_framework = cfg.test_framework diff --git a/codeflash/main.py b/codeflash/main.py index 9eb22dde1..650bdbd63 100644 --- a/codeflash/main.py +++ b/codeflash/main.py @@ -11,7 +11,6 @@ from codeflash.cli_cmds.console import paneled_text from codeflash.code_utils.checkpoint import ask_should_use_checkpoint_get_functions from codeflash.code_utils.config_parser import parse_config_file -from codeflash.optimization import optimizer from codeflash.telemetry import posthog_cf from codeflash.telemetry.sentry import init_sentry @@ -41,6 +40,9 @@ def main() -> None: args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args) init_sentry(not args.disable_telemetry, exclude_errors=True) posthog_cf.initialize_posthog(not args.disable_telemetry) + + from codeflash.optimization import optimizer + optimizer.run_with_args(args) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 90c088ec6..9e5715e2a 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -9,30 +9,20 @@ from typing import TYPE_CHECKING from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient -from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator -from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin -from codeflash.benchmarking.replay_test import generate_replay_test -from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest -from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table from codeflash.cli_cmds.console import console, logger, progress_bar from codeflash.code_utils import env_utils -from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint -from codeflash.code_utils.code_replacer import normalize_code, normalize_node -from codeflash.code_utils.code_utils import cleanup_paths -from codeflash.code_utils.static_analysis import analyze_imported_modules, get_first_top_level_function_or_method_ast -from codeflash.discovery.discover_unit_tests import discover_unit_tests -from codeflash.discovery.functions_to_optimize import get_functions_to_optimize from codeflash.either import is_successful from codeflash.models.models import ValidCode -from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.telemetry.posthog_cf import ph from codeflash.verification.verification_utils import TestConfig if TYPE_CHECKING: from argparse import Namespace + from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import BenchmarkKey, FunctionCalledInTest + from codeflash.optimization.function_optimizer import FunctionOptimizer class Optimizer: @@ -63,6 +53,8 @@ def create_function_optimizer( function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None, total_benchmark_timings: dict[BenchmarkKey, float] | None = None, ) -> FunctionOptimizer: + from codeflash.optimization.function_optimizer import FunctionOptimizer + return FunctionOptimizer( function_to_optimize=function_to_optimize, test_cfg=self.test_cfg, @@ -77,6 +69,16 @@ def create_function_optimizer( ) def run(self) -> None: + from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint + from codeflash.code_utils.code_replacer import normalize_code, normalize_node + from codeflash.code_utils.code_utils import cleanup_paths + from codeflash.code_utils.static_analysis import ( + analyze_imported_modules, + get_first_top_level_function_or_method_ast, + ) + from codeflash.discovery.discover_unit_tests import discover_unit_tests + from codeflash.discovery.functions_to_optimize import get_functions_to_optimize + ph("cli-optimize-run-start") logger.info("Running optimizer.") console.rule() @@ -102,6 +104,12 @@ def run(self) -> None: function_benchmark_timings: dict[str, dict[BenchmarkKey, int]] = {} total_benchmark_timings: dict[BenchmarkKey, int] = {} if self.args.benchmark and num_optimizable_functions > 0: + from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator + from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin + from codeflash.benchmarking.replay_test import generate_replay_test + from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest + from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table + with progress_bar(f"Running benchmarks in {self.args.benchmarks_root}", transient=True): # Insert decorator file_path_to_source_code = defaultdict(str)