diff --git a/scripts/profiling/apis.py b/scripts/profiling/apis.py deleted file mode 100644 index 24c6ff782..000000000 --- a/scripts/profiling/apis.py +++ /dev/null @@ -1,68 +0,0 @@ -import base64 -import logging -import pickle -from pathlib import Path - -import networkx as nx -from tabulate import tabulate - -from codegen.sdk.codebase.factory.get_dev_customer_codebase import get_codebase_codegen -from codegen.shared.enums.programming_language import ProgrammingLanguage - -logging.basicConfig(level=logging.INFO) -codegen = get_codebase_codegen("../codegen", ".") -res = [] - -# Create a directed graph -G = nx.DiGraph() - -# Iterate through all files in the codebase -for file in codegen.files: - if "test" not in file.filepath: - if file.ctx.repo_name == "codegen": - color = "yellow" - elif file.ctx.repo_name == "codegen-sdk": - color = "red" - if file.ctx.base_path == "codegen-git": - color = "green" - elif file.ctx.base_path == "codegen-utils": - color = "purple" - else: - color = "yellow" - if file.ctx.programming_language == ProgrammingLanguage.TYPESCRIPT: - color = "blue" - - # Add the file as a node - G.add_node(file.filepath, color=color) - - # Iterate through all imports in the file - for imp in file.imports: - if imp.from_file: - # print("way") - # Add an edge from the current file to the imported file - G.add_edge(file.filepath, imp.from_file.filepath) - -for url, func in codegen.global_context.multigraph.api_definitions.items(): - usages = codegen.global_context.multigraph.usages.get(url, None) - if usages: - res.append((url, func.name, func.filepath, [usage.filepath + "::" + usage.name for usage in usages])) - for usage in usages: - G.add_edge(usage.filepath, func.filepath) -print(tabulate(res, headers=["URL", "Function Name", "Filepath", "Usages"])) -# Visualize the graph -# codebase.visualize(G) - -# Print some basic statistics -print(f"Number of files: {G.number_of_nodes()}") -print(f"Number of import relationships: {G.number_of_edges()}") - -# Serialize the graph to a pickle -graph_pickle = pickle.dumps(G) - -# Convert to base64 -graph_base64 = base64.b64encode(graph_pickle).decode("utf-8") - -print(f"Base64 encoded graph size: {len(graph_base64)} bytes 🌈") -print(f"Base64 string: {graph_base64}") -Path("out.txt").write_text(graph_base64) -print() diff --git a/src/codegen/cli/auth/session.py b/src/codegen/cli/auth/session.py index d06454330..8dfa2ebf3 100644 --- a/src/codegen/cli/auth/session.py +++ b/src/codegen/cli/auth/session.py @@ -7,7 +7,7 @@ from codegen.cli.git.repo import get_git_repo from codegen.cli.rich.codeblocks import format_command -from codegen.configs.constants import CODEGEN_DIR_NAME, ENV_FILENAME +from codegen.configs.constants import CODEGEN_DIR_NAME from codegen.configs.session_manager import session_manager from codegen.configs.user_config import UserConfig from codegen.git.repo_operator.local_git_repo import LocalGitRepo @@ -30,7 +30,7 @@ def __init__(self, repo_path: Path, git_token: str | None = None) -> None: self.repo_path = repo_path self.local_git = LocalGitRepo(repo_path=repo_path) self.codegen_dir = repo_path / CODEGEN_DIR_NAME - self.config = UserConfig(env_filepath=repo_path / ENV_FILENAME) + self.config = UserConfig(root_path=repo_path) self.config.secrets.github_token = git_token or self.config.secrets.github_token self.existing = session_manager.get_session(repo_path) is not None diff --git a/src/codegen/cli/commands/config/main.py b/src/codegen/cli/commands/config/main.py index b4ec3f3d7..9fce6f9b4 100644 --- a/src/codegen/cli/commands/config/main.py +++ b/src/codegen/cli/commands/config/main.py @@ -4,7 +4,7 @@ import rich_click as click from rich.table import Table -from codegen.configs.constants import ENV_FILENAME, GLOBAL_ENV_FILE +from codegen.configs.constants import ENV_FILENAME, GLOBAL_CONFIG_DIR from codegen.configs.user_config import UserConfig from codegen.shared.path import get_git_root_path @@ -117,8 +117,8 @@ def set_command(key: str, value: str): def _get_user_config() -> UserConfig: if (project_root := get_git_root_path()) is None: - env_filepath = GLOBAL_ENV_FILE + root_path = GLOBAL_CONFIG_DIR else: - env_filepath = project_root / ENV_FILENAME + root_path = project_root - return UserConfig(env_filepath) + return UserConfig(root_path) diff --git a/src/codegen/cli/commands/run/run_local.py b/src/codegen/cli/commands/run/run_local.py index 4ca737dd1..0ae1a47cf 100644 --- a/src/codegen/cli/commands/run/run_local.py +++ b/src/codegen/cli/commands/run/run_local.py @@ -6,8 +6,8 @@ from codegen.cli.auth.session import CodegenSession from codegen.cli.utils.function_finder import DecoratedFunction +from codegen.configs.models.repository import RepositoryConfig from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.git.schemas.repo_config import RepoConfig from codegen.git.utils.language import determine_project_language from codegen.sdk.codebase.config import ProjectConfig from codegen.sdk.core.codebase import Codebase @@ -30,7 +30,7 @@ def parse_codebase( codebase = Codebase( projects=[ ProjectConfig( - repo_operator=RepoOperator(repo_config=RepoConfig.from_repo_path(repo_path=repo_path)), + repo_operator=RepoOperator(repo_config=RepositoryConfig.from_path(path=repo_path)), subdirectories=subdirectories, programming_language=language or determine_project_language(repo_path), ) diff --git a/src/codegen/cli/commands/start/main.py b/src/codegen/cli/commands/start/main.py index 652e400e1..af5989630 100644 --- a/src/codegen/cli/commands/start/main.py +++ b/src/codegen/cli/commands/start/main.py @@ -10,9 +10,9 @@ from codegen.cli.commands.start.docker_container import DockerContainer from codegen.cli.commands.start.docker_fleet import CODEGEN_RUNNER_IMAGE +from codegen.configs.models.repository import RepositoryConfig from codegen.configs.models.secrets import SecretsConfig from codegen.git.repo_operator.local_git_repo import LocalGitRepo -from codegen.git.schemas.repo_config import RepoConfig from codegen.shared.network.port import get_free_port _default_host = "0.0.0.0" @@ -26,7 +26,7 @@ def start_command(port: int | None, detached: bool = False, skip_build: bool = False, force: bool = False) -> None: """Starts a local codegen server""" repo_path = Path.cwd().resolve() - repo_config = RepoConfig.from_repo_path(str(repo_path)) + repo_config = LocalGitRepo(repo_path=repo_path).get_repo_config() if (container := DockerContainer.get(repo_config.name)) is not None: if force: rich.print(f"[yellow]Removing existing runner {repo_config.name} to force restart[/yellow]") @@ -50,7 +50,7 @@ def start_command(port: int | None, detached: bool = False, skip_build: bool = F raise click.Abort() -def _handle_existing_container(repo_config: RepoConfig, container: DockerContainer) -> None: +def _handle_existing_container(repo_config: RepositoryConfig, container: DockerContainer) -> None: if container.is_running(): rich.print( Panel( @@ -122,20 +122,21 @@ def _get_platform() -> str: return "linux/amd64" -def _run_docker_container(repo_config: RepoConfig, port: int, detached: bool) -> None: +def _run_docker_container(repo_config: RepositoryConfig, port: int, detached: bool) -> None: rich.print("[bold blue]Starting Docker container...[/bold blue]") container_repo_path = f"/app/git/{repo_config.name}" name_args = ["--name", f"{repo_config.name}"] + repo_path = Path(repo_config.path) envvars = { - "REPOSITORY_LANGUAGE": repo_config.language.value, - "REPOSITORY_OWNER": LocalGitRepo(repo_config.repo_path).owner, + "REPOSITORY_LANGUAGE": repo_config.language, + "REPOSITORY_OWNER": LocalGitRepo(repo_path=repo_path).owner, "REPOSITORY_PATH": container_repo_path, - "GITHUB_TOKEN": SecretsConfig().github_token, + "GITHUB_TOKEN": SecretsConfig(root_path=repo_path).github_token, "PYTHONUNBUFFERED": "1", # Ensure Python output is unbuffered "CODEBASE_SYNC_ENABLED": "True", } envvars_args = [arg for k, v in envvars.items() for arg in ("--env", f"{k}={v}")] - mount_args = ["-v", f"{repo_config.repo_path}:{container_repo_path}"] + mount_args = ["-v", f"{repo_config.path}:{container_repo_path}"] entry_point = f"uv run --frozen uvicorn codegen.runner.servers.local_daemon:app --host {_default_host} --port {port}" port_args = ["-p", f"{port}:{port}"] detached_args = ["-d"] if detached else [] diff --git a/src/codegen/cli/mcp/resources/system_prompt.py b/src/codegen/cli/mcp/resources/system_prompt.py index 9c7e23c6b..7da246c9a 100644 --- a/src/codegen/cli/mcp/resources/system_prompt.py +++ b/src/codegen/cli/mcp/resources/system_prompt.py @@ -1417,7 +1417,6 @@ def baz(): ```python from codegen import Codebase from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.git.schemas.repo_config import RepoConfig from codegen.sdk.codebase.config import ProjectConfig from codegen.shared.enums.programming_language import ProgrammingLanguage @@ -1425,7 +1424,7 @@ def baz(): projects = [ ProjectConfig( repo_operator=RepoOperator( - repo_config=RepoConfig(name="codegen-sdk"), + repo_path="/tmp/codegen-sdk", bot_commit=True ), programming_language=ProgrammingLanguage.TYPESCRIPT, diff --git a/src/codegen/configs/models/base_config.py b/src/codegen/configs/models/base_config.py index 3a82223de..2b85a4efb 100644 --- a/src/codegen/configs/models/base_config.py +++ b/src/codegen/configs/models/base_config.py @@ -16,7 +16,8 @@ class BaseConfig(BaseSettings, ABC): model_config = SettingsConfigDict(extra="ignore", case_sensitive=False) - def __init__(self, prefix: str, env_filepath: Path | None = None, *args, **kwargs) -> None: + def __init__(self, prefix: str, root_path: Path | None = None, *args, **kwargs) -> None: + env_filepath = root_path / ENV_FILENAME if root_path else None if env_filepath is None: root_path = get_git_root_path() if root_path is not None: diff --git a/src/codegen/configs/models/repository.py b/src/codegen/configs/models/repository.py index d4960c503..1aaeb170c 100644 --- a/src/codegen/configs/models/repository.py +++ b/src/codegen/configs/models/repository.py @@ -1,4 +1,6 @@ import os +from pathlib import Path +from typing import Self from codegen.configs.models.base_config import BaseConfig @@ -13,16 +15,15 @@ class RepositoryConfig(BaseConfig): language: str | None = None user_name: str | None = None user_email: str | None = None + subdirectories: list[str] | None = None + base_path: str | None = None # root module of the parsed codebase def __init__(self, prefix: str = "REPOSITORY", *args, **kwargs) -> None: super().__init__(prefix=prefix, *args, **kwargs) - def _initialize( - self, - ) -> None: - """Initialize the repository config""" - if self.path is None: - self.path = os.getcwd() + @classmethod + def from_path(cls, path: str) -> Self: + return cls(root_path=Path(path), path=str(path)) @property def base_dir(self) -> str: diff --git a/src/codegen/configs/user_config.py b/src/codegen/configs/user_config.py index ecebec4d1..cc3219dca 100644 --- a/src/codegen/configs/user_config.py +++ b/src/codegen/configs/user_config.py @@ -3,6 +3,7 @@ from pydantic import Field +from codegen.configs.constants import ENV_FILENAME from codegen.configs.models.codebase import CodebaseConfig from codegen.configs.models.repository import RepositoryConfig from codegen.configs.models.secrets import SecretsConfig @@ -14,11 +15,11 @@ class UserConfig: codebase: CodebaseConfig = Field(default_factory=CodebaseConfig) secrets: SecretsConfig = Field(default_factory=SecretsConfig) - def __init__(self, env_filepath: Path) -> None: - self.env_filepath = env_filepath - self.secrets = SecretsConfig(env_filepath=env_filepath) - self.repository = RepositoryConfig(env_filepath=env_filepath) - self.codebase = CodebaseConfig(env_filepath=env_filepath) + def __init__(self, root_path: Path) -> None: + self.env_filepath = root_path / ENV_FILENAME + self.secrets = SecretsConfig(root_path=root_path) + self.repository = RepositoryConfig(root_path=root_path) + self.codebase = CodebaseConfig(root_path=root_path) def save(self) -> None: """Save configuration to the config file.""" diff --git a/src/codegen/extensions/events/modal/base.py b/src/codegen/extensions/events/modal/base.py index 64bdf5b28..c7bc19457 100644 --- a/src/codegen/extensions/events/modal/base.py +++ b/src/codegen/extensions/events/modal/base.py @@ -8,7 +8,6 @@ from codegen.extensions.events.codegen_app import CodegenApp from codegen.extensions.events.modal.request_util import fastapi_request_adapter from codegen.git.clients.git_repo_client import GitRepoClient -from codegen.git.schemas.repo_config import RepoConfig logging.basicConfig(level=logging.INFO, force=True) logger = logging.getLogger(__name__) @@ -36,17 +35,11 @@ def get_event_handler_cls(self) -> modal.Cls: raise NotImplementedError(msg) async def handle_event(self, org: str, repo: str, provider: Literal["slack", "github", "linear"], request: Request): - repo_config = RepoConfig( - name=repo, - full_name=f"{org}/{repo}", - ) - repo_snapshotdict = modal.Dict.from_name(self.snapshot_index_id, {}, create_if_missing=True) - last_snapshot_commit = repo_snapshotdict.get(f"{org}/{repo}", None) if last_snapshot_commit is None: - git_client = GitRepoClient(repo_config=repo_config, access_token=os.environ["GITHUB_ACCESS_TOKEN"]) + git_client = GitRepoClient(repo_full_name=f"{org}/{repo}", access_token=os.environ["GITHUB_ACCESS_TOKEN"]) branch = git_client.get_branch_safe(git_client.default_branch) last_snapshot_commit = branch.commit.sha if branch and branch.commit else None @@ -76,15 +69,8 @@ def refresh_repository_snapshots(self, snapshot_index_id: str): try: # Parse the repository full name to get org and repo org, repo = repo_full_name.split("/") - - # Create a RepoConfig for the repository - repo_config = RepoConfig( - name=repo, - full_name=repo_full_name, - ) - # Initialize the GitRepoClient to fetch the latest commit - git_client = GitRepoClient(repo_config=repo_config, access_token=os.environ["GITHUB_ACCESS_TOKEN"]) + git_client = GitRepoClient(repo_full_name=repo_full_name, access_token=os.environ["GITHUB_ACCESS_TOKEN"]) # Get the default branch and its latest commit branch = git_client.get_branch_safe(git_client.default_branch) diff --git a/src/codegen/git/clients/git_repo_client.py b/src/codegen/git/clients/git_repo_client.py index 79c8df2a4..c1441e4c6 100644 --- a/src/codegen/git/clients/git_repo_client.py +++ b/src/codegen/git/clients/git_repo_client.py @@ -16,7 +16,6 @@ from codegen.configs.models.secrets import SecretsConfig from codegen.git.clients.github_client import GithubClient -from codegen.git.schemas.repo_config import RepoConfig from codegen.git.utils.format import format_comparison from codegen.shared.logging.get_logger import get_logger @@ -26,12 +25,12 @@ class GitRepoClient: """Wrapper around PyGithub's Remote Repository.""" - repo_config: RepoConfig + repo_full_name: str gh_client: GithubClient _repo: Repository - def __init__(self, repo_config: RepoConfig, access_token: str | None = None) -> None: - self.repo_config = repo_config + def __init__(self, repo_full_name: str, access_token: str | None = None) -> None: + self.repo_full_name = repo_full_name self.gh_client = self._create_github_client(token=access_token or SecretsConfig().github_token) self._repo = self._create_client() @@ -39,9 +38,9 @@ def _create_github_client(self, token: str) -> GithubClient: return GithubClient(token=token) def _create_client(self) -> Repository: - client = self.gh_client.get_repo_by_full_name(self.repo_config.full_name) + client = self.gh_client.get_repo_by_full_name(self.repo_full_name) if not client: - msg = f"Repo {self.repo_config.full_name} not found!" + msg = f"Repo {self.repo_full_name} not found!" raise ValueError(msg) return client diff --git a/src/codegen/git/repo_operator/local_git_repo.py b/src/codegen/git/repo_operator/local_git_repo.py index a5c4acea3..dc3c6a443 100644 --- a/src/codegen/git/repo_operator/local_git_repo.py +++ b/src/codegen/git/repo_operator/local_git_repo.py @@ -6,8 +6,8 @@ from git import Repo from git.remote import Remote +from codegen.configs.models.repository import RepositoryConfig from codegen.git.clients.git_repo_client import GitRepoClient -from codegen.git.schemas.repo_config import RepoConfig from codegen.git.utils.language import determine_project_language @@ -20,7 +20,11 @@ def __init__(self, repo_path: Path): @cached_property def git_cli(self) -> Repo: - return Repo(self.repo_path) + if not os.path.exists(self.repo_path): + os.makedirs(self.repo_path) + return Repo.init(self.repo_path) + else: + return Repo(self.repo_path) @cached_property def name(self) -> str: @@ -72,9 +76,7 @@ def user_email(self) -> str | None: def get_language(self, access_token: str | None = None) -> str: """Returns the majority language of the repository""" if access_token is not None: - repo_config = RepoConfig.from_repo_path(repo_path=str(self.repo_path)) - repo_config.full_name = self.full_name - remote_git = GitRepoClient(repo_config=repo_config, access_token=access_token) + remote_git = GitRepoClient(repo_full_name=self.full_name, access_token=access_token) if (language := remote_git.repo.language) is not None: return language.upper() @@ -82,3 +84,12 @@ def get_language(self, access_token: str | None = None) -> str: def has_remote(self) -> bool: return bool(self.git_cli.remotes) + + def get_repo_config(self, access_token: str | None = None, repo_config: RepositoryConfig | None = None) -> RepositoryConfig: + config = repo_config or RepositoryConfig() + config.path = config.path or str(self.repo_path) + config.owner = config.owner or self.owner + config.user_name = config.user_name or self.user_name + config.user_email = config.user_email or self.user_email + config.language = config.language or self.get_language(access_token=access_token).upper() + return config diff --git a/src/codegen/git/repo_operator/repo_operator.py b/src/codegen/git/repo_operator/repo_operator.py index 9f1f58007..dad10e62c 100644 --- a/src/codegen/git/repo_operator/repo_operator.py +++ b/src/codegen/git/repo_operator/repo_operator.py @@ -5,6 +5,7 @@ from collections.abc import Generator from datetime import UTC, datetime from functools import cached_property +from pathlib import Path from time import perf_counter from typing import Self @@ -16,17 +17,18 @@ from github.IssueComment import IssueComment from github.PullRequest import PullRequest +from codegen.configs.models.repository import RepositoryConfig from codegen.configs.models.secrets import SecretsConfig from codegen.git.clients.git_repo_client import GitRepoClient from codegen.git.configs.constants import CODEGEN_BOT_EMAIL, CODEGEN_BOT_NAME from codegen.git.repo_operator.local_git_repo import LocalGitRepo from codegen.git.schemas.enums import CheckoutResult, FetchResult, SetupOption -from codegen.git.schemas.repo_config import RepoConfig from codegen.git.utils.clone import clone_or_pull_repo, clone_repo, pull_repo -from codegen.git.utils.clone_url import add_access_token_to_url, get_authenticated_clone_url_for_repo_config, get_clone_url_for_repo_config, url_to_github +from codegen.git.utils.clone_url import get_authenticated_clone_url_for_repo_config, get_clone_url_for_repo_config, url_to_github from codegen.git.utils.codeowner_utils import create_codeowners_parser_for_repo from codegen.git.utils.file_utils import create_files from codegen.git.utils.remote_progress import CustomRemoteProgress +from codegen.shared.enums.programming_language import ProgrammingLanguage from codegen.shared.logging.get_logger import get_logger from codegen.shared.performance.stopwatch_utils import stopwatch from codegen.shared.performance.time_utils import humanize_duration @@ -37,11 +39,10 @@ class RepoOperator: """A wrapper around GitPython to make it easier to interact with a repo.""" - repo_config: RepoConfig - base_dir: str - bot_commit: bool = True - access_token: str | None = None - + repo_config: RepositoryConfig + bot_commit: bool + access_token: str | None + respect_gitignore: bool # lazy attributes _codeowners_parser: CodeOwnersParser | None = None _default_branch: str | None = None @@ -50,17 +51,18 @@ class RepoOperator: def __init__( self, - repo_config: RepoConfig, + repo_config: RepositoryConfig, access_token: str | None = None, bot_commit: bool = False, + respect_gitignore: bool = True, setup_option: SetupOption | None = None, shallow: bool | None = None, ) -> None: - assert repo_config is not None - self.repo_config = repo_config - self.access_token = access_token or SecretsConfig().github_token - self.base_dir = repo_config.base_dir + self.access_token = access_token or SecretsConfig(root_path=Path(repo_config.path)).github_token + self._local_git_repo = LocalGitRepo(repo_path=Path(repo_config.path)) + self.repo_config = self._local_git_repo.get_repo_config(self.access_token, repo_config) self.bot_commit = bot_commit + self.respect_gitignore = respect_gitignore if setup_option: if shallow is not None: @@ -68,13 +70,6 @@ def __init__( else: self.setup_repo_dir(setup_option=setup_option) - else: - os.makedirs(self.repo_path, exist_ok=True) - GitCLI.init(self.repo_path) - self._local_git_repo = LocalGitRepo(repo_path=repo_config.repo_path) - if self.repo_config.full_name is None: - self.repo_config.full_name = self._local_git_repo.full_name - #################################################################################################################### # PROPERTIES #################################################################################################################### @@ -85,7 +80,11 @@ def repo_name(self) -> str: @property def repo_path(self) -> str: - return os.path.join(self.base_dir, self.repo_name) + return self.repo_config.path + + @property + def repo_full_name(self) -> str: + return self.repo_config.full_name or f"{self._local_git_repo.owner}/{self.repo_name}" @property def remote_git_repo(self) -> GitRepoClient: @@ -94,18 +93,18 @@ def remote_git_repo(self) -> GitRepoClient: raise ValueError(msg) if not self._remote_git_repo: - self._remote_git_repo = GitRepoClient(self.repo_config, access_token=self.access_token) + self._remote_git_repo = GitRepoClient(self.repo_full_name, access_token=self.access_token) return self._remote_git_repo @property def clone_url(self) -> str: if self.access_token: return get_authenticated_clone_url_for_repo_config(repo=self.repo_config, token=self.access_token) - return f"https://github.com/{self.repo_config.full_name}.git" + return get_clone_url_for_repo_config(repo=self.repo_config) @property def viz_path(self) -> str: - return os.path.join(self.base_dir, "codegen-graphviz") + return os.path.join(self.repo_config.base_dir, "codegen-graphviz") @property def viz_file_path(self) -> str: @@ -203,8 +202,7 @@ def codeowners_parser(self) -> CodeOwnersParser | None: # SET UP #################################################################################################################### def setup_repo_dir(self, setup_option: SetupOption = SetupOption.PULL_OR_CLONE, shallow: bool = True) -> None: - os.makedirs(self.base_dir, exist_ok=True) - os.chdir(self.base_dir) + os.chdir(self.repo_config.base_dir) if setup_option is SetupOption.CLONE: # if repo exists delete, then clone, else clone clone_repo(shallow=shallow, repo_path=self.repo_path, clone_url=self.clone_url) @@ -278,8 +276,6 @@ def clone_repo(self, shallow: bool = True) -> None: def clone_or_pull_repo(self, shallow: bool = True) -> None: """If repo exists, pulls changes. otherwise, clones the repo.""" # TODO(CG-7804): if repo is not valid we should delete it and re-clone. maybe we can create a pull_repo util + use the existing clone_repo util - if self.repo_exists(): - self.clean_repo() clone_or_pull_repo(repo_path=self.repo_path, clone_url=self.clone_url, shallow=shallow) #################################################################################################################### @@ -577,7 +573,7 @@ def delete_file(self, path: str) -> None: def get_filepaths_for_repo(self, ignore_list): # Get list of files to iterate over based on gitignore setting - if self.repo_config.respect_gitignore: + if self.respect_gitignore: # ls-file flags: # -c: show cached files # -o: show other / untracked files @@ -808,7 +804,7 @@ def get_pull_request(self, pr_number: int) -> PullRequest | None: # CLASS METHODS #################################################################################################################### @classmethod - def create_from_files(cls, repo_path: str, files: dict[str, str], bot_commit: bool = True) -> Self: + def create_from_files(cls, repo_path: str, files: dict[str, str], programming_language: ProgrammingLanguage, bot_commit: bool = True) -> Self: """Used when you want to create a directory from a set of files and then create a RepoOperator that points to that directory. Use cases: - Unit testing @@ -824,7 +820,9 @@ def create_from_files(cls, repo_path: str, files: dict[str, str], bot_commit: bo create_files(base_dir=repo_path, files=files) # Step 2: Init git repo - op = cls(repo_config=RepoConfig.from_repo_path(repo_path), bot_commit=bot_commit) + repo_config = RepositoryConfig.from_path(path=repo_path) + repo_config.language = programming_language + op = cls(repo_config=repo_config, bot_commit=bot_commit) if op.stage_and_commit_all_changes("[Codegen] initial commit"): op.checkout_branch(None, create_if_missing=True) return op @@ -839,60 +837,14 @@ def create_from_commit(cls, repo_path: str, commit: str, url: str, access_token: url (str): Git URL of the repository access_token (str | None): Optional GitHub API key for operations that need GitHub access """ - op = cls(repo_config=RepoConfig.from_repo_path(repo_path, full_name=full_name), bot_commit=False, access_token=access_token) + repo_config = RepositoryConfig.from_path(path=repo_path) + if full_name: + repo_config.owner = full_name.split("/")[0] + op = cls(repo_config=repo_config, bot_commit=False, access_token=access_token) op.discard_changes() if op.get_active_branch_or_commit() != commit: op.create_remote("origin", url) op.git_cli.remotes["origin"].fetch(commit, depth=1) op.checkout_commit(commit) return op - - @classmethod - def create_from_repo(cls, repo_path: str, url: str, access_token: str | None = None) -> Self | None: - """Create a fresh clone of a repository or use existing one if up to date. - - Args: - repo_path (str): Path where the repo should be cloned - url (str): Git URL of the repository - access_token (str | None): Optional GitHub API key for operations that need GitHub access - """ - access_token = access_token or SecretsConfig().github_token - if access_token: - url = add_access_token_to_url(url=url, token=access_token) - - # Check if repo already exists - if os.path.exists(repo_path): - try: - # Try to initialize git repo from existing path - git_cli = GitCLI(repo_path) - # Check if it has our remote URL - if any(remote.url == url for remote in git_cli.remotes): - # Fetch to check for updates - git_cli.remotes.origin.fetch() - # Get current and remote HEADs - local_head = git_cli.head.commit - remote_head = git_cli.remotes.origin.refs[git_cli.active_branch.name].commit - # If up to date, use existing repo - if local_head.hexsha == remote_head.hexsha: - return cls(repo_config=RepoConfig.from_repo_path(repo_path), bot_commit=False, access_token=access_token) - except Exception: - # If any git operations fail, fallback to fresh clone - pass - - # If we get here, repo exists but is not up to date or valid - # Remove the existing directory to do a fresh clone - import shutil - - shutil.rmtree(repo_path) - try: - # Clone the repository - GitCLI.clone_from(url=url, to_path=repo_path, depth=1) - - # Initialize with the cloned repo - git_cli = GitCLI(repo_path) - except (GitCommandError, ValueError) as e: - logger.exception("Failed to initialize Git repository:") - logger.exception("Please authenticate with a valid token and ensure the repository is properly initialized.") - return None - return cls(repo_config=RepoConfig.from_repo_path(repo_path), bot_commit=False, access_token=access_token) diff --git a/src/codegen/git/schemas/repo_config.py b/src/codegen/git/schemas/repo_config.py deleted file mode 100644 index 0e54b8362..000000000 --- a/src/codegen/git/schemas/repo_config.py +++ /dev/null @@ -1,53 +0,0 @@ -import os.path -from pathlib import Path - -from pydantic import BaseModel - -from codegen.configs.models.repository import RepositoryConfig -from codegen.git.schemas.enums import RepoVisibility -from codegen.shared.enums.programming_language import ProgrammingLanguage -from codegen.shared.logging.get_logger import get_logger - -logger = get_logger(__name__) - - -class RepoConfig(BaseModel): - """All the information about the repo needed to build a codebase""" - - name: str - full_name: str | None = None - visibility: RepoVisibility | None = None - - # Codebase fields - base_dir: str = "/tmp" # parent directory of the git repo - language: ProgrammingLanguage = ProgrammingLanguage.PYTHON - respect_gitignore: bool = True - base_path: str | None = None # root directory of the codebase within the repo - subdirectories: list[str] | None = None - - @classmethod - def from_envs(cls) -> "RepoConfig": - default_repo_config = RepositoryConfig() - return RepoConfig( - name=default_repo_config.name, - full_name=default_repo_config.full_name, - base_dir=os.path.dirname(default_repo_config.path), - language=ProgrammingLanguage(default_repo_config.language.upper()), - ) - - @classmethod - def from_repo_path(cls, repo_path: str, full_name: str | None = None) -> "RepoConfig": - name = os.path.basename(repo_path) - base_dir = os.path.dirname(repo_path) - return cls(name=name, base_dir=base_dir, full_name=full_name) - - @property - def repo_path(self) -> Path: - return Path(f"{self.base_dir}/{self.name}") - - @property - def organization_name(self) -> str | None: - if self.full_name is not None: - return self.full_name.split("/")[0] - - return None diff --git a/src/codegen/git/utils/clone_url.py b/src/codegen/git/utils/clone_url.py index 21cd80cfb..1d3636b83 100644 --- a/src/codegen/git/utils/clone_url.py +++ b/src/codegen/git/utils/clone_url.py @@ -1,6 +1,6 @@ from urllib.parse import urlparse -from codegen.git.schemas.repo_config import RepoConfig +from codegen.configs.models.repository import RepositoryConfig # TODO: move out doesn't belong here @@ -9,11 +9,11 @@ def url_to_github(url: str, branch: str) -> str: return f"{clone_url}/blob/{branch}" -def get_clone_url_for_repo_config(repo_config: RepoConfig) -> str: +def get_clone_url_for_repo_config(repo_config: RepositoryConfig) -> str: return f"https://github.com/{repo_config.full_name}.git" -def get_authenticated_clone_url_for_repo_config(repo: RepoConfig, token: str) -> str: +def get_authenticated_clone_url_for_repo_config(repo: RepositoryConfig, token: str) -> str: git_url = get_clone_url_for_repo_config(repo) return add_access_token_to_url(git_url, token) diff --git a/src/codegen/git/utils/language.py b/src/codegen/git/utils/language.py index 551ac4212..c0191040a 100644 --- a/src/codegen/git/utils/language.py +++ b/src/codegen/git/utils/language.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Literal +from codegen.configs.models.repository import RepositoryConfig from codegen.git.utils.file_utils import split_git_path from codegen.shared.enums.programming_language import ProgrammingLanguage from codegen.shared.logging.get_logger import get_logger @@ -108,7 +109,6 @@ def _determine_language_by_git_file_count(folder_path: str) -> ProgrammingLangua or if less than MIN_LANGUAGE_RATIO of files match the dominant language """ from codegen.git.repo_operator.repo_operator import RepoOperator - from codegen.git.schemas.repo_config import RepoConfig from codegen.sdk.codebase.codebase_context import GLOBAL_FILE_IGNORE_LIST from codegen.sdk.python import PyFile from codegen.sdk.typescript.file import TSFile @@ -129,7 +129,7 @@ def _determine_language_by_git_file_count(folder_path: str) -> ProgrammingLangua # Initiate RepoOperator git_root, base_path = split_git_path(folder_path) - repo_config = RepoConfig.from_repo_path(repo_path=git_root) + repo_config = RepositoryConfig.from_path(path=git_root) repo_operator = RepoOperator(repo_config=repo_config) # Walk through the directory diff --git a/src/codegen/runner/clients/codebase_client.py b/src/codegen/runner/clients/codebase_client.py index 7b4bf16ce..c059c5690 100644 --- a/src/codegen/runner/clients/codebase_client.py +++ b/src/codegen/runner/clients/codebase_client.py @@ -3,9 +3,10 @@ import os import subprocess import time +from pathlib import Path +from codegen.configs.models.repository import RepositoryConfig from codegen.configs.models.secrets import SecretsConfig -from codegen.git.schemas.repo_config import RepoConfig from codegen.runner.clients.client import Client from codegen.runner.models.apis import SANDBOX_SERVER_PORT from codegen.shared.logging.get_logger import get_logger @@ -21,9 +22,9 @@ class CodebaseClient(Client): """Client for interacting with the locally hosted sandbox server.""" - repo_config: RepoConfig + repo_config: RepositoryConfig - def __init__(self, repo_config: RepoConfig, host: str = "127.0.0.1", port: int = SANDBOX_SERVER_PORT, server_path: str = RUNNER_SERVER_PATH): + def __init__(self, repo_config: RepositoryConfig, host: str = "127.0.0.1", port: int = SANDBOX_SERVER_PORT, server_path: str = RUNNER_SERVER_PATH): super().__init__(host=host, port=port) self.repo_config = repo_config self._process = None @@ -39,6 +40,8 @@ def _start_server(self, server_path: str) -> None: """Start the FastAPI server in a subprocess""" envs = self._get_envs() logger.info(f"Starting local server on {self.base_url} with envvars: {envs}") + for key, value in envs.items(): + logger.info(f"{key}={value}") self._process = subprocess.Popen( [ @@ -66,10 +69,10 @@ def _wait_for_server(self, timeout: int = 30, interval: float = 0.3) -> None: def _get_envs(self) -> dict: envs = os.environ.copy() codebase_envs = { - "REPOSITORY_PATH": str(self.repo_config.repo_path), - "REPOSITORY_OWNER": self.repo_config.organization_name, - "REPOSITORY_LANGUAGE": self.repo_config.language.value, - "GITHUB_TOKEN": SecretsConfig().github_token, + "REPOSITORY_PATH": self.repo_config.path, + "REPOSITORY_OWNER": self.repo_config.owner, + "REPOSITORY_LANGUAGE": self.repo_config.language, + "GITHUB_TOKEN": SecretsConfig(root_path=Path(self.repo_config.path)).github_token, } envs.update(codebase_envs) @@ -77,7 +80,5 @@ def _get_envs(self) -> dict: if __name__ == "__main__": - test_config = RepoConfig.from_repo_path("/Users/caroljung/git/codegen/codegen-agi") - test_config.full_name = "codegen-sh/codegen-agi" - client = CodebaseClient(test_config) + client = CodebaseClient(repo_config=RepositoryConfig.from_path("/Users/caroljung/git/codegen/codegen-agi")) print(client.is_running()) diff --git a/src/codegen/runner/sandbox/runner.py b/src/codegen/runner/sandbox/runner.py index 4a86bc618..fc0438f8f 100644 --- a/src/codegen/runner/sandbox/runner.py +++ b/src/codegen/runner/sandbox/runner.py @@ -1,15 +1,16 @@ import sys from codegen.configs.models.codebase import CodebaseConfig +from codegen.configs.models.repository import RepositoryConfig from codegen.git.repo_operator.repo_operator import RepoOperator from codegen.git.schemas.enums import SetupOption -from codegen.git.schemas.repo_config import RepoConfig from codegen.runner.models.apis import CreateBranchRequest, CreateBranchResponse, GetDiffRequest, GetDiffResponse from codegen.runner.sandbox.executor import SandboxExecutor from codegen.sdk.codebase.config import ProjectConfig, SessionOptions from codegen.sdk.codebase.factory.codebase_factory import CodebaseType from codegen.sdk.core.codebase import Codebase from codegen.shared.compilation.string_to_code import create_execute_function_from_codeblock +from codegen.shared.enums.programming_language import ProgrammingLanguage from codegen.shared.logging.get_logger import get_logger logger = get_logger(__name__) @@ -19,16 +20,16 @@ class SandboxRunner: """Responsible for orchestrating the lifecycle of a warmed sandbox""" # =====[ __init__ instance attributes ]===== - repo: RepoConfig + repo: RepositoryConfig op: RepoOperator | None # =====[ computed instance attributes ]===== codebase: CodebaseType executor: SandboxExecutor - def __init__(self, repo_config: RepoConfig, op: RepoOperator | None = None) -> None: - self.repo = repo_config - self.op = op or RepoOperator(repo_config=self.repo, setup_option=SetupOption.PULL_OR_CLONE, bot_commit=True) + def __init__(self, repo_config: RepositoryConfig, op: RepoOperator | None = None) -> None: + self.op = op or RepoOperator(repo_config=repo_config, setup_option=SetupOption.PULL_OR_CLONE, bot_commit=True) + self.repo = self.op.repo_config async def warmup(self, codebase_config: CodebaseConfig | None = None) -> None: """Warms up this runner by cloning the repo and parsing the graph.""" @@ -40,7 +41,8 @@ async def warmup(self, codebase_config: CodebaseConfig | None = None) -> None: async def _build_graph(self, codebase_config: CodebaseConfig | None = None) -> Codebase: logger.info("> Building graph...") - projects = [ProjectConfig(programming_language=self.repo.language, repo_operator=self.op, base_path=self.repo.base_path, subdirectories=self.repo.subdirectories)] + programming_language = ProgrammingLanguage(self.repo.language.upper()) + projects = [ProjectConfig(programming_language=programming_language, repo_operator=self.op, base_path=self.repo.base_path, subdirectories=self.repo.subdirectories)] return Codebase(projects=projects, config=codebase_config) async def get_diff(self, request: GetDiffRequest) -> GetDiffResponse: diff --git a/src/codegen/runner/sandbox/server.py b/src/codegen/runner/sandbox/server.py index a6a346fcf..3f57f9718 100644 --- a/src/codegen/runner/sandbox/server.py +++ b/src/codegen/runner/sandbox/server.py @@ -1,10 +1,8 @@ -import os from contextlib import asynccontextmanager from fastapi import FastAPI from codegen.configs.models.repository import RepositoryConfig -from codegen.git.schemas.repo_config import RepoConfig from codegen.runner.enums.warmup_state import WarmupState from codegen.runner.models.apis import ( BRANCH_ENDPOINT, @@ -17,7 +15,6 @@ ) from codegen.runner.sandbox.middlewares import CodemodRunMiddleware from codegen.runner.sandbox.runner import SandboxRunner -from codegen.shared.enums.programming_language import ProgrammingLanguage from codegen.shared.logging.get_logger import get_logger logger = get_logger(__name__) @@ -31,18 +28,11 @@ async def lifespan(server: FastAPI): global server_info global runner - default_repo_config = RepositoryConfig() - repo_name = default_repo_config.full_name or default_repo_config.name - server_info = ServerInfo(repo_name=repo_name) try: - logger.info(f"Starting up sandbox fastapi server for repo_name={repo_name}") - repo_config = RepoConfig( - name=default_repo_config.name, - full_name=default_repo_config.full_name, - base_dir=os.path.dirname(default_repo_config.path), - language=ProgrammingLanguage(default_repo_config.language.upper()), - ) + repo_config = RepositoryConfig() runner = SandboxRunner(repo_config=repo_config) + server_info = ServerInfo(repo_name=runner.repo.full_name or runner.repo.name) + logger.info(f"Starting up sandbox fastapi server for repo_name={server_info.repo_name}") server_info.warmup_state = WarmupState.PENDING await runner.warmup() server_info.synced_commit = runner.op.git_cli.head.commit.hexsha diff --git a/src/codegen/runner/servers/local_daemon.py b/src/codegen/runner/servers/local_daemon.py index 1d24006ae..00d8ec12e 100644 --- a/src/codegen/runner/servers/local_daemon.py +++ b/src/codegen/runner/servers/local_daemon.py @@ -4,10 +4,10 @@ from fastapi import FastAPI from codegen.configs.models.codebase import DefaultCodebaseConfig +from codegen.configs.models.repository import RepositoryConfig from codegen.git.configs.constants import CODEGEN_BOT_EMAIL, CODEGEN_BOT_NAME from codegen.git.repo_operator.repo_operator import RepoOperator from codegen.git.schemas.enums import SetupOption -from codegen.git.schemas.repo_config import RepoConfig from codegen.runner.enums.warmup_state import WarmupState from codegen.runner.models.apis import ( RUN_FUNCTION_ENDPOINT, @@ -37,18 +37,17 @@ async def lifespan(server: FastAPI): global runner try: - repo_config = RepoConfig.from_envs() - server_info = ServerInfo(repo_name=repo_config.full_name or repo_config.name) - - # Set the bot email and username + repo_config = RepositoryConfig() op = RepoOperator(repo_config=repo_config, setup_option=SetupOption.SKIP, bot_commit=True) runner = SandboxRunner(repo_config=repo_config, op=op) + server_info = ServerInfo(repo_name=runner.repo.full_name or runner.repo.name) + logger.info(f"Configuring git user config to {CODEGEN_BOT_EMAIL} and {CODEGEN_BOT_NAME}") runner.op.git_cli.git.config("user.email", CODEGEN_BOT_EMAIL) runner.op.git_cli.git.config("user.name", CODEGEN_BOT_NAME) # Parse the codebase with sync enabled - logger.info(f"Starting up fastapi server for repo_name={repo_config.name}") + logger.info(f"Starting up fastapi server for repo_name={server_info.repo_name}") server_info.warmup_state = WarmupState.PENDING codebase_config = DefaultCodebaseConfig.model_copy(update={"sync_enabled": True}) await runner.warmup(codebase_config=codebase_config) diff --git a/src/codegen/sdk/code_generation/current_code_codebase.py b/src/codegen/sdk/code_generation/current_code_codebase.py index bfcee2232..9b418861f 100644 --- a/src/codegen/sdk/code_generation/current_code_codebase.py +++ b/src/codegen/sdk/code_generation/current_code_codebase.py @@ -5,9 +5,9 @@ from typing import TypedDict from codegen.configs.models.codebase import CodebaseConfig +from codegen.configs.models.repository import RepositoryConfig from codegen.configs.models.secrets import SecretsConfig from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.git.schemas.repo_config import RepoConfig from codegen.sdk.codebase.config import ProjectConfig from codegen.sdk.core.codebase import Codebase, CodebaseType from codegen.shared.decorators.docs import DocumentedObject, apidoc_objects, no_apidoc_objects, py_apidoc_objects, ts_apidoc_objects @@ -41,11 +41,10 @@ def get_current_code_codebase(config: CodebaseConfig | None = None, secrets: Sec base_dir = get_codegen_codebase_base_path() logger.info(f"Creating codebase from repo at: {codegen_repo_path} with base_path {base_dir}") - repo_config = RepoConfig.from_repo_path(codegen_repo_path) - repo_config.respect_gitignore = False - op = RepoOperator(repo_config=repo_config, bot_commit=False) + repo_config = RepositoryConfig.from_path(path=codegen_repo_path) + op = RepoOperator(repo_config=repo_config, bot_commit=False, respect_gitignore=False) - config = (config or CodebaseConfig()).model_copy(update={"base_path": base_dir}) + config = config or CodebaseConfig() projects = [ProjectConfig(repo_operator=op, programming_language=ProgrammingLanguage.PYTHON, subdirectories=subdirectories, base_path=base_dir)] codebase = Codebase(projects=projects, config=config, secrets=secrets) return codebase diff --git a/src/codegen/sdk/codebase/config.py b/src/codegen/sdk/codebase/config.py index 25f3c2e0e..e9898c7df 100644 --- a/src/codegen/sdk/codebase/config.py +++ b/src/codegen/sdk/codebase/config.py @@ -6,8 +6,8 @@ from pydantic.fields import Field from codegen.configs.models.codebase import DefaultCodebaseConfig +from codegen.configs.models.repository import RepositoryConfig from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.git.schemas.repo_config import RepoConfig from codegen.git.utils.file_utils import split_git_path from codegen.git.utils.language import determine_project_language from codegen.shared.enums.programming_language import ProgrammingLanguage @@ -35,7 +35,7 @@ class ProjectConfig(BaseModel): model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) repo_operator: RepoOperator - # TODO: clean up these fields. Duplicated across RepoConfig and CodebaseContext + # TODO: clean up these fields. Duplicated across RepositoryConfig and CodebaseContext base_path: str | None = None subdirectories: list[str] | None = None programming_language: ProgrammingLanguage = ProgrammingLanguage.PYTHON @@ -47,9 +47,7 @@ def from_path(cls, path: str, programming_language: ProgrammingLanguage | None = git_root, base_path = split_git_path(repo_path) subdirectories = [base_path] if base_path else None programming_language = programming_language or determine_project_language(repo_path) - repo_config = RepoConfig.from_repo_path(repo_path=git_root) - repo_config.language = programming_language - repo_config.subdirectories = subdirectories + repo_config = RepositoryConfig.from_path(path=repo_path, language=programming_language.value, subdirectories=subdirectories, base_path=base_path) # Create main project return cls( repo_operator=RepoOperator(repo_config=repo_config), @@ -60,9 +58,13 @@ def from_path(cls, path: str, programming_language: ProgrammingLanguage | None = @classmethod def from_repo_operator(cls, repo_operator: RepoOperator, programming_language: ProgrammingLanguage | None = None, base_path: str | None = None) -> Self: + language = programming_language or determine_project_language(repo_operator.repo_path) + repo_operator.repo_config.language = language.value + repo_operator.repo_config.subdirectories = [base_path] if base_path else None + repo_operator.repo_config.base_path = base_path return cls( repo_operator=repo_operator, - programming_language=programming_language or determine_project_language(repo_operator.repo_path), + programming_language=language, base_path=base_path, subdirectories=[base_path] if base_path else None, ) diff --git a/src/codegen/sdk/codebase/factory/codebase_factory.py b/src/codegen/sdk/codebase/factory/codebase_factory.py index 009992311..3c753a563 100644 --- a/src/codegen/sdk/codebase/factory/codebase_factory.py +++ b/src/codegen/sdk/codebase/factory/codebase_factory.py @@ -23,6 +23,6 @@ def get_codebase_from_files( config: CodebaseConfig | None = None, secrets: SecretsConfig | None = None, ) -> CodebaseType: - op = RepoOperator.create_from_files(repo_path=repo_path, files=files, bot_commit=bot_commit) + op = RepoOperator.create_from_files(repo_path=repo_path, files=files, bot_commit=bot_commit, programming_language=programming_language) projects = [ProjectConfig(repo_operator=op, programming_language=programming_language)] return Codebase(projects=projects, config=config, secrets=secrets) diff --git a/src/codegen/sdk/codebase/factory/get_session.py b/src/codegen/sdk/codebase/factory/get_session.py index 189eec6e6..089923427 100644 --- a/src/codegen/sdk/codebase/factory/get_session.py +++ b/src/codegen/sdk/codebase/factory/get_session.py @@ -112,7 +112,7 @@ def get_codebase_graph_session( session_options: SessionOptions = SessionOptions(), ) -> Generator[CodebaseContext, None, None]: """Gives you a Codebase2 operating on the files you provided as a dict""" - op = RepoOperator.create_from_files(repo_path=tmpdir, files=files) + op = RepoOperator.create_from_files(repo_path=tmpdir, files=files, programming_language=programming_language) projects = [ProjectConfig(repo_operator=op, programming_language=programming_language)] graph = CodebaseContext(projects=projects, config=TestFlags) with graph.session(sync_graph=sync_graph, session_options=session_options): diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index 735a8b5ea..36ca0664d 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -23,6 +23,7 @@ from typing_extensions import TypeVar, deprecated from codegen.configs.models.codebase import CodebaseConfig +from codegen.configs.models.repository import RepositoryConfig from codegen.configs.models.secrets import SecretsConfig from codegen.git.repo_operator.repo_operator import RepoOperator from codegen.git.schemas.enums import CheckoutResult, SetupOption @@ -1342,13 +1343,19 @@ def from_repo( # Setup repo path and URL repo_path = os.path.join(tmp_dir, repo) repo_url = f"https://github.com/{repo_full_name}.git" + repo_config = RepositoryConfig( + root_path=Path(repo_path), + path=repo_path, + owner=owner, + language=language.value if language else None, + ) logger.info(f"Will clone {repo_url} to {repo_path}") try: # Use RepoOperator to fetch the repository logger.info("Cloning repository...") if commit is None: - repo_operator = RepoOperator.create_from_repo(repo_path=repo_path, url=repo_url) + repo_operator = RepoOperator(repo_config=repo_config, setup_option=SetupOption.PULL_OR_CLONE) else: # Ensure the operator can handle remote operations access_token = secrets.github_token if secrets else None diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index 16f7f876b..9e8ebd845 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -15,12 +15,12 @@ from tree_sitter import Node as TSNode from codegen.sdk.codebase.codebase_context import CodebaseContext + from codegen.sdk.core.file import SourceFile from codegen.sdk.core.interfaces.editable import Editable from codegen.sdk.core.interfaces.exportable import Exportable from codegen.sdk.core.node_id_factory import NodeId from codegen.sdk.core.statements.import_statement import ImportStatement from codegen.sdk.python.file import PyFile - from src.codegen.sdk.core.file import SourceFile logger = get_logger(__name__) diff --git a/src/codegen/sdk/system-prompt.txt b/src/codegen/sdk/system-prompt.txt index e5b616c1b..8407ff04f 100644 --- a/src/codegen/sdk/system-prompt.txt +++ b/src/codegen/sdk/system-prompt.txt @@ -624,7 +624,7 @@ Codegen creates a custom Python environment in `.codegen/.venv`. Configure your ```bash .codegen/.venv/bin/python ``` - + Alternatively, create a `.vscode/settings.json`: ```json { @@ -646,7 +646,7 @@ Codegen creates a custom Python environment in `.codegen/.venv`. Configure your .codegen/.venv/bin/python ``` - + @@ -1191,8 +1191,8 @@ iconType: "solid" - Yes - [by design](/introduction/guiding-principles#python-first-composability). - + Yes - [by design](/introduction/guiding-principles#python-first-composability). + Codegen works like any other python package. It works alongside your IDE, version control system, and other development tools. - Currently, the codebase object can only parse source code files of one language at a time. This means that if you want to work with both Python and TypeScript files, you will need to create two separate codebase objects. + Currently, the codebase object can only parse source code files of one language at a time. This means that if you want to work with both Python and TypeScript files, you will need to create two separate codebase objects. ## Accessing Code @@ -2968,7 +2966,7 @@ for module, imports in module_imports.items(): Always check if imports resolve to external modules before modification to avoid breaking third-party package imports. - + ## Import Statements vs Imports @@ -3170,7 +3168,7 @@ for exp in file.exports: # Get original and current symbols current = exp.exported_symbol original = exp.resolved_symbol - + print(f"Re-exporting {original.name} from {exp.from_file.filepath}") print(f"Through: {' -> '.join(e.file.filepath for e in exp.export_chain)}") ``` @@ -3220,7 +3218,7 @@ for from_file, exports in file_exports.items(): When managing exports, consider the impact on your module's public API. Not all symbols that can be exported should be exported. - + --- title: "Inheritable Behaviors" @@ -3710,9 +3708,9 @@ If `A` depends on `B`, then `B` is used by `A`. This relationship is tracked in flowchart LR B(BaseClass) - - - + + + A(MyClass) B ---| used by |A A ---|depends on |B @@ -3881,7 +3879,7 @@ class A: def method_a(self): pass class B(A): - def method_b(self): + def method_b(self): self.method_a() class C(B): @@ -4771,7 +4769,7 @@ for attr in class_def.attributes: # Each attribute has an assignment property attr_type = attr.assignment.type # -> TypeAnnotation print(f"{attr.name}: {attr_type.source}") # e.g. "x: int" - + # Set attribute type attr.assignment.set_type("int") @@ -4788,7 +4786,7 @@ Union types ([UnionType](/api-reference/core/UnionType)) can be manipulated as c ```python # Get union type -union_type = function.return_type # -> A | B +union_type = function.return_type # -> A | B print(union_type.symbols) # ["A", "B"] # Add/remove options @@ -5639,13 +5637,13 @@ Here's an example of using flags during code analysis: ```python def analyze_codebase(codebase): - for function in codebase.functions: + for function in codebase.functions: # Check documentation if not function.docstring: function.flag( message="Missing docstring", ) - + # Check error handling if function.is_async and not function.has_try_catch: function.flag( @@ -6355,7 +6353,7 @@ Explore our tutorials to learn how to use Codegen for various code transformatio > Update API calls, handle breaking changes, and manage bulk updates across your codebase. - Convert Flask applications to FastAPI, updating routes and dependencies. - Migrate Python 2 code to Python 3, updating syntax and modernizing APIs. @@ -6388,9 +6386,9 @@ Explore our tutorials to learn how to use Codegen for various code transformatio > Restructure files, enforce naming conventions, and improve project layout. - Split large files, extract shared logic, and manage dependencies. @@ -6488,7 +6486,7 @@ The agent has access to powerful code viewing and manipulation tools powered by - `CreateFileTool`: Create new files - `DeleteFileTool`: Delete files - `RenameFileTool`: Rename files -- `EditFileTool`: Edit files +- `EditFileTool`: Edit files @@ -6995,7 +6993,7 @@ Be explicit about the changes, produce a short summary, and point out possible i Focus on facts and technical details, using code snippets where helpful. """ result = agent.run(prompt) - + # Clean up the temporary comment comment.delete() ``` @@ -7176,21 +7174,21 @@ def research(repo_name: Optional[str] = None, query: Optional[str] = None): """Start a code research session.""" # Initialize codebase codebase = initialize_codebase(repo_name) - + # Create and run the agent agent = create_research_agent(codebase) - + # Main research loop while True: if not query: query = Prompt.ask("[bold cyan]Research query[/bold cyan]") - + result = agent.invoke( {"input": query}, config={"configurable": {"thread_id": 1}} ) console.print(Markdown(result["messages"][-1].content)) - + query = None # Clear for next iteration ``` @@ -7238,7 +7236,7 @@ class CustomAnalysisTool(BaseTool): """Custom tool for specialized code analysis.""" name = "custom_analysis" description = "Performs specialized code analysis" - + def _run(self, query: str) -> str: # Custom analysis logic return results @@ -7516,7 +7514,7 @@ from codegen import Codebase # Initialize codebase codebase = Codebase("path/to/posthog/") -# Create a directed graph for representing call relationships +# Create a directed graph for representing call relationships G = nx.DiGraph() # Configuration flags @@ -7538,7 +7536,7 @@ We'll create a function that will recursively traverse the call trace of a funct ```python def create_downstream_call_trace(src_func: Function, depth: int = 0): """Creates call graph by recursively traversing function calls - + Args: src_func (Function): Starting function for call graph depth (int): Current recursion depth @@ -7546,7 +7544,7 @@ def create_downstream_call_trace(src_func: Function, depth: int = 0): # Prevent infinite recursion if MAX_DEPTH <= depth: return - + # External modules are not functions if isinstance(src_func, ExternalModule): return @@ -7556,12 +7554,12 @@ def create_downstream_call_trace(src_func: Function, depth: int = 0): # Skip self-recursive calls if call.name == src_func.name: continue - + # Get called function definition func = call.function_definition if not func: continue - + # Apply configured filters if isinstance(func, ExternalModule) and IGNORE_EXTERNAL_MODULE_CALLS: continue @@ -7575,7 +7573,7 @@ def create_downstream_call_trace(src_func: Function, depth: int = 0): func_name = f"{func.parent_class.name}.{func.name}" if func.is_method else func.name # Add node and edge with metadata - G.add_node(func, name=func_name, + G.add_node(func, name=func_name, color=COLOR_PALETTE.get(func.__class__.__name__)) G.add_edge(src_func, func, **generate_edge_meta(call)) @@ -7590,10 +7588,10 @@ We can enrich our edges with metadata about the function calls: ```python def generate_edge_meta(call: FunctionCall) -> dict: """Generate metadata for call graph edges - + Args: call (FunctionCall): Function call information - + Returns: dict: Edge metadata including name and location """ @@ -7612,8 +7610,8 @@ Finally, we can visualize our call graph starting from a specific function: target_class = codebase.get_class('SharingConfigurationViewSet') target_method = target_class.get_method('patch') -# Add root node -G.add_node(target_method, +# Add root node +G.add_node(target_method, name=f"{target_class.name}.{target_method.name}", color=COLOR_PALETTE["StartFunction"]) @@ -7663,7 +7661,7 @@ The core function for building our dependency graph: ```python def create_dependencies_visualization(symbol: Symbol, depth: int = 0): """Creates visualization of symbol dependencies - + Args: symbol (Symbol): Starting symbol to analyze depth (int): Current recursion depth @@ -7671,11 +7669,11 @@ def create_dependencies_visualization(symbol: Symbol, depth: int = 0): # Prevent excessive recursion if depth >= MAX_DEPTH: return - + # Process each dependency for dep in symbol.dependencies: dep_symbol = None - + # Handle different dependency types if isinstance(dep, Symbol): # Direct symbol reference @@ -7686,13 +7684,13 @@ def create_dependencies_visualization(symbol: Symbol, depth: int = 0): if dep_symbol: # Add node with appropriate styling - G.add_node(dep_symbol, - color=COLOR_PALETTE.get(dep_symbol.__class__.__name__, + G.add_node(dep_symbol, + color=COLOR_PALETTE.get(dep_symbol.__class__.__name__, "#f694ff")) - + # Add dependency relationship G.add_edge(symbol, dep_symbol) - + # Recurse unless it's a class (avoid complexity) if not isinstance(dep_symbol, PyClass): create_dependencies_visualization(dep_symbol, depth + 1) @@ -7704,7 +7702,7 @@ Finally, we can visualize our dependency graph starting from a specific symbol: # Get target symbol target_func = codebase.get_function("get_query_runner") -# Add root node +# Add root node G.add_node(target_func, color=COLOR_PALETTE["StartFunction"]) # Generate dependency graph @@ -7747,16 +7745,16 @@ HTTP_METHODS = ["get", "put", "patch", "post", "head", "delete"] def generate_edge_meta(usage: Usage) -> dict: """Generate metadata for graph edges - + Args: usage (Usage): Usage relationship information - + Returns: dict: Edge metadata including name and location """ return { "name": usage.match.source, - "file_path": usage.match.filepath, + "file_path": usage.match.filepath, "start_point": usage.match.start_point, "end_point": usage.match.end_point, "symbol_name": usage.match.__class__.__name__ @@ -7764,10 +7762,10 @@ def generate_edge_meta(usage: Usage) -> dict: def is_http_method(symbol: PySymbol) -> bool: """Check if a symbol is an HTTP endpoint method - + Args: symbol (PySymbol): Symbol to check - + Returns: bool: True if symbol is an HTTP method """ @@ -7781,7 +7779,7 @@ The main function for creating our blast radius visualization: ```python def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0): """Create visualization of symbol usage relationships - + Args: symbol (PySymbol): Starting symbol to analyze depth (int): Current recursion depth @@ -7789,11 +7787,11 @@ def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0): # Prevent excessive recursion if depth >= MAX_DEPTH: return - + # Process each usage of the symbol for usage in symbol.usages: usage_symbol = usage.usage_symbol - + # Determine node color based on type if is_http_method(usage_symbol): color = COLOR_PALETTE.get("HTTP_METHOD") @@ -7803,7 +7801,7 @@ def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0): # Add node and edge to graph G.add_node(usage_symbol, color=color) G.add_edge(symbol, usage_symbol, **generate_edge_meta(usage)) - + # Recursively process usage symbol create_blast_radius_visualization(usage_symbol, depth + 1) ``` @@ -7954,7 +7952,7 @@ for call in old_api.call_sites: f"data={call.get_arg_by_parameter_name('input').value}", f"timeout={call.get_arg_by_parameter_name('wait').value}" ] - + # Replace the old call with the new API call.replace(f"new_process_data({', '.join(args)})") ``` @@ -7968,10 +7966,10 @@ When updating chained method calls, like database queries or builder patterns: for execute_call in codebase.function_calls: if execute_call.name != "execute": continue - + # Get the full chain chain = execute_call.call_chain - + # Example: Add .timeout() before .execute() if "timeout" not in {call.name for call in chain}: execute_call.insert_before("timeout(30)") @@ -7990,45 +7988,45 @@ Here's a comprehensive example: ```python def migrate_api_v1_to_v2(codebase): old_api = codebase.get_function("create_user_v1") - + # Document all existing call patterns call_patterns = {} for call in old_api.call_sites: args = [arg.source for arg in call.args] pattern = ", ".join(args) call_patterns[pattern] = call_patterns.get(pattern, 0) + 1 - + print("Found call patterns:") for pattern, count in call_patterns.items(): print(f" {pattern}: {count} occurrences") - + # Create new API version new_api = old_api.copy() new_api.rename("create_user_v2") - + # Update parameter types new_api.get_parameter("email").type = "EmailStr" new_api.get_parameter("role").type = "UserRole" - + # Add new required parameters new_api.add_parameter("tenant_id: UUID") - + # Update all call sites for call in old_api.call_sites: # Get current arguments email_arg = call.get_arg_by_parameter_name("email") role_arg = call.get_arg_by_parameter_name("role") - + # Build new argument list with type conversions new_args = [ f"email=EmailStr({email_arg.value})", f"role=UserRole({role_arg.value})", "tenant_id=get_current_tenant_id()" ] - + # Replace old call with new version call.replace(f"create_user_v2({', '.join(new_args)})") - + # Add deprecation notice to old version old_api.add_decorator('@deprecated("Use create_user_v2 instead")') @@ -8050,10 +8048,10 @@ migrate_api_v1_to_v2(codebase) ```python # First update parameter names param.rename("new_name") - + # Then update types param.type = "new_type" - + # Finally update call sites for call in api.call_sites: # ... update calls @@ -8063,7 +8061,7 @@ migrate_api_v1_to_v2(codebase) ```python # Add new parameter with default api.add_parameter("new_param: str = None") - + # Later make it required api.get_parameter("new_param").remove_default() ``` @@ -8078,7 +8076,7 @@ migrate_api_v1_to_v2(codebase) Remember to test thoroughly after making bulk changes to APIs. While Codegen ensures syntactic correctness, you'll want to verify the semantic correctness of the changes. - + --- title: "Organizing Your Codebase" @@ -8642,16 +8640,16 @@ from collections import defaultdict # Create a graph of file dependencies def create_dependency_graph(): G = nx.DiGraph() - + for file in codebase.files: # Add node for this file G.add_node(file.filepath) - + # Add edges for each import for imp in file.imports: if imp.from_file: # Skip external imports G.add_edge(file.filepath, imp.from_file.filepath) - + return G # Create and analyze the graph @@ -8680,18 +8678,18 @@ def break_circular_dependency(cycle): # Get the first two files in the cycle file1 = codebase.get_file(cycle[0]) file2 = codebase.get_file(cycle[1]) - + # Create a shared module for common code shared_dir = "shared" if not codebase.has_directory(shared_dir): codebase.create_directory(shared_dir) - + # Find symbols used by both files shared_symbols = [] for symbol in file1.symbols: if any(usage.file == file2 for usage in symbol.usages): shared_symbols.append(symbol) - + # Move shared symbols to a new file if shared_symbols: shared_file = codebase.create_file(f"{shared_dir}/shared_types.py") @@ -8713,7 +8711,7 @@ def organize_file_imports(file): std_lib_imports = [] third_party_imports = [] local_imports = [] - + for imp in file.imports: if imp.is_standard_library: std_lib_imports.append(imp) @@ -8721,26 +8719,26 @@ def organize_file_imports(file): third_party_imports.append(imp) else: local_imports.append(imp) - + # Sort each group for group in [std_lib_imports, third_party_imports, local_imports]: group.sort(key=lambda x: x.module_name) - + # Remove all existing imports for imp in file.imports: imp.remove() - + # Add imports back in organized groups if std_lib_imports: for imp in std_lib_imports: file.add_import(imp.source) file.insert_after_imports("") # Add newline - + if third_party_imports: for imp in third_party_imports: file.add_import(imp.source) file.insert_after_imports("") # Add newline - + if local_imports: for imp in local_imports: file.add_import(imp.source) @@ -8759,22 +8757,22 @@ from collections import defaultdict def analyze_module_coupling(): coupling_scores = defaultdict(int) - + for file in codebase.files: # Count unique files imported from imported_files = {imp.from_file for imp in file.imports if imp.from_file} coupling_scores[file.filepath] = len(imported_files) - + # Count files that import this file - importing_files = {usage.file for symbol in file.symbols + importing_files = {usage.file for symbol in file.symbols for usage in symbol.usages if usage.file != file} coupling_scores[file.filepath] += len(importing_files) - + # Sort by coupling score - sorted_files = sorted(coupling_scores.items(), - key=lambda x: x[1], + sorted_files = sorted(coupling_scores.items(), + key=lambda x: x[1], reverse=True) - + print("\nšŸ” Module Coupling Analysis:") print("\nMost coupled files:") for filepath, score in sorted_files[:5]: @@ -8792,9 +8790,9 @@ def extract_shared_code(file, min_usages=3): # Find symbols used by multiple files for symbol in file.symbols: # Get unique files using this symbol - using_files = {usage.file for usage in symbol.usages + using_files = {usage.file for usage in symbol.usages if usage.file != file} - + if len(using_files) >= min_usages: # Create appropriate shared module module_name = determine_shared_module(symbol) @@ -8802,7 +8800,7 @@ def extract_shared_code(file, min_usages=3): shared_file = codebase.create_file(f"shared/{module_name}.py") else: shared_file = codebase.get_file(f"shared/{module_name}.py") - + # Move symbol to shared module symbol.move_to_file(shared_file, strategy="update_all_imports") @@ -8856,7 +8854,7 @@ if feature_flag_class: # Initialize usage count for all attributes for attr in feature_flag_class.attributes: feature_flag_usage[attr.name] = 0 - + # Get all usages of the FeatureFlag class for usage in feature_flag_class.usages: usage_source = usage.usage_symbol.source if hasattr(usage, 'usage_symbol') else str(usage) @@ -9601,7 +9599,7 @@ Let's break down how this works: if export.is_reexport() and export.is_default_export(): print(f" šŸ”„ Converting default export '{export.name}'") ``` - + The code identifies default exports by checking: 1. If it's a re-export (`is_reexport()`) 2. If it's a default export (`is_default_export()`) @@ -9709,7 +9707,7 @@ for file in codebase.files: print(f"✨ Fixed exports in {target_file.filepath}") -``` +``` --- title: "Creating Documentation" @@ -9798,11 +9796,11 @@ for directory in codebase.directories: # Skip test, sql and alembic directories if any(x in directory.path.lower() for x in ['test', 'sql', 'alembic']): continue - + # Get undecorated functions funcs = [f for f in directory.functions if not f.is_decorated] total = len(funcs) - + # Only analyze dirs with >10 functions if total > 10: documented = sum(1 for f in funcs if f.docstring) @@ -9817,12 +9815,12 @@ for directory in codebase.directories: if dir_stats: lowest_dir = min(dir_stats.items(), key=lambda x: x[1]['coverage']) path, stats = lowest_dir - + print(f"šŸ“‰ Lowest coverage directory: '{path}'") print(f" • Total functions: {stats['total']}") print(f" • Documented: {stats['documented']}") print(f" • Coverage: {stats['coverage']:.1f}%") - + # Print all directory stats for comparison print("\nšŸ“Š All directory coverage rates:") for path, stats in sorted(dir_stats.items(), key=lambda x: x[1]['coverage']): @@ -10610,7 +10608,7 @@ iconType: "solid" -Import loops occur when two or more Python modules depend on each other, creating a circular dependency. While some import cycles can be harmless, others can lead to runtime errors and make code harder to maintain. +Import loops occur when two or more Python modules depend on each other, creating a circular dependency. While some import cycles can be harmless, others can lead to runtime errors and make code harder to maintain. In this tutorial, we'll explore how to identify and fix problematic import cycles using Codegen. @@ -11507,10 +11505,10 @@ Match (s: Func )-[r: CALLS]-> (e:Func) RETURN s, e LIMIT 10 ```cypher Match path = (:(Method|Func)) -[:CALLS*5..10]-> (:(Method|Func)) -Return path +Return path LIMIT 20 ``` - \ No newline at end of file + diff --git a/src/codegen/visualizations/visualization_manager.py b/src/codegen/visualizations/visualization_manager.py index 7be3cf8fb..114beb226 100644 --- a/src/codegen/visualizations/visualization_manager.py +++ b/src/codegen/visualizations/visualization_manager.py @@ -22,7 +22,7 @@ def __init__( @property def viz_path(self) -> str: - return os.path.join(self.op.base_dir, "codegen-graphviz") + return os.path.join(self.op.repo_config.base_dir, "codegen-graphviz") @property def viz_file_path(self) -> str: diff --git a/tests/integration/codegen/git/conftest.py b/tests/integration/codegen/git/conftest.py index e3846935d..2577e899a 100644 --- a/tests/integration/codegen/git/conftest.py +++ b/tests/integration/codegen/git/conftest.py @@ -2,7 +2,9 @@ import pytest -from codegen.git.schemas.repo_config import RepoConfig +from codegen.configs.models.repository import RepositoryConfig +from codegen.git.repo_operator.repo_operator import RepoOperator +from codegen.git.schemas.enums import SetupOption @pytest.fixture() @@ -16,9 +18,15 @@ def mock_config(): @pytest.fixture() def repo_config(tmpdir): - repo_config = RepoConfig( + repo_config = RepositoryConfig( name="Kevin-s-Adventure-Game", - full_name="codegen-sh/Kevin-s-Adventure-Game", + owner="codegen-sh", base_dir=str(tmpdir), ) yield repo_config + + +@pytest.fixture +def op(repo_config, request): + op = RepoOperator(repo_config=repo_config, shallow=request.param if hasattr(request, "param") else True, bot_commit=False, setup_option=SetupOption.PULL_OR_CLONE) + yield op diff --git a/tests/integration/codegen/git/repo_operator/test_remote_repo_operator.py b/tests/integration/codegen/git/repo_operator/test_remote_repo_operator.py index 76d7a6292..5bbf8ef87 100644 --- a/tests/integration/codegen/git/repo_operator/test_remote_repo_operator.py +++ b/tests/integration/codegen/git/repo_operator/test_remote_repo_operator.py @@ -4,18 +4,12 @@ from github.MainClass import Github from codegen.git.repo_operator.repo_operator import RepoOperator -from codegen.git.schemas.enums import CheckoutResult, SetupOption +from codegen.git.schemas.enums import CheckoutResult from codegen.git.utils.file_utils import create_files shallow_options = [True, False] -@pytest.fixture -def op(repo_config, request): - op = RepoOperator(repo_config, shallow=request.param if hasattr(request, "param") else True, bot_commit=False, setup_option=SetupOption.PULL_OR_CLONE) - yield op - - @pytest.mark.parametrize("op", shallow_options, ids=lambda x: f"shallow={x}", indirect=True) @patch("codegen.git.clients.github_client.Github") def test_checkout_branch(mock_git_client, op: RepoOperator): diff --git a/tests/integration/codegen/runner/conftest.py b/tests/integration/codegen/runner/conftest.py index 5f16fa4f4..0d7b20d89 100644 --- a/tests/integration/codegen/runner/conftest.py +++ b/tests/integration/codegen/runner/conftest.py @@ -3,37 +3,35 @@ import pytest +from codegen.configs.models.repository import RepositoryConfig from codegen.git.clients.git_repo_client import GitRepoClient from codegen.git.repo_operator.repo_operator import RepoOperator from codegen.git.schemas.enums import SetupOption -from codegen.git.schemas.repo_config import RepoConfig from codegen.runner.clients.codebase_client import CodebaseClient -from codegen.shared.enums.programming_language import ProgrammingLanguage from codegen.shared.network.port import get_free_port @pytest.fixture() -def repo_config(tmpdir) -> Generator[RepoConfig, None, None]: - yield RepoConfig( - name="Kevin-s-Adventure-Game", - full_name="codegen-sh/Kevin-s-Adventure-Game", - language=ProgrammingLanguage.PYTHON, - base_dir=str(tmpdir), +def repo_config(tmpdir) -> Generator[RepositoryConfig, None, None]: + yield RepositoryConfig( + path=str(tmpdir / "Kevin-s-Adventure-Game"), + owner="codegen-sh", + language="PYTHON", ) @pytest.fixture -def op(repo_config: RepoConfig) -> Generator[RepoOperator, None, None]: +def op(repo_config: RepositoryConfig) -> Generator[RepoOperator, None, None]: yield RepoOperator(repo_config=repo_config, setup_option=SetupOption.PULL_OR_CLONE) @pytest.fixture -def git_repo_client(op: RepoOperator, repo_config: RepoConfig) -> Generator[GitRepoClient, None, None]: - yield GitRepoClient(repo_config=repo_config, access_token=op.access_token) +def git_repo_client(op: RepoOperator, repo_config: RepositoryConfig) -> Generator[GitRepoClient, None, None]: + yield GitRepoClient(repo_full_name=repo_config.full_name, access_token=op.access_token) @pytest.fixture -def codebase_client(repo_config: RepoConfig) -> Generator[CodebaseClient, None, None]: +def codebase_client(repo_config: RepositoryConfig) -> Generator[CodebaseClient, None, None]: sb_client = CodebaseClient(repo_config=repo_config, port=get_free_port()) sb_client.runner = Mock() yield sb_client diff --git a/tests/shared/skills/verify_skill_output.py b/tests/shared/skills/verify_skill_output.py index 00e30c3e0..32eca65bb 100644 --- a/tests/shared/skills/verify_skill_output.py +++ b/tests/shared/skills/verify_skill_output.py @@ -50,7 +50,7 @@ def verify_skill_output(codebase: Codebase, skill, test_case, get_diff, snapshot elif test_case.graph: # I want to save test_case graph locally and compare it with the graph generated by the skill # Read the generated graph JSON - graph_json = open(f"{codebase.op.base_dir}/codegen-graphviz/graph.json").read() + graph_json = open(f"{codebase.op.repo_config.base_dir}/codegen-graphviz/graph.json").read() # Compare with snapshot snapshot.assert_match(graph_json, f"{skill.name}_{test_case.name or 'unnamed'}.json") diff --git a/tests/unit/codegen/runner/sandbox/conftest.py b/tests/unit/codegen/runner/sandbox/conftest.py index efc4913cf..a9a5c98cc 100644 --- a/tests/unit/codegen/runner/sandbox/conftest.py +++ b/tests/unit/codegen/runner/sandbox/conftest.py @@ -13,7 +13,7 @@ @pytest.fixture def codebase(tmpdir) -> Codebase: - op = RepoOperator.create_from_files(repo_path=f"{tmpdir}/test-repo", files={"test.py": "a = 1"}, bot_commit=True) + op = RepoOperator.create_from_files(repo_path=f"{tmpdir}/test-repo", files={"test.py": "a = 1"}, bot_commit=True, programming_language=ProgrammingLanguage.PYTHON) projects = [ProjectConfig(repo_operator=op, programming_language=ProgrammingLanguage.PYTHON)] codebase = Codebase(projects=projects) return codebase