diff --git a/scripts/profiling/apis.py b/scripts/profiling/apis.py
deleted file mode 100644
index 24c6ff782..000000000
--- a/scripts/profiling/apis.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import base64
-import logging
-import pickle
-from pathlib import Path
-
-import networkx as nx
-from tabulate import tabulate
-
-from codegen.sdk.codebase.factory.get_dev_customer_codebase import get_codebase_codegen
-from codegen.shared.enums.programming_language import ProgrammingLanguage
-
-logging.basicConfig(level=logging.INFO)
-codegen = get_codebase_codegen("../codegen", ".")
-res = []
-
-# Create a directed graph
-G = nx.DiGraph()
-
-# Iterate through all files in the codebase
-for file in codegen.files:
- if "test" not in file.filepath:
- if file.ctx.repo_name == "codegen":
- color = "yellow"
- elif file.ctx.repo_name == "codegen-sdk":
- color = "red"
- if file.ctx.base_path == "codegen-git":
- color = "green"
- elif file.ctx.base_path == "codegen-utils":
- color = "purple"
- else:
- color = "yellow"
- if file.ctx.programming_language == ProgrammingLanguage.TYPESCRIPT:
- color = "blue"
-
- # Add the file as a node
- G.add_node(file.filepath, color=color)
-
- # Iterate through all imports in the file
- for imp in file.imports:
- if imp.from_file:
- # print("way")
- # Add an edge from the current file to the imported file
- G.add_edge(file.filepath, imp.from_file.filepath)
-
-for url, func in codegen.global_context.multigraph.api_definitions.items():
- usages = codegen.global_context.multigraph.usages.get(url, None)
- if usages:
- res.append((url, func.name, func.filepath, [usage.filepath + "::" + usage.name for usage in usages]))
- for usage in usages:
- G.add_edge(usage.filepath, func.filepath)
-print(tabulate(res, headers=["URL", "Function Name", "Filepath", "Usages"]))
-# Visualize the graph
-# codebase.visualize(G)
-
-# Print some basic statistics
-print(f"Number of files: {G.number_of_nodes()}")
-print(f"Number of import relationships: {G.number_of_edges()}")
-
-# Serialize the graph to a pickle
-graph_pickle = pickle.dumps(G)
-
-# Convert to base64
-graph_base64 = base64.b64encode(graph_pickle).decode("utf-8")
-
-print(f"Base64 encoded graph size: {len(graph_base64)} bytes š")
-print(f"Base64 string: {graph_base64}")
-Path("out.txt").write_text(graph_base64)
-print()
diff --git a/src/codegen/cli/auth/session.py b/src/codegen/cli/auth/session.py
index d06454330..8dfa2ebf3 100644
--- a/src/codegen/cli/auth/session.py
+++ b/src/codegen/cli/auth/session.py
@@ -7,7 +7,7 @@
from codegen.cli.git.repo import get_git_repo
from codegen.cli.rich.codeblocks import format_command
-from codegen.configs.constants import CODEGEN_DIR_NAME, ENV_FILENAME
+from codegen.configs.constants import CODEGEN_DIR_NAME
from codegen.configs.session_manager import session_manager
from codegen.configs.user_config import UserConfig
from codegen.git.repo_operator.local_git_repo import LocalGitRepo
@@ -30,7 +30,7 @@ def __init__(self, repo_path: Path, git_token: str | None = None) -> None:
self.repo_path = repo_path
self.local_git = LocalGitRepo(repo_path=repo_path)
self.codegen_dir = repo_path / CODEGEN_DIR_NAME
- self.config = UserConfig(env_filepath=repo_path / ENV_FILENAME)
+ self.config = UserConfig(root_path=repo_path)
self.config.secrets.github_token = git_token or self.config.secrets.github_token
self.existing = session_manager.get_session(repo_path) is not None
diff --git a/src/codegen/cli/commands/config/main.py b/src/codegen/cli/commands/config/main.py
index b4ec3f3d7..9fce6f9b4 100644
--- a/src/codegen/cli/commands/config/main.py
+++ b/src/codegen/cli/commands/config/main.py
@@ -4,7 +4,7 @@
import rich_click as click
from rich.table import Table
-from codegen.configs.constants import ENV_FILENAME, GLOBAL_ENV_FILE
+from codegen.configs.constants import ENV_FILENAME, GLOBAL_CONFIG_DIR
from codegen.configs.user_config import UserConfig
from codegen.shared.path import get_git_root_path
@@ -117,8 +117,8 @@ def set_command(key: str, value: str):
def _get_user_config() -> UserConfig:
if (project_root := get_git_root_path()) is None:
- env_filepath = GLOBAL_ENV_FILE
+ root_path = GLOBAL_CONFIG_DIR
else:
- env_filepath = project_root / ENV_FILENAME
+ root_path = project_root
- return UserConfig(env_filepath)
+ return UserConfig(root_path)
diff --git a/src/codegen/cli/commands/run/run_local.py b/src/codegen/cli/commands/run/run_local.py
index 4ca737dd1..0ae1a47cf 100644
--- a/src/codegen/cli/commands/run/run_local.py
+++ b/src/codegen/cli/commands/run/run_local.py
@@ -6,8 +6,8 @@
from codegen.cli.auth.session import CodegenSession
from codegen.cli.utils.function_finder import DecoratedFunction
+from codegen.configs.models.repository import RepositoryConfig
from codegen.git.repo_operator.repo_operator import RepoOperator
-from codegen.git.schemas.repo_config import RepoConfig
from codegen.git.utils.language import determine_project_language
from codegen.sdk.codebase.config import ProjectConfig
from codegen.sdk.core.codebase import Codebase
@@ -30,7 +30,7 @@ def parse_codebase(
codebase = Codebase(
projects=[
ProjectConfig(
- repo_operator=RepoOperator(repo_config=RepoConfig.from_repo_path(repo_path=repo_path)),
+ repo_operator=RepoOperator(repo_config=RepositoryConfig.from_path(path=repo_path)),
subdirectories=subdirectories,
programming_language=language or determine_project_language(repo_path),
)
diff --git a/src/codegen/cli/commands/start/main.py b/src/codegen/cli/commands/start/main.py
index 652e400e1..af5989630 100644
--- a/src/codegen/cli/commands/start/main.py
+++ b/src/codegen/cli/commands/start/main.py
@@ -10,9 +10,9 @@
from codegen.cli.commands.start.docker_container import DockerContainer
from codegen.cli.commands.start.docker_fleet import CODEGEN_RUNNER_IMAGE
+from codegen.configs.models.repository import RepositoryConfig
from codegen.configs.models.secrets import SecretsConfig
from codegen.git.repo_operator.local_git_repo import LocalGitRepo
-from codegen.git.schemas.repo_config import RepoConfig
from codegen.shared.network.port import get_free_port
_default_host = "0.0.0.0"
@@ -26,7 +26,7 @@
def start_command(port: int | None, detached: bool = False, skip_build: bool = False, force: bool = False) -> None:
"""Starts a local codegen server"""
repo_path = Path.cwd().resolve()
- repo_config = RepoConfig.from_repo_path(str(repo_path))
+ repo_config = LocalGitRepo(repo_path=repo_path).get_repo_config()
if (container := DockerContainer.get(repo_config.name)) is not None:
if force:
rich.print(f"[yellow]Removing existing runner {repo_config.name} to force restart[/yellow]")
@@ -50,7 +50,7 @@ def start_command(port: int | None, detached: bool = False, skip_build: bool = F
raise click.Abort()
-def _handle_existing_container(repo_config: RepoConfig, container: DockerContainer) -> None:
+def _handle_existing_container(repo_config: RepositoryConfig, container: DockerContainer) -> None:
if container.is_running():
rich.print(
Panel(
@@ -122,20 +122,21 @@ def _get_platform() -> str:
return "linux/amd64"
-def _run_docker_container(repo_config: RepoConfig, port: int, detached: bool) -> None:
+def _run_docker_container(repo_config: RepositoryConfig, port: int, detached: bool) -> None:
rich.print("[bold blue]Starting Docker container...[/bold blue]")
container_repo_path = f"/app/git/{repo_config.name}"
name_args = ["--name", f"{repo_config.name}"]
+ repo_path = Path(repo_config.path)
envvars = {
- "REPOSITORY_LANGUAGE": repo_config.language.value,
- "REPOSITORY_OWNER": LocalGitRepo(repo_config.repo_path).owner,
+ "REPOSITORY_LANGUAGE": repo_config.language,
+ "REPOSITORY_OWNER": LocalGitRepo(repo_path=repo_path).owner,
"REPOSITORY_PATH": container_repo_path,
- "GITHUB_TOKEN": SecretsConfig().github_token,
+ "GITHUB_TOKEN": SecretsConfig(root_path=repo_path).github_token,
"PYTHONUNBUFFERED": "1", # Ensure Python output is unbuffered
"CODEBASE_SYNC_ENABLED": "True",
}
envvars_args = [arg for k, v in envvars.items() for arg in ("--env", f"{k}={v}")]
- mount_args = ["-v", f"{repo_config.repo_path}:{container_repo_path}"]
+ mount_args = ["-v", f"{repo_config.path}:{container_repo_path}"]
entry_point = f"uv run --frozen uvicorn codegen.runner.servers.local_daemon:app --host {_default_host} --port {port}"
port_args = ["-p", f"{port}:{port}"]
detached_args = ["-d"] if detached else []
diff --git a/src/codegen/cli/mcp/resources/system_prompt.py b/src/codegen/cli/mcp/resources/system_prompt.py
index 9c7e23c6b..7da246c9a 100644
--- a/src/codegen/cli/mcp/resources/system_prompt.py
+++ b/src/codegen/cli/mcp/resources/system_prompt.py
@@ -1417,7 +1417,6 @@ def baz():
```python
from codegen import Codebase
from codegen.git.repo_operator.repo_operator import RepoOperator
-from codegen.git.schemas.repo_config import RepoConfig
from codegen.sdk.codebase.config import ProjectConfig
from codegen.shared.enums.programming_language import ProgrammingLanguage
@@ -1425,7 +1424,7 @@ def baz():
projects = [
ProjectConfig(
repo_operator=RepoOperator(
- repo_config=RepoConfig(name="codegen-sdk"),
+ repo_path="/tmp/codegen-sdk",
bot_commit=True
),
programming_language=ProgrammingLanguage.TYPESCRIPT,
diff --git a/src/codegen/configs/models/base_config.py b/src/codegen/configs/models/base_config.py
index 3a82223de..2b85a4efb 100644
--- a/src/codegen/configs/models/base_config.py
+++ b/src/codegen/configs/models/base_config.py
@@ -16,7 +16,8 @@ class BaseConfig(BaseSettings, ABC):
model_config = SettingsConfigDict(extra="ignore", case_sensitive=False)
- def __init__(self, prefix: str, env_filepath: Path | None = None, *args, **kwargs) -> None:
+ def __init__(self, prefix: str, root_path: Path | None = None, *args, **kwargs) -> None:
+ env_filepath = root_path / ENV_FILENAME if root_path else None
if env_filepath is None:
root_path = get_git_root_path()
if root_path is not None:
diff --git a/src/codegen/configs/models/repository.py b/src/codegen/configs/models/repository.py
index d4960c503..1aaeb170c 100644
--- a/src/codegen/configs/models/repository.py
+++ b/src/codegen/configs/models/repository.py
@@ -1,4 +1,6 @@
import os
+from pathlib import Path
+from typing import Self
from codegen.configs.models.base_config import BaseConfig
@@ -13,16 +15,15 @@ class RepositoryConfig(BaseConfig):
language: str | None = None
user_name: str | None = None
user_email: str | None = None
+ subdirectories: list[str] | None = None
+ base_path: str | None = None # root module of the parsed codebase
def __init__(self, prefix: str = "REPOSITORY", *args, **kwargs) -> None:
super().__init__(prefix=prefix, *args, **kwargs)
- def _initialize(
- self,
- ) -> None:
- """Initialize the repository config"""
- if self.path is None:
- self.path = os.getcwd()
+ @classmethod
+ def from_path(cls, path: str) -> Self:
+ return cls(root_path=Path(path), path=str(path))
@property
def base_dir(self) -> str:
diff --git a/src/codegen/configs/user_config.py b/src/codegen/configs/user_config.py
index ecebec4d1..cc3219dca 100644
--- a/src/codegen/configs/user_config.py
+++ b/src/codegen/configs/user_config.py
@@ -3,6 +3,7 @@
from pydantic import Field
+from codegen.configs.constants import ENV_FILENAME
from codegen.configs.models.codebase import CodebaseConfig
from codegen.configs.models.repository import RepositoryConfig
from codegen.configs.models.secrets import SecretsConfig
@@ -14,11 +15,11 @@ class UserConfig:
codebase: CodebaseConfig = Field(default_factory=CodebaseConfig)
secrets: SecretsConfig = Field(default_factory=SecretsConfig)
- def __init__(self, env_filepath: Path) -> None:
- self.env_filepath = env_filepath
- self.secrets = SecretsConfig(env_filepath=env_filepath)
- self.repository = RepositoryConfig(env_filepath=env_filepath)
- self.codebase = CodebaseConfig(env_filepath=env_filepath)
+ def __init__(self, root_path: Path) -> None:
+ self.env_filepath = root_path / ENV_FILENAME
+ self.secrets = SecretsConfig(root_path=root_path)
+ self.repository = RepositoryConfig(root_path=root_path)
+ self.codebase = CodebaseConfig(root_path=root_path)
def save(self) -> None:
"""Save configuration to the config file."""
diff --git a/src/codegen/extensions/events/modal/base.py b/src/codegen/extensions/events/modal/base.py
index 64bdf5b28..c7bc19457 100644
--- a/src/codegen/extensions/events/modal/base.py
+++ b/src/codegen/extensions/events/modal/base.py
@@ -8,7 +8,6 @@
from codegen.extensions.events.codegen_app import CodegenApp
from codegen.extensions.events.modal.request_util import fastapi_request_adapter
from codegen.git.clients.git_repo_client import GitRepoClient
-from codegen.git.schemas.repo_config import RepoConfig
logging.basicConfig(level=logging.INFO, force=True)
logger = logging.getLogger(__name__)
@@ -36,17 +35,11 @@ def get_event_handler_cls(self) -> modal.Cls:
raise NotImplementedError(msg)
async def handle_event(self, org: str, repo: str, provider: Literal["slack", "github", "linear"], request: Request):
- repo_config = RepoConfig(
- name=repo,
- full_name=f"{org}/{repo}",
- )
-
repo_snapshotdict = modal.Dict.from_name(self.snapshot_index_id, {}, create_if_missing=True)
-
last_snapshot_commit = repo_snapshotdict.get(f"{org}/{repo}", None)
if last_snapshot_commit is None:
- git_client = GitRepoClient(repo_config=repo_config, access_token=os.environ["GITHUB_ACCESS_TOKEN"])
+ git_client = GitRepoClient(repo_full_name=f"{org}/{repo}", access_token=os.environ["GITHUB_ACCESS_TOKEN"])
branch = git_client.get_branch_safe(git_client.default_branch)
last_snapshot_commit = branch.commit.sha if branch and branch.commit else None
@@ -76,15 +69,8 @@ def refresh_repository_snapshots(self, snapshot_index_id: str):
try:
# Parse the repository full name to get org and repo
org, repo = repo_full_name.split("/")
-
- # Create a RepoConfig for the repository
- repo_config = RepoConfig(
- name=repo,
- full_name=repo_full_name,
- )
-
# Initialize the GitRepoClient to fetch the latest commit
- git_client = GitRepoClient(repo_config=repo_config, access_token=os.environ["GITHUB_ACCESS_TOKEN"])
+ git_client = GitRepoClient(repo_full_name=repo_full_name, access_token=os.environ["GITHUB_ACCESS_TOKEN"])
# Get the default branch and its latest commit
branch = git_client.get_branch_safe(git_client.default_branch)
diff --git a/src/codegen/git/clients/git_repo_client.py b/src/codegen/git/clients/git_repo_client.py
index 79c8df2a4..c1441e4c6 100644
--- a/src/codegen/git/clients/git_repo_client.py
+++ b/src/codegen/git/clients/git_repo_client.py
@@ -16,7 +16,6 @@
from codegen.configs.models.secrets import SecretsConfig
from codegen.git.clients.github_client import GithubClient
-from codegen.git.schemas.repo_config import RepoConfig
from codegen.git.utils.format import format_comparison
from codegen.shared.logging.get_logger import get_logger
@@ -26,12 +25,12 @@
class GitRepoClient:
"""Wrapper around PyGithub's Remote Repository."""
- repo_config: RepoConfig
+ repo_full_name: str
gh_client: GithubClient
_repo: Repository
- def __init__(self, repo_config: RepoConfig, access_token: str | None = None) -> None:
- self.repo_config = repo_config
+ def __init__(self, repo_full_name: str, access_token: str | None = None) -> None:
+ self.repo_full_name = repo_full_name
self.gh_client = self._create_github_client(token=access_token or SecretsConfig().github_token)
self._repo = self._create_client()
@@ -39,9 +38,9 @@ def _create_github_client(self, token: str) -> GithubClient:
return GithubClient(token=token)
def _create_client(self) -> Repository:
- client = self.gh_client.get_repo_by_full_name(self.repo_config.full_name)
+ client = self.gh_client.get_repo_by_full_name(self.repo_full_name)
if not client:
- msg = f"Repo {self.repo_config.full_name} not found!"
+ msg = f"Repo {self.repo_full_name} not found!"
raise ValueError(msg)
return client
diff --git a/src/codegen/git/repo_operator/local_git_repo.py b/src/codegen/git/repo_operator/local_git_repo.py
index a5c4acea3..dc3c6a443 100644
--- a/src/codegen/git/repo_operator/local_git_repo.py
+++ b/src/codegen/git/repo_operator/local_git_repo.py
@@ -6,8 +6,8 @@
from git import Repo
from git.remote import Remote
+from codegen.configs.models.repository import RepositoryConfig
from codegen.git.clients.git_repo_client import GitRepoClient
-from codegen.git.schemas.repo_config import RepoConfig
from codegen.git.utils.language import determine_project_language
@@ -20,7 +20,11 @@ def __init__(self, repo_path: Path):
@cached_property
def git_cli(self) -> Repo:
- return Repo(self.repo_path)
+ if not os.path.exists(self.repo_path):
+ os.makedirs(self.repo_path)
+ return Repo.init(self.repo_path)
+ else:
+ return Repo(self.repo_path)
@cached_property
def name(self) -> str:
@@ -72,9 +76,7 @@ def user_email(self) -> str | None:
def get_language(self, access_token: str | None = None) -> str:
"""Returns the majority language of the repository"""
if access_token is not None:
- repo_config = RepoConfig.from_repo_path(repo_path=str(self.repo_path))
- repo_config.full_name = self.full_name
- remote_git = GitRepoClient(repo_config=repo_config, access_token=access_token)
+ remote_git = GitRepoClient(repo_full_name=self.full_name, access_token=access_token)
if (language := remote_git.repo.language) is not None:
return language.upper()
@@ -82,3 +84,12 @@ def get_language(self, access_token: str | None = None) -> str:
def has_remote(self) -> bool:
return bool(self.git_cli.remotes)
+
+ def get_repo_config(self, access_token: str | None = None, repo_config: RepositoryConfig | None = None) -> RepositoryConfig:
+ config = repo_config or RepositoryConfig()
+ config.path = config.path or str(self.repo_path)
+ config.owner = config.owner or self.owner
+ config.user_name = config.user_name or self.user_name
+ config.user_email = config.user_email or self.user_email
+ config.language = config.language or self.get_language(access_token=access_token).upper()
+ return config
diff --git a/src/codegen/git/repo_operator/repo_operator.py b/src/codegen/git/repo_operator/repo_operator.py
index 9f1f58007..dad10e62c 100644
--- a/src/codegen/git/repo_operator/repo_operator.py
+++ b/src/codegen/git/repo_operator/repo_operator.py
@@ -5,6 +5,7 @@
from collections.abc import Generator
from datetime import UTC, datetime
from functools import cached_property
+from pathlib import Path
from time import perf_counter
from typing import Self
@@ -16,17 +17,18 @@
from github.IssueComment import IssueComment
from github.PullRequest import PullRequest
+from codegen.configs.models.repository import RepositoryConfig
from codegen.configs.models.secrets import SecretsConfig
from codegen.git.clients.git_repo_client import GitRepoClient
from codegen.git.configs.constants import CODEGEN_BOT_EMAIL, CODEGEN_BOT_NAME
from codegen.git.repo_operator.local_git_repo import LocalGitRepo
from codegen.git.schemas.enums import CheckoutResult, FetchResult, SetupOption
-from codegen.git.schemas.repo_config import RepoConfig
from codegen.git.utils.clone import clone_or_pull_repo, clone_repo, pull_repo
-from codegen.git.utils.clone_url import add_access_token_to_url, get_authenticated_clone_url_for_repo_config, get_clone_url_for_repo_config, url_to_github
+from codegen.git.utils.clone_url import get_authenticated_clone_url_for_repo_config, get_clone_url_for_repo_config, url_to_github
from codegen.git.utils.codeowner_utils import create_codeowners_parser_for_repo
from codegen.git.utils.file_utils import create_files
from codegen.git.utils.remote_progress import CustomRemoteProgress
+from codegen.shared.enums.programming_language import ProgrammingLanguage
from codegen.shared.logging.get_logger import get_logger
from codegen.shared.performance.stopwatch_utils import stopwatch
from codegen.shared.performance.time_utils import humanize_duration
@@ -37,11 +39,10 @@
class RepoOperator:
"""A wrapper around GitPython to make it easier to interact with a repo."""
- repo_config: RepoConfig
- base_dir: str
- bot_commit: bool = True
- access_token: str | None = None
-
+ repo_config: RepositoryConfig
+ bot_commit: bool
+ access_token: str | None
+ respect_gitignore: bool
# lazy attributes
_codeowners_parser: CodeOwnersParser | None = None
_default_branch: str | None = None
@@ -50,17 +51,18 @@ class RepoOperator:
def __init__(
self,
- repo_config: RepoConfig,
+ repo_config: RepositoryConfig,
access_token: str | None = None,
bot_commit: bool = False,
+ respect_gitignore: bool = True,
setup_option: SetupOption | None = None,
shallow: bool | None = None,
) -> None:
- assert repo_config is not None
- self.repo_config = repo_config
- self.access_token = access_token or SecretsConfig().github_token
- self.base_dir = repo_config.base_dir
+ self.access_token = access_token or SecretsConfig(root_path=Path(repo_config.path)).github_token
+ self._local_git_repo = LocalGitRepo(repo_path=Path(repo_config.path))
+ self.repo_config = self._local_git_repo.get_repo_config(self.access_token, repo_config)
self.bot_commit = bot_commit
+ self.respect_gitignore = respect_gitignore
if setup_option:
if shallow is not None:
@@ -68,13 +70,6 @@ def __init__(
else:
self.setup_repo_dir(setup_option=setup_option)
- else:
- os.makedirs(self.repo_path, exist_ok=True)
- GitCLI.init(self.repo_path)
- self._local_git_repo = LocalGitRepo(repo_path=repo_config.repo_path)
- if self.repo_config.full_name is None:
- self.repo_config.full_name = self._local_git_repo.full_name
-
####################################################################################################################
# PROPERTIES
####################################################################################################################
@@ -85,7 +80,11 @@ def repo_name(self) -> str:
@property
def repo_path(self) -> str:
- return os.path.join(self.base_dir, self.repo_name)
+ return self.repo_config.path
+
+ @property
+ def repo_full_name(self) -> str:
+ return self.repo_config.full_name or f"{self._local_git_repo.owner}/{self.repo_name}"
@property
def remote_git_repo(self) -> GitRepoClient:
@@ -94,18 +93,18 @@ def remote_git_repo(self) -> GitRepoClient:
raise ValueError(msg)
if not self._remote_git_repo:
- self._remote_git_repo = GitRepoClient(self.repo_config, access_token=self.access_token)
+ self._remote_git_repo = GitRepoClient(self.repo_full_name, access_token=self.access_token)
return self._remote_git_repo
@property
def clone_url(self) -> str:
if self.access_token:
return get_authenticated_clone_url_for_repo_config(repo=self.repo_config, token=self.access_token)
- return f"https://github.com/{self.repo_config.full_name}.git"
+ return get_clone_url_for_repo_config(repo=self.repo_config)
@property
def viz_path(self) -> str:
- return os.path.join(self.base_dir, "codegen-graphviz")
+ return os.path.join(self.repo_config.base_dir, "codegen-graphviz")
@property
def viz_file_path(self) -> str:
@@ -203,8 +202,7 @@ def codeowners_parser(self) -> CodeOwnersParser | None:
# SET UP
####################################################################################################################
def setup_repo_dir(self, setup_option: SetupOption = SetupOption.PULL_OR_CLONE, shallow: bool = True) -> None:
- os.makedirs(self.base_dir, exist_ok=True)
- os.chdir(self.base_dir)
+ os.chdir(self.repo_config.base_dir)
if setup_option is SetupOption.CLONE:
# if repo exists delete, then clone, else clone
clone_repo(shallow=shallow, repo_path=self.repo_path, clone_url=self.clone_url)
@@ -278,8 +276,6 @@ def clone_repo(self, shallow: bool = True) -> None:
def clone_or_pull_repo(self, shallow: bool = True) -> None:
"""If repo exists, pulls changes. otherwise, clones the repo."""
# TODO(CG-7804): if repo is not valid we should delete it and re-clone. maybe we can create a pull_repo util + use the existing clone_repo util
- if self.repo_exists():
- self.clean_repo()
clone_or_pull_repo(repo_path=self.repo_path, clone_url=self.clone_url, shallow=shallow)
####################################################################################################################
@@ -577,7 +573,7 @@ def delete_file(self, path: str) -> None:
def get_filepaths_for_repo(self, ignore_list):
# Get list of files to iterate over based on gitignore setting
- if self.repo_config.respect_gitignore:
+ if self.respect_gitignore:
# ls-file flags:
# -c: show cached files
# -o: show other / untracked files
@@ -808,7 +804,7 @@ def get_pull_request(self, pr_number: int) -> PullRequest | None:
# CLASS METHODS
####################################################################################################################
@classmethod
- def create_from_files(cls, repo_path: str, files: dict[str, str], bot_commit: bool = True) -> Self:
+ def create_from_files(cls, repo_path: str, files: dict[str, str], programming_language: ProgrammingLanguage, bot_commit: bool = True) -> Self:
"""Used when you want to create a directory from a set of files and then create a RepoOperator that points to that directory.
Use cases:
- Unit testing
@@ -824,7 +820,9 @@ def create_from_files(cls, repo_path: str, files: dict[str, str], bot_commit: bo
create_files(base_dir=repo_path, files=files)
# Step 2: Init git repo
- op = cls(repo_config=RepoConfig.from_repo_path(repo_path), bot_commit=bot_commit)
+ repo_config = RepositoryConfig.from_path(path=repo_path)
+ repo_config.language = programming_language
+ op = cls(repo_config=repo_config, bot_commit=bot_commit)
if op.stage_and_commit_all_changes("[Codegen] initial commit"):
op.checkout_branch(None, create_if_missing=True)
return op
@@ -839,60 +837,14 @@ def create_from_commit(cls, repo_path: str, commit: str, url: str, access_token:
url (str): Git URL of the repository
access_token (str | None): Optional GitHub API key for operations that need GitHub access
"""
- op = cls(repo_config=RepoConfig.from_repo_path(repo_path, full_name=full_name), bot_commit=False, access_token=access_token)
+ repo_config = RepositoryConfig.from_path(path=repo_path)
+ if full_name:
+ repo_config.owner = full_name.split("/")[0]
+ op = cls(repo_config=repo_config, bot_commit=False, access_token=access_token)
op.discard_changes()
if op.get_active_branch_or_commit() != commit:
op.create_remote("origin", url)
op.git_cli.remotes["origin"].fetch(commit, depth=1)
op.checkout_commit(commit)
return op
-
- @classmethod
- def create_from_repo(cls, repo_path: str, url: str, access_token: str | None = None) -> Self | None:
- """Create a fresh clone of a repository or use existing one if up to date.
-
- Args:
- repo_path (str): Path where the repo should be cloned
- url (str): Git URL of the repository
- access_token (str | None): Optional GitHub API key for operations that need GitHub access
- """
- access_token = access_token or SecretsConfig().github_token
- if access_token:
- url = add_access_token_to_url(url=url, token=access_token)
-
- # Check if repo already exists
- if os.path.exists(repo_path):
- try:
- # Try to initialize git repo from existing path
- git_cli = GitCLI(repo_path)
- # Check if it has our remote URL
- if any(remote.url == url for remote in git_cli.remotes):
- # Fetch to check for updates
- git_cli.remotes.origin.fetch()
- # Get current and remote HEADs
- local_head = git_cli.head.commit
- remote_head = git_cli.remotes.origin.refs[git_cli.active_branch.name].commit
- # If up to date, use existing repo
- if local_head.hexsha == remote_head.hexsha:
- return cls(repo_config=RepoConfig.from_repo_path(repo_path), bot_commit=False, access_token=access_token)
- except Exception:
- # If any git operations fail, fallback to fresh clone
- pass
-
- # If we get here, repo exists but is not up to date or valid
- # Remove the existing directory to do a fresh clone
- import shutil
-
- shutil.rmtree(repo_path)
- try:
- # Clone the repository
- GitCLI.clone_from(url=url, to_path=repo_path, depth=1)
-
- # Initialize with the cloned repo
- git_cli = GitCLI(repo_path)
- except (GitCommandError, ValueError) as e:
- logger.exception("Failed to initialize Git repository:")
- logger.exception("Please authenticate with a valid token and ensure the repository is properly initialized.")
- return None
- return cls(repo_config=RepoConfig.from_repo_path(repo_path), bot_commit=False, access_token=access_token)
diff --git a/src/codegen/git/schemas/repo_config.py b/src/codegen/git/schemas/repo_config.py
deleted file mode 100644
index 0e54b8362..000000000
--- a/src/codegen/git/schemas/repo_config.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import os.path
-from pathlib import Path
-
-from pydantic import BaseModel
-
-from codegen.configs.models.repository import RepositoryConfig
-from codegen.git.schemas.enums import RepoVisibility
-from codegen.shared.enums.programming_language import ProgrammingLanguage
-from codegen.shared.logging.get_logger import get_logger
-
-logger = get_logger(__name__)
-
-
-class RepoConfig(BaseModel):
- """All the information about the repo needed to build a codebase"""
-
- name: str
- full_name: str | None = None
- visibility: RepoVisibility | None = None
-
- # Codebase fields
- base_dir: str = "/tmp" # parent directory of the git repo
- language: ProgrammingLanguage = ProgrammingLanguage.PYTHON
- respect_gitignore: bool = True
- base_path: str | None = None # root directory of the codebase within the repo
- subdirectories: list[str] | None = None
-
- @classmethod
- def from_envs(cls) -> "RepoConfig":
- default_repo_config = RepositoryConfig()
- return RepoConfig(
- name=default_repo_config.name,
- full_name=default_repo_config.full_name,
- base_dir=os.path.dirname(default_repo_config.path),
- language=ProgrammingLanguage(default_repo_config.language.upper()),
- )
-
- @classmethod
- def from_repo_path(cls, repo_path: str, full_name: str | None = None) -> "RepoConfig":
- name = os.path.basename(repo_path)
- base_dir = os.path.dirname(repo_path)
- return cls(name=name, base_dir=base_dir, full_name=full_name)
-
- @property
- def repo_path(self) -> Path:
- return Path(f"{self.base_dir}/{self.name}")
-
- @property
- def organization_name(self) -> str | None:
- if self.full_name is not None:
- return self.full_name.split("/")[0]
-
- return None
diff --git a/src/codegen/git/utils/clone_url.py b/src/codegen/git/utils/clone_url.py
index 21cd80cfb..1d3636b83 100644
--- a/src/codegen/git/utils/clone_url.py
+++ b/src/codegen/git/utils/clone_url.py
@@ -1,6 +1,6 @@
from urllib.parse import urlparse
-from codegen.git.schemas.repo_config import RepoConfig
+from codegen.configs.models.repository import RepositoryConfig
# TODO: move out doesn't belong here
@@ -9,11 +9,11 @@ def url_to_github(url: str, branch: str) -> str:
return f"{clone_url}/blob/{branch}"
-def get_clone_url_for_repo_config(repo_config: RepoConfig) -> str:
+def get_clone_url_for_repo_config(repo_config: RepositoryConfig) -> str:
return f"https://github.com/{repo_config.full_name}.git"
-def get_authenticated_clone_url_for_repo_config(repo: RepoConfig, token: str) -> str:
+def get_authenticated_clone_url_for_repo_config(repo: RepositoryConfig, token: str) -> str:
git_url = get_clone_url_for_repo_config(repo)
return add_access_token_to_url(git_url, token)
diff --git a/src/codegen/git/utils/language.py b/src/codegen/git/utils/language.py
index 551ac4212..c0191040a 100644
--- a/src/codegen/git/utils/language.py
+++ b/src/codegen/git/utils/language.py
@@ -2,6 +2,7 @@
from pathlib import Path
from typing import Literal
+from codegen.configs.models.repository import RepositoryConfig
from codegen.git.utils.file_utils import split_git_path
from codegen.shared.enums.programming_language import ProgrammingLanguage
from codegen.shared.logging.get_logger import get_logger
@@ -108,7 +109,6 @@ def _determine_language_by_git_file_count(folder_path: str) -> ProgrammingLangua
or if less than MIN_LANGUAGE_RATIO of files match the dominant language
"""
from codegen.git.repo_operator.repo_operator import RepoOperator
- from codegen.git.schemas.repo_config import RepoConfig
from codegen.sdk.codebase.codebase_context import GLOBAL_FILE_IGNORE_LIST
from codegen.sdk.python import PyFile
from codegen.sdk.typescript.file import TSFile
@@ -129,7 +129,7 @@ def _determine_language_by_git_file_count(folder_path: str) -> ProgrammingLangua
# Initiate RepoOperator
git_root, base_path = split_git_path(folder_path)
- repo_config = RepoConfig.from_repo_path(repo_path=git_root)
+ repo_config = RepositoryConfig.from_path(path=git_root)
repo_operator = RepoOperator(repo_config=repo_config)
# Walk through the directory
diff --git a/src/codegen/runner/clients/codebase_client.py b/src/codegen/runner/clients/codebase_client.py
index 7b4bf16ce..c059c5690 100644
--- a/src/codegen/runner/clients/codebase_client.py
+++ b/src/codegen/runner/clients/codebase_client.py
@@ -3,9 +3,10 @@
import os
import subprocess
import time
+from pathlib import Path
+from codegen.configs.models.repository import RepositoryConfig
from codegen.configs.models.secrets import SecretsConfig
-from codegen.git.schemas.repo_config import RepoConfig
from codegen.runner.clients.client import Client
from codegen.runner.models.apis import SANDBOX_SERVER_PORT
from codegen.shared.logging.get_logger import get_logger
@@ -21,9 +22,9 @@
class CodebaseClient(Client):
"""Client for interacting with the locally hosted sandbox server."""
- repo_config: RepoConfig
+ repo_config: RepositoryConfig
- def __init__(self, repo_config: RepoConfig, host: str = "127.0.0.1", port: int = SANDBOX_SERVER_PORT, server_path: str = RUNNER_SERVER_PATH):
+ def __init__(self, repo_config: RepositoryConfig, host: str = "127.0.0.1", port: int = SANDBOX_SERVER_PORT, server_path: str = RUNNER_SERVER_PATH):
super().__init__(host=host, port=port)
self.repo_config = repo_config
self._process = None
@@ -39,6 +40,8 @@ def _start_server(self, server_path: str) -> None:
"""Start the FastAPI server in a subprocess"""
envs = self._get_envs()
logger.info(f"Starting local server on {self.base_url} with envvars: {envs}")
+ for key, value in envs.items():
+ logger.info(f"{key}={value}")
self._process = subprocess.Popen(
[
@@ -66,10 +69,10 @@ def _wait_for_server(self, timeout: int = 30, interval: float = 0.3) -> None:
def _get_envs(self) -> dict:
envs = os.environ.copy()
codebase_envs = {
- "REPOSITORY_PATH": str(self.repo_config.repo_path),
- "REPOSITORY_OWNER": self.repo_config.organization_name,
- "REPOSITORY_LANGUAGE": self.repo_config.language.value,
- "GITHUB_TOKEN": SecretsConfig().github_token,
+ "REPOSITORY_PATH": self.repo_config.path,
+ "REPOSITORY_OWNER": self.repo_config.owner,
+ "REPOSITORY_LANGUAGE": self.repo_config.language,
+ "GITHUB_TOKEN": SecretsConfig(root_path=Path(self.repo_config.path)).github_token,
}
envs.update(codebase_envs)
@@ -77,7 +80,5 @@ def _get_envs(self) -> dict:
if __name__ == "__main__":
- test_config = RepoConfig.from_repo_path("/Users/caroljung/git/codegen/codegen-agi")
- test_config.full_name = "codegen-sh/codegen-agi"
- client = CodebaseClient(test_config)
+ client = CodebaseClient(repo_config=RepositoryConfig.from_path("/Users/caroljung/git/codegen/codegen-agi"))
print(client.is_running())
diff --git a/src/codegen/runner/sandbox/runner.py b/src/codegen/runner/sandbox/runner.py
index 4a86bc618..fc0438f8f 100644
--- a/src/codegen/runner/sandbox/runner.py
+++ b/src/codegen/runner/sandbox/runner.py
@@ -1,15 +1,16 @@
import sys
from codegen.configs.models.codebase import CodebaseConfig
+from codegen.configs.models.repository import RepositoryConfig
from codegen.git.repo_operator.repo_operator import RepoOperator
from codegen.git.schemas.enums import SetupOption
-from codegen.git.schemas.repo_config import RepoConfig
from codegen.runner.models.apis import CreateBranchRequest, CreateBranchResponse, GetDiffRequest, GetDiffResponse
from codegen.runner.sandbox.executor import SandboxExecutor
from codegen.sdk.codebase.config import ProjectConfig, SessionOptions
from codegen.sdk.codebase.factory.codebase_factory import CodebaseType
from codegen.sdk.core.codebase import Codebase
from codegen.shared.compilation.string_to_code import create_execute_function_from_codeblock
+from codegen.shared.enums.programming_language import ProgrammingLanguage
from codegen.shared.logging.get_logger import get_logger
logger = get_logger(__name__)
@@ -19,16 +20,16 @@ class SandboxRunner:
"""Responsible for orchestrating the lifecycle of a warmed sandbox"""
# =====[ __init__ instance attributes ]=====
- repo: RepoConfig
+ repo: RepositoryConfig
op: RepoOperator | None
# =====[ computed instance attributes ]=====
codebase: CodebaseType
executor: SandboxExecutor
- def __init__(self, repo_config: RepoConfig, op: RepoOperator | None = None) -> None:
- self.repo = repo_config
- self.op = op or RepoOperator(repo_config=self.repo, setup_option=SetupOption.PULL_OR_CLONE, bot_commit=True)
+ def __init__(self, repo_config: RepositoryConfig, op: RepoOperator | None = None) -> None:
+ self.op = op or RepoOperator(repo_config=repo_config, setup_option=SetupOption.PULL_OR_CLONE, bot_commit=True)
+ self.repo = self.op.repo_config
async def warmup(self, codebase_config: CodebaseConfig | None = None) -> None:
"""Warms up this runner by cloning the repo and parsing the graph."""
@@ -40,7 +41,8 @@ async def warmup(self, codebase_config: CodebaseConfig | None = None) -> None:
async def _build_graph(self, codebase_config: CodebaseConfig | None = None) -> Codebase:
logger.info("> Building graph...")
- projects = [ProjectConfig(programming_language=self.repo.language, repo_operator=self.op, base_path=self.repo.base_path, subdirectories=self.repo.subdirectories)]
+ programming_language = ProgrammingLanguage(self.repo.language.upper())
+ projects = [ProjectConfig(programming_language=programming_language, repo_operator=self.op, base_path=self.repo.base_path, subdirectories=self.repo.subdirectories)]
return Codebase(projects=projects, config=codebase_config)
async def get_diff(self, request: GetDiffRequest) -> GetDiffResponse:
diff --git a/src/codegen/runner/sandbox/server.py b/src/codegen/runner/sandbox/server.py
index a6a346fcf..3f57f9718 100644
--- a/src/codegen/runner/sandbox/server.py
+++ b/src/codegen/runner/sandbox/server.py
@@ -1,10 +1,8 @@
-import os
from contextlib import asynccontextmanager
from fastapi import FastAPI
from codegen.configs.models.repository import RepositoryConfig
-from codegen.git.schemas.repo_config import RepoConfig
from codegen.runner.enums.warmup_state import WarmupState
from codegen.runner.models.apis import (
BRANCH_ENDPOINT,
@@ -17,7 +15,6 @@
)
from codegen.runner.sandbox.middlewares import CodemodRunMiddleware
from codegen.runner.sandbox.runner import SandboxRunner
-from codegen.shared.enums.programming_language import ProgrammingLanguage
from codegen.shared.logging.get_logger import get_logger
logger = get_logger(__name__)
@@ -31,18 +28,11 @@ async def lifespan(server: FastAPI):
global server_info
global runner
- default_repo_config = RepositoryConfig()
- repo_name = default_repo_config.full_name or default_repo_config.name
- server_info = ServerInfo(repo_name=repo_name)
try:
- logger.info(f"Starting up sandbox fastapi server for repo_name={repo_name}")
- repo_config = RepoConfig(
- name=default_repo_config.name,
- full_name=default_repo_config.full_name,
- base_dir=os.path.dirname(default_repo_config.path),
- language=ProgrammingLanguage(default_repo_config.language.upper()),
- )
+ repo_config = RepositoryConfig()
runner = SandboxRunner(repo_config=repo_config)
+ server_info = ServerInfo(repo_name=runner.repo.full_name or runner.repo.name)
+ logger.info(f"Starting up sandbox fastapi server for repo_name={server_info.repo_name}")
server_info.warmup_state = WarmupState.PENDING
await runner.warmup()
server_info.synced_commit = runner.op.git_cli.head.commit.hexsha
diff --git a/src/codegen/runner/servers/local_daemon.py b/src/codegen/runner/servers/local_daemon.py
index 1d24006ae..00d8ec12e 100644
--- a/src/codegen/runner/servers/local_daemon.py
+++ b/src/codegen/runner/servers/local_daemon.py
@@ -4,10 +4,10 @@
from fastapi import FastAPI
from codegen.configs.models.codebase import DefaultCodebaseConfig
+from codegen.configs.models.repository import RepositoryConfig
from codegen.git.configs.constants import CODEGEN_BOT_EMAIL, CODEGEN_BOT_NAME
from codegen.git.repo_operator.repo_operator import RepoOperator
from codegen.git.schemas.enums import SetupOption
-from codegen.git.schemas.repo_config import RepoConfig
from codegen.runner.enums.warmup_state import WarmupState
from codegen.runner.models.apis import (
RUN_FUNCTION_ENDPOINT,
@@ -37,18 +37,17 @@ async def lifespan(server: FastAPI):
global runner
try:
- repo_config = RepoConfig.from_envs()
- server_info = ServerInfo(repo_name=repo_config.full_name or repo_config.name)
-
- # Set the bot email and username
+ repo_config = RepositoryConfig()
op = RepoOperator(repo_config=repo_config, setup_option=SetupOption.SKIP, bot_commit=True)
runner = SandboxRunner(repo_config=repo_config, op=op)
+ server_info = ServerInfo(repo_name=runner.repo.full_name or runner.repo.name)
+
logger.info(f"Configuring git user config to {CODEGEN_BOT_EMAIL} and {CODEGEN_BOT_NAME}")
runner.op.git_cli.git.config("user.email", CODEGEN_BOT_EMAIL)
runner.op.git_cli.git.config("user.name", CODEGEN_BOT_NAME)
# Parse the codebase with sync enabled
- logger.info(f"Starting up fastapi server for repo_name={repo_config.name}")
+ logger.info(f"Starting up fastapi server for repo_name={server_info.repo_name}")
server_info.warmup_state = WarmupState.PENDING
codebase_config = DefaultCodebaseConfig.model_copy(update={"sync_enabled": True})
await runner.warmup(codebase_config=codebase_config)
diff --git a/src/codegen/sdk/code_generation/current_code_codebase.py b/src/codegen/sdk/code_generation/current_code_codebase.py
index bfcee2232..9b418861f 100644
--- a/src/codegen/sdk/code_generation/current_code_codebase.py
+++ b/src/codegen/sdk/code_generation/current_code_codebase.py
@@ -5,9 +5,9 @@
from typing import TypedDict
from codegen.configs.models.codebase import CodebaseConfig
+from codegen.configs.models.repository import RepositoryConfig
from codegen.configs.models.secrets import SecretsConfig
from codegen.git.repo_operator.repo_operator import RepoOperator
-from codegen.git.schemas.repo_config import RepoConfig
from codegen.sdk.codebase.config import ProjectConfig
from codegen.sdk.core.codebase import Codebase, CodebaseType
from codegen.shared.decorators.docs import DocumentedObject, apidoc_objects, no_apidoc_objects, py_apidoc_objects, ts_apidoc_objects
@@ -41,11 +41,10 @@ def get_current_code_codebase(config: CodebaseConfig | None = None, secrets: Sec
base_dir = get_codegen_codebase_base_path()
logger.info(f"Creating codebase from repo at: {codegen_repo_path} with base_path {base_dir}")
- repo_config = RepoConfig.from_repo_path(codegen_repo_path)
- repo_config.respect_gitignore = False
- op = RepoOperator(repo_config=repo_config, bot_commit=False)
+ repo_config = RepositoryConfig.from_path(path=codegen_repo_path)
+ op = RepoOperator(repo_config=repo_config, bot_commit=False, respect_gitignore=False)
- config = (config or CodebaseConfig()).model_copy(update={"base_path": base_dir})
+ config = config or CodebaseConfig()
projects = [ProjectConfig(repo_operator=op, programming_language=ProgrammingLanguage.PYTHON, subdirectories=subdirectories, base_path=base_dir)]
codebase = Codebase(projects=projects, config=config, secrets=secrets)
return codebase
diff --git a/src/codegen/sdk/codebase/config.py b/src/codegen/sdk/codebase/config.py
index 25f3c2e0e..e9898c7df 100644
--- a/src/codegen/sdk/codebase/config.py
+++ b/src/codegen/sdk/codebase/config.py
@@ -6,8 +6,8 @@
from pydantic.fields import Field
from codegen.configs.models.codebase import DefaultCodebaseConfig
+from codegen.configs.models.repository import RepositoryConfig
from codegen.git.repo_operator.repo_operator import RepoOperator
-from codegen.git.schemas.repo_config import RepoConfig
from codegen.git.utils.file_utils import split_git_path
from codegen.git.utils.language import determine_project_language
from codegen.shared.enums.programming_language import ProgrammingLanguage
@@ -35,7 +35,7 @@ class ProjectConfig(BaseModel):
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
repo_operator: RepoOperator
- # TODO: clean up these fields. Duplicated across RepoConfig and CodebaseContext
+ # TODO: clean up these fields. Duplicated across RepositoryConfig and CodebaseContext
base_path: str | None = None
subdirectories: list[str] | None = None
programming_language: ProgrammingLanguage = ProgrammingLanguage.PYTHON
@@ -47,9 +47,7 @@ def from_path(cls, path: str, programming_language: ProgrammingLanguage | None =
git_root, base_path = split_git_path(repo_path)
subdirectories = [base_path] if base_path else None
programming_language = programming_language or determine_project_language(repo_path)
- repo_config = RepoConfig.from_repo_path(repo_path=git_root)
- repo_config.language = programming_language
- repo_config.subdirectories = subdirectories
+ repo_config = RepositoryConfig.from_path(path=repo_path, language=programming_language.value, subdirectories=subdirectories, base_path=base_path)
# Create main project
return cls(
repo_operator=RepoOperator(repo_config=repo_config),
@@ -60,9 +58,13 @@ def from_path(cls, path: str, programming_language: ProgrammingLanguage | None =
@classmethod
def from_repo_operator(cls, repo_operator: RepoOperator, programming_language: ProgrammingLanguage | None = None, base_path: str | None = None) -> Self:
+ language = programming_language or determine_project_language(repo_operator.repo_path)
+ repo_operator.repo_config.language = language.value
+ repo_operator.repo_config.subdirectories = [base_path] if base_path else None
+ repo_operator.repo_config.base_path = base_path
return cls(
repo_operator=repo_operator,
- programming_language=programming_language or determine_project_language(repo_operator.repo_path),
+ programming_language=language,
base_path=base_path,
subdirectories=[base_path] if base_path else None,
)
diff --git a/src/codegen/sdk/codebase/factory/codebase_factory.py b/src/codegen/sdk/codebase/factory/codebase_factory.py
index 009992311..3c753a563 100644
--- a/src/codegen/sdk/codebase/factory/codebase_factory.py
+++ b/src/codegen/sdk/codebase/factory/codebase_factory.py
@@ -23,6 +23,6 @@ def get_codebase_from_files(
config: CodebaseConfig | None = None,
secrets: SecretsConfig | None = None,
) -> CodebaseType:
- op = RepoOperator.create_from_files(repo_path=repo_path, files=files, bot_commit=bot_commit)
+ op = RepoOperator.create_from_files(repo_path=repo_path, files=files, bot_commit=bot_commit, programming_language=programming_language)
projects = [ProjectConfig(repo_operator=op, programming_language=programming_language)]
return Codebase(projects=projects, config=config, secrets=secrets)
diff --git a/src/codegen/sdk/codebase/factory/get_session.py b/src/codegen/sdk/codebase/factory/get_session.py
index 189eec6e6..089923427 100644
--- a/src/codegen/sdk/codebase/factory/get_session.py
+++ b/src/codegen/sdk/codebase/factory/get_session.py
@@ -112,7 +112,7 @@ def get_codebase_graph_session(
session_options: SessionOptions = SessionOptions(),
) -> Generator[CodebaseContext, None, None]:
"""Gives you a Codebase2 operating on the files you provided as a dict"""
- op = RepoOperator.create_from_files(repo_path=tmpdir, files=files)
+ op = RepoOperator.create_from_files(repo_path=tmpdir, files=files, programming_language=programming_language)
projects = [ProjectConfig(repo_operator=op, programming_language=programming_language)]
graph = CodebaseContext(projects=projects, config=TestFlags)
with graph.session(sync_graph=sync_graph, session_options=session_options):
diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py
index 735a8b5ea..36ca0664d 100644
--- a/src/codegen/sdk/core/codebase.py
+++ b/src/codegen/sdk/core/codebase.py
@@ -23,6 +23,7 @@
from typing_extensions import TypeVar, deprecated
from codegen.configs.models.codebase import CodebaseConfig
+from codegen.configs.models.repository import RepositoryConfig
from codegen.configs.models.secrets import SecretsConfig
from codegen.git.repo_operator.repo_operator import RepoOperator
from codegen.git.schemas.enums import CheckoutResult, SetupOption
@@ -1342,13 +1343,19 @@ def from_repo(
# Setup repo path and URL
repo_path = os.path.join(tmp_dir, repo)
repo_url = f"https://github.com/{repo_full_name}.git"
+ repo_config = RepositoryConfig(
+ root_path=Path(repo_path),
+ path=repo_path,
+ owner=owner,
+ language=language.value if language else None,
+ )
logger.info(f"Will clone {repo_url} to {repo_path}")
try:
# Use RepoOperator to fetch the repository
logger.info("Cloning repository...")
if commit is None:
- repo_operator = RepoOperator.create_from_repo(repo_path=repo_path, url=repo_url)
+ repo_operator = RepoOperator(repo_config=repo_config, setup_option=SetupOption.PULL_OR_CLONE)
else:
# Ensure the operator can handle remote operations
access_token = secrets.github_token if secrets else None
diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py
index 16f7f876b..9e8ebd845 100644
--- a/src/codegen/sdk/python/import_resolution.py
+++ b/src/codegen/sdk/python/import_resolution.py
@@ -15,12 +15,12 @@
from tree_sitter import Node as TSNode
from codegen.sdk.codebase.codebase_context import CodebaseContext
+ from codegen.sdk.core.file import SourceFile
from codegen.sdk.core.interfaces.editable import Editable
from codegen.sdk.core.interfaces.exportable import Exportable
from codegen.sdk.core.node_id_factory import NodeId
from codegen.sdk.core.statements.import_statement import ImportStatement
from codegen.sdk.python.file import PyFile
- from src.codegen.sdk.core.file import SourceFile
logger = get_logger(__name__)
diff --git a/src/codegen/sdk/system-prompt.txt b/src/codegen/sdk/system-prompt.txt
index e5b616c1b..8407ff04f 100644
--- a/src/codegen/sdk/system-prompt.txt
+++ b/src/codegen/sdk/system-prompt.txt
@@ -624,7 +624,7 @@ Codegen creates a custom Python environment in `.codegen/.venv`. Configure your
```bash
.codegen/.venv/bin/python
```
-
+
Alternatively, create a `.vscode/settings.json`:
```json
{
@@ -646,7 +646,7 @@ Codegen creates a custom Python environment in `.codegen/.venv`. Configure your
.codegen/.venv/bin/python
```
-
+
@@ -1191,8 +1191,8 @@ iconType: "solid"
- Yes - [by design](/introduction/guiding-principles#python-first-composability).
-
+ Yes - [by design](/introduction/guiding-principles#python-first-composability).
+
Codegen works like any other python package. It works alongside your IDE, version control system, and other development tools.
- Currently, the codebase object can only parse source code files of one language at a time. This means that if you want to work with both Python and TypeScript files, you will need to create two separate codebase objects.
+ Currently, the codebase object can only parse source code files of one language at a time. This means that if you want to work with both Python and TypeScript files, you will need to create two separate codebase objects.
## Accessing Code
@@ -2968,7 +2966,7 @@ for module, imports in module_imports.items():
Always check if imports resolve to external modules before modification to avoid breaking third-party package imports.
-
+
## Import Statements vs Imports
@@ -3170,7 +3168,7 @@ for exp in file.exports:
# Get original and current symbols
current = exp.exported_symbol
original = exp.resolved_symbol
-
+
print(f"Re-exporting {original.name} from {exp.from_file.filepath}")
print(f"Through: {' -> '.join(e.file.filepath for e in exp.export_chain)}")
```
@@ -3220,7 +3218,7 @@ for from_file, exports in file_exports.items():
When managing exports, consider the impact on your module's public API. Not all symbols that can be exported should be exported.
-
+
---
title: "Inheritable Behaviors"
@@ -3710,9 +3708,9 @@ If `A` depends on `B`, then `B` is used by `A`. This relationship is tracked in
flowchart LR
B(BaseClass)
-
-
-
+
+
+
A(MyClass)
B ---| used by |A
A ---|depends on |B
@@ -3881,7 +3879,7 @@ class A:
def method_a(self): pass
class B(A):
- def method_b(self):
+ def method_b(self):
self.method_a()
class C(B):
@@ -4771,7 +4769,7 @@ for attr in class_def.attributes:
# Each attribute has an assignment property
attr_type = attr.assignment.type # -> TypeAnnotation
print(f"{attr.name}: {attr_type.source}") # e.g. "x: int"
-
+
# Set attribute type
attr.assignment.set_type("int")
@@ -4788,7 +4786,7 @@ Union types ([UnionType](/api-reference/core/UnionType)) can be manipulated as c
```python
# Get union type
-union_type = function.return_type # -> A | B
+union_type = function.return_type # -> A | B
print(union_type.symbols) # ["A", "B"]
# Add/remove options
@@ -5639,13 +5637,13 @@ Here's an example of using flags during code analysis:
```python
def analyze_codebase(codebase):
- for function in codebase.functions:
+ for function in codebase.functions:
# Check documentation
if not function.docstring:
function.flag(
message="Missing docstring",
)
-
+
# Check error handling
if function.is_async and not function.has_try_catch:
function.flag(
@@ -6355,7 +6353,7 @@ Explore our tutorials to learn how to use Codegen for various code transformatio
>
Update API calls, handle breaking changes, and manage bulk updates across your codebase.
-
Convert Flask applications to FastAPI, updating routes and dependencies.
-
Migrate Python 2 code to Python 3, updating syntax and modernizing APIs.
@@ -6388,9 +6386,9 @@ Explore our tutorials to learn how to use Codegen for various code transformatio
>
Restructure files, enforce naming conventions, and improve project layout.
-
Split large files, extract shared logic, and manage dependencies.
@@ -6488,7 +6486,7 @@ The agent has access to powerful code viewing and manipulation tools powered by
- `CreateFileTool`: Create new files
- `DeleteFileTool`: Delete files
- `RenameFileTool`: Rename files
-- `EditFileTool`: Edit files
+- `EditFileTool`: Edit files
@@ -6995,7 +6993,7 @@ Be explicit about the changes, produce a short summary, and point out possible i
Focus on facts and technical details, using code snippets where helpful.
"""
result = agent.run(prompt)
-
+
# Clean up the temporary comment
comment.delete()
```
@@ -7176,21 +7174,21 @@ def research(repo_name: Optional[str] = None, query: Optional[str] = None):
"""Start a code research session."""
# Initialize codebase
codebase = initialize_codebase(repo_name)
-
+
# Create and run the agent
agent = create_research_agent(codebase)
-
+
# Main research loop
while True:
if not query:
query = Prompt.ask("[bold cyan]Research query[/bold cyan]")
-
+
result = agent.invoke(
{"input": query},
config={"configurable": {"thread_id": 1}}
)
console.print(Markdown(result["messages"][-1].content))
-
+
query = None # Clear for next iteration
```
@@ -7238,7 +7236,7 @@ class CustomAnalysisTool(BaseTool):
"""Custom tool for specialized code analysis."""
name = "custom_analysis"
description = "Performs specialized code analysis"
-
+
def _run(self, query: str) -> str:
# Custom analysis logic
return results
@@ -7516,7 +7514,7 @@ from codegen import Codebase
# Initialize codebase
codebase = Codebase("path/to/posthog/")
-# Create a directed graph for representing call relationships
+# Create a directed graph for representing call relationships
G = nx.DiGraph()
# Configuration flags
@@ -7538,7 +7536,7 @@ We'll create a function that will recursively traverse the call trace of a funct
```python
def create_downstream_call_trace(src_func: Function, depth: int = 0):
"""Creates call graph by recursively traversing function calls
-
+
Args:
src_func (Function): Starting function for call graph
depth (int): Current recursion depth
@@ -7546,7 +7544,7 @@ def create_downstream_call_trace(src_func: Function, depth: int = 0):
# Prevent infinite recursion
if MAX_DEPTH <= depth:
return
-
+
# External modules are not functions
if isinstance(src_func, ExternalModule):
return
@@ -7556,12 +7554,12 @@ def create_downstream_call_trace(src_func: Function, depth: int = 0):
# Skip self-recursive calls
if call.name == src_func.name:
continue
-
+
# Get called function definition
func = call.function_definition
if not func:
continue
-
+
# Apply configured filters
if isinstance(func, ExternalModule) and IGNORE_EXTERNAL_MODULE_CALLS:
continue
@@ -7575,7 +7573,7 @@ def create_downstream_call_trace(src_func: Function, depth: int = 0):
func_name = f"{func.parent_class.name}.{func.name}" if func.is_method else func.name
# Add node and edge with metadata
- G.add_node(func, name=func_name,
+ G.add_node(func, name=func_name,
color=COLOR_PALETTE.get(func.__class__.__name__))
G.add_edge(src_func, func, **generate_edge_meta(call))
@@ -7590,10 +7588,10 @@ We can enrich our edges with metadata about the function calls:
```python
def generate_edge_meta(call: FunctionCall) -> dict:
"""Generate metadata for call graph edges
-
+
Args:
call (FunctionCall): Function call information
-
+
Returns:
dict: Edge metadata including name and location
"""
@@ -7612,8 +7610,8 @@ Finally, we can visualize our call graph starting from a specific function:
target_class = codebase.get_class('SharingConfigurationViewSet')
target_method = target_class.get_method('patch')
-# Add root node
-G.add_node(target_method,
+# Add root node
+G.add_node(target_method,
name=f"{target_class.name}.{target_method.name}",
color=COLOR_PALETTE["StartFunction"])
@@ -7663,7 +7661,7 @@ The core function for building our dependency graph:
```python
def create_dependencies_visualization(symbol: Symbol, depth: int = 0):
"""Creates visualization of symbol dependencies
-
+
Args:
symbol (Symbol): Starting symbol to analyze
depth (int): Current recursion depth
@@ -7671,11 +7669,11 @@ def create_dependencies_visualization(symbol: Symbol, depth: int = 0):
# Prevent excessive recursion
if depth >= MAX_DEPTH:
return
-
+
# Process each dependency
for dep in symbol.dependencies:
dep_symbol = None
-
+
# Handle different dependency types
if isinstance(dep, Symbol):
# Direct symbol reference
@@ -7686,13 +7684,13 @@ def create_dependencies_visualization(symbol: Symbol, depth: int = 0):
if dep_symbol:
# Add node with appropriate styling
- G.add_node(dep_symbol,
- color=COLOR_PALETTE.get(dep_symbol.__class__.__name__,
+ G.add_node(dep_symbol,
+ color=COLOR_PALETTE.get(dep_symbol.__class__.__name__,
"#f694ff"))
-
+
# Add dependency relationship
G.add_edge(symbol, dep_symbol)
-
+
# Recurse unless it's a class (avoid complexity)
if not isinstance(dep_symbol, PyClass):
create_dependencies_visualization(dep_symbol, depth + 1)
@@ -7704,7 +7702,7 @@ Finally, we can visualize our dependency graph starting from a specific symbol:
# Get target symbol
target_func = codebase.get_function("get_query_runner")
-# Add root node
+# Add root node
G.add_node(target_func, color=COLOR_PALETTE["StartFunction"])
# Generate dependency graph
@@ -7747,16 +7745,16 @@ HTTP_METHODS = ["get", "put", "patch", "post", "head", "delete"]
def generate_edge_meta(usage: Usage) -> dict:
"""Generate metadata for graph edges
-
+
Args:
usage (Usage): Usage relationship information
-
+
Returns:
dict: Edge metadata including name and location
"""
return {
"name": usage.match.source,
- "file_path": usage.match.filepath,
+ "file_path": usage.match.filepath,
"start_point": usage.match.start_point,
"end_point": usage.match.end_point,
"symbol_name": usage.match.__class__.__name__
@@ -7764,10 +7762,10 @@ def generate_edge_meta(usage: Usage) -> dict:
def is_http_method(symbol: PySymbol) -> bool:
"""Check if a symbol is an HTTP endpoint method
-
+
Args:
symbol (PySymbol): Symbol to check
-
+
Returns:
bool: True if symbol is an HTTP method
"""
@@ -7781,7 +7779,7 @@ The main function for creating our blast radius visualization:
```python
def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0):
"""Create visualization of symbol usage relationships
-
+
Args:
symbol (PySymbol): Starting symbol to analyze
depth (int): Current recursion depth
@@ -7789,11 +7787,11 @@ def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0):
# Prevent excessive recursion
if depth >= MAX_DEPTH:
return
-
+
# Process each usage of the symbol
for usage in symbol.usages:
usage_symbol = usage.usage_symbol
-
+
# Determine node color based on type
if is_http_method(usage_symbol):
color = COLOR_PALETTE.get("HTTP_METHOD")
@@ -7803,7 +7801,7 @@ def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0):
# Add node and edge to graph
G.add_node(usage_symbol, color=color)
G.add_edge(symbol, usage_symbol, **generate_edge_meta(usage))
-
+
# Recursively process usage symbol
create_blast_radius_visualization(usage_symbol, depth + 1)
```
@@ -7954,7 +7952,7 @@ for call in old_api.call_sites:
f"data={call.get_arg_by_parameter_name('input').value}",
f"timeout={call.get_arg_by_parameter_name('wait').value}"
]
-
+
# Replace the old call with the new API
call.replace(f"new_process_data({', '.join(args)})")
```
@@ -7968,10 +7966,10 @@ When updating chained method calls, like database queries or builder patterns:
for execute_call in codebase.function_calls:
if execute_call.name != "execute":
continue
-
+
# Get the full chain
chain = execute_call.call_chain
-
+
# Example: Add .timeout() before .execute()
if "timeout" not in {call.name for call in chain}:
execute_call.insert_before("timeout(30)")
@@ -7990,45 +7988,45 @@ Here's a comprehensive example:
```python
def migrate_api_v1_to_v2(codebase):
old_api = codebase.get_function("create_user_v1")
-
+
# Document all existing call patterns
call_patterns = {}
for call in old_api.call_sites:
args = [arg.source for arg in call.args]
pattern = ", ".join(args)
call_patterns[pattern] = call_patterns.get(pattern, 0) + 1
-
+
print("Found call patterns:")
for pattern, count in call_patterns.items():
print(f" {pattern}: {count} occurrences")
-
+
# Create new API version
new_api = old_api.copy()
new_api.rename("create_user_v2")
-
+
# Update parameter types
new_api.get_parameter("email").type = "EmailStr"
new_api.get_parameter("role").type = "UserRole"
-
+
# Add new required parameters
new_api.add_parameter("tenant_id: UUID")
-
+
# Update all call sites
for call in old_api.call_sites:
# Get current arguments
email_arg = call.get_arg_by_parameter_name("email")
role_arg = call.get_arg_by_parameter_name("role")
-
+
# Build new argument list with type conversions
new_args = [
f"email=EmailStr({email_arg.value})",
f"role=UserRole({role_arg.value})",
"tenant_id=get_current_tenant_id()"
]
-
+
# Replace old call with new version
call.replace(f"create_user_v2({', '.join(new_args)})")
-
+
# Add deprecation notice to old version
old_api.add_decorator('@deprecated("Use create_user_v2 instead")')
@@ -8050,10 +8048,10 @@ migrate_api_v1_to_v2(codebase)
```python
# First update parameter names
param.rename("new_name")
-
+
# Then update types
param.type = "new_type"
-
+
# Finally update call sites
for call in api.call_sites:
# ... update calls
@@ -8063,7 +8061,7 @@ migrate_api_v1_to_v2(codebase)
```python
# Add new parameter with default
api.add_parameter("new_param: str = None")
-
+
# Later make it required
api.get_parameter("new_param").remove_default()
```
@@ -8078,7 +8076,7 @@ migrate_api_v1_to_v2(codebase)
Remember to test thoroughly after making bulk changes to APIs. While Codegen ensures syntactic correctness, you'll want to verify the semantic correctness of the changes.
-
+
---
title: "Organizing Your Codebase"
@@ -8642,16 +8640,16 @@ from collections import defaultdict
# Create a graph of file dependencies
def create_dependency_graph():
G = nx.DiGraph()
-
+
for file in codebase.files:
# Add node for this file
G.add_node(file.filepath)
-
+
# Add edges for each import
for imp in file.imports:
if imp.from_file: # Skip external imports
G.add_edge(file.filepath, imp.from_file.filepath)
-
+
return G
# Create and analyze the graph
@@ -8680,18 +8678,18 @@ def break_circular_dependency(cycle):
# Get the first two files in the cycle
file1 = codebase.get_file(cycle[0])
file2 = codebase.get_file(cycle[1])
-
+
# Create a shared module for common code
shared_dir = "shared"
if not codebase.has_directory(shared_dir):
codebase.create_directory(shared_dir)
-
+
# Find symbols used by both files
shared_symbols = []
for symbol in file1.symbols:
if any(usage.file == file2 for usage in symbol.usages):
shared_symbols.append(symbol)
-
+
# Move shared symbols to a new file
if shared_symbols:
shared_file = codebase.create_file(f"{shared_dir}/shared_types.py")
@@ -8713,7 +8711,7 @@ def organize_file_imports(file):
std_lib_imports = []
third_party_imports = []
local_imports = []
-
+
for imp in file.imports:
if imp.is_standard_library:
std_lib_imports.append(imp)
@@ -8721,26 +8719,26 @@ def organize_file_imports(file):
third_party_imports.append(imp)
else:
local_imports.append(imp)
-
+
# Sort each group
for group in [std_lib_imports, third_party_imports, local_imports]:
group.sort(key=lambda x: x.module_name)
-
+
# Remove all existing imports
for imp in file.imports:
imp.remove()
-
+
# Add imports back in organized groups
if std_lib_imports:
for imp in std_lib_imports:
file.add_import(imp.source)
file.insert_after_imports("") # Add newline
-
+
if third_party_imports:
for imp in third_party_imports:
file.add_import(imp.source)
file.insert_after_imports("") # Add newline
-
+
if local_imports:
for imp in local_imports:
file.add_import(imp.source)
@@ -8759,22 +8757,22 @@ from collections import defaultdict
def analyze_module_coupling():
coupling_scores = defaultdict(int)
-
+
for file in codebase.files:
# Count unique files imported from
imported_files = {imp.from_file for imp in file.imports if imp.from_file}
coupling_scores[file.filepath] = len(imported_files)
-
+
# Count files that import this file
- importing_files = {usage.file for symbol in file.symbols
+ importing_files = {usage.file for symbol in file.symbols
for usage in symbol.usages if usage.file != file}
coupling_scores[file.filepath] += len(importing_files)
-
+
# Sort by coupling score
- sorted_files = sorted(coupling_scores.items(),
- key=lambda x: x[1],
+ sorted_files = sorted(coupling_scores.items(),
+ key=lambda x: x[1],
reverse=True)
-
+
print("\nš Module Coupling Analysis:")
print("\nMost coupled files:")
for filepath, score in sorted_files[:5]:
@@ -8792,9 +8790,9 @@ def extract_shared_code(file, min_usages=3):
# Find symbols used by multiple files
for symbol in file.symbols:
# Get unique files using this symbol
- using_files = {usage.file for usage in symbol.usages
+ using_files = {usage.file for usage in symbol.usages
if usage.file != file}
-
+
if len(using_files) >= min_usages:
# Create appropriate shared module
module_name = determine_shared_module(symbol)
@@ -8802,7 +8800,7 @@ def extract_shared_code(file, min_usages=3):
shared_file = codebase.create_file(f"shared/{module_name}.py")
else:
shared_file = codebase.get_file(f"shared/{module_name}.py")
-
+
# Move symbol to shared module
symbol.move_to_file(shared_file, strategy="update_all_imports")
@@ -8856,7 +8854,7 @@ if feature_flag_class:
# Initialize usage count for all attributes
for attr in feature_flag_class.attributes:
feature_flag_usage[attr.name] = 0
-
+
# Get all usages of the FeatureFlag class
for usage in feature_flag_class.usages:
usage_source = usage.usage_symbol.source if hasattr(usage, 'usage_symbol') else str(usage)
@@ -9601,7 +9599,7 @@ Let's break down how this works:
if export.is_reexport() and export.is_default_export():
print(f" š Converting default export '{export.name}'")
```
-
+
The code identifies default exports by checking:
1. If it's a re-export (`is_reexport()`)
2. If it's a default export (`is_default_export()`)
@@ -9709,7 +9707,7 @@ for file in codebase.files:
print(f"⨠Fixed exports in {target_file.filepath}")
-```
+```
---
title: "Creating Documentation"
@@ -9798,11 +9796,11 @@ for directory in codebase.directories:
# Skip test, sql and alembic directories
if any(x in directory.path.lower() for x in ['test', 'sql', 'alembic']):
continue
-
+
# Get undecorated functions
funcs = [f for f in directory.functions if not f.is_decorated]
total = len(funcs)
-
+
# Only analyze dirs with >10 functions
if total > 10:
documented = sum(1 for f in funcs if f.docstring)
@@ -9817,12 +9815,12 @@ for directory in codebase.directories:
if dir_stats:
lowest_dir = min(dir_stats.items(), key=lambda x: x[1]['coverage'])
path, stats = lowest_dir
-
+
print(f"š Lowest coverage directory: '{path}'")
print(f" ⢠Total functions: {stats['total']}")
print(f" ⢠Documented: {stats['documented']}")
print(f" ⢠Coverage: {stats['coverage']:.1f}%")
-
+
# Print all directory stats for comparison
print("\nš All directory coverage rates:")
for path, stats in sorted(dir_stats.items(), key=lambda x: x[1]['coverage']):
@@ -10610,7 +10608,7 @@ iconType: "solid"
-Import loops occur when two or more Python modules depend on each other, creating a circular dependency. While some import cycles can be harmless, others can lead to runtime errors and make code harder to maintain.
+Import loops occur when two or more Python modules depend on each other, creating a circular dependency. While some import cycles can be harmless, others can lead to runtime errors and make code harder to maintain.
In this tutorial, we'll explore how to identify and fix problematic import cycles using Codegen.
@@ -11507,10 +11505,10 @@ Match (s: Func )-[r: CALLS]-> (e:Func) RETURN s, e LIMIT 10
```cypher
Match path = (:(Method|Func)) -[:CALLS*5..10]-> (:(Method|Func))
-Return path
+Return path
LIMIT 20
```
-
\ No newline at end of file
+
diff --git a/src/codegen/visualizations/visualization_manager.py b/src/codegen/visualizations/visualization_manager.py
index 7be3cf8fb..114beb226 100644
--- a/src/codegen/visualizations/visualization_manager.py
+++ b/src/codegen/visualizations/visualization_manager.py
@@ -22,7 +22,7 @@ def __init__(
@property
def viz_path(self) -> str:
- return os.path.join(self.op.base_dir, "codegen-graphviz")
+ return os.path.join(self.op.repo_config.base_dir, "codegen-graphviz")
@property
def viz_file_path(self) -> str:
diff --git a/tests/integration/codegen/git/conftest.py b/tests/integration/codegen/git/conftest.py
index e3846935d..2577e899a 100644
--- a/tests/integration/codegen/git/conftest.py
+++ b/tests/integration/codegen/git/conftest.py
@@ -2,7 +2,9 @@
import pytest
-from codegen.git.schemas.repo_config import RepoConfig
+from codegen.configs.models.repository import RepositoryConfig
+from codegen.git.repo_operator.repo_operator import RepoOperator
+from codegen.git.schemas.enums import SetupOption
@pytest.fixture()
@@ -16,9 +18,15 @@ def mock_config():
@pytest.fixture()
def repo_config(tmpdir):
- repo_config = RepoConfig(
+ repo_config = RepositoryConfig(
name="Kevin-s-Adventure-Game",
- full_name="codegen-sh/Kevin-s-Adventure-Game",
+ owner="codegen-sh",
base_dir=str(tmpdir),
)
yield repo_config
+
+
+@pytest.fixture
+def op(repo_config, request):
+ op = RepoOperator(repo_config=repo_config, shallow=request.param if hasattr(request, "param") else True, bot_commit=False, setup_option=SetupOption.PULL_OR_CLONE)
+ yield op
diff --git a/tests/integration/codegen/git/repo_operator/test_remote_repo_operator.py b/tests/integration/codegen/git/repo_operator/test_remote_repo_operator.py
index 76d7a6292..5bbf8ef87 100644
--- a/tests/integration/codegen/git/repo_operator/test_remote_repo_operator.py
+++ b/tests/integration/codegen/git/repo_operator/test_remote_repo_operator.py
@@ -4,18 +4,12 @@
from github.MainClass import Github
from codegen.git.repo_operator.repo_operator import RepoOperator
-from codegen.git.schemas.enums import CheckoutResult, SetupOption
+from codegen.git.schemas.enums import CheckoutResult
from codegen.git.utils.file_utils import create_files
shallow_options = [True, False]
-@pytest.fixture
-def op(repo_config, request):
- op = RepoOperator(repo_config, shallow=request.param if hasattr(request, "param") else True, bot_commit=False, setup_option=SetupOption.PULL_OR_CLONE)
- yield op
-
-
@pytest.mark.parametrize("op", shallow_options, ids=lambda x: f"shallow={x}", indirect=True)
@patch("codegen.git.clients.github_client.Github")
def test_checkout_branch(mock_git_client, op: RepoOperator):
diff --git a/tests/integration/codegen/runner/conftest.py b/tests/integration/codegen/runner/conftest.py
index 5f16fa4f4..0d7b20d89 100644
--- a/tests/integration/codegen/runner/conftest.py
+++ b/tests/integration/codegen/runner/conftest.py
@@ -3,37 +3,35 @@
import pytest
+from codegen.configs.models.repository import RepositoryConfig
from codegen.git.clients.git_repo_client import GitRepoClient
from codegen.git.repo_operator.repo_operator import RepoOperator
from codegen.git.schemas.enums import SetupOption
-from codegen.git.schemas.repo_config import RepoConfig
from codegen.runner.clients.codebase_client import CodebaseClient
-from codegen.shared.enums.programming_language import ProgrammingLanguage
from codegen.shared.network.port import get_free_port
@pytest.fixture()
-def repo_config(tmpdir) -> Generator[RepoConfig, None, None]:
- yield RepoConfig(
- name="Kevin-s-Adventure-Game",
- full_name="codegen-sh/Kevin-s-Adventure-Game",
- language=ProgrammingLanguage.PYTHON,
- base_dir=str(tmpdir),
+def repo_config(tmpdir) -> Generator[RepositoryConfig, None, None]:
+ yield RepositoryConfig(
+ path=str(tmpdir / "Kevin-s-Adventure-Game"),
+ owner="codegen-sh",
+ language="PYTHON",
)
@pytest.fixture
-def op(repo_config: RepoConfig) -> Generator[RepoOperator, None, None]:
+def op(repo_config: RepositoryConfig) -> Generator[RepoOperator, None, None]:
yield RepoOperator(repo_config=repo_config, setup_option=SetupOption.PULL_OR_CLONE)
@pytest.fixture
-def git_repo_client(op: RepoOperator, repo_config: RepoConfig) -> Generator[GitRepoClient, None, None]:
- yield GitRepoClient(repo_config=repo_config, access_token=op.access_token)
+def git_repo_client(op: RepoOperator, repo_config: RepositoryConfig) -> Generator[GitRepoClient, None, None]:
+ yield GitRepoClient(repo_full_name=repo_config.full_name, access_token=op.access_token)
@pytest.fixture
-def codebase_client(repo_config: RepoConfig) -> Generator[CodebaseClient, None, None]:
+def codebase_client(repo_config: RepositoryConfig) -> Generator[CodebaseClient, None, None]:
sb_client = CodebaseClient(repo_config=repo_config, port=get_free_port())
sb_client.runner = Mock()
yield sb_client
diff --git a/tests/shared/skills/verify_skill_output.py b/tests/shared/skills/verify_skill_output.py
index 00e30c3e0..32eca65bb 100644
--- a/tests/shared/skills/verify_skill_output.py
+++ b/tests/shared/skills/verify_skill_output.py
@@ -50,7 +50,7 @@ def verify_skill_output(codebase: Codebase, skill, test_case, get_diff, snapshot
elif test_case.graph:
# I want to save test_case graph locally and compare it with the graph generated by the skill
# Read the generated graph JSON
- graph_json = open(f"{codebase.op.base_dir}/codegen-graphviz/graph.json").read()
+ graph_json = open(f"{codebase.op.repo_config.base_dir}/codegen-graphviz/graph.json").read()
# Compare with snapshot
snapshot.assert_match(graph_json, f"{skill.name}_{test_case.name or 'unnamed'}.json")
diff --git a/tests/unit/codegen/runner/sandbox/conftest.py b/tests/unit/codegen/runner/sandbox/conftest.py
index efc4913cf..a9a5c98cc 100644
--- a/tests/unit/codegen/runner/sandbox/conftest.py
+++ b/tests/unit/codegen/runner/sandbox/conftest.py
@@ -13,7 +13,7 @@
@pytest.fixture
def codebase(tmpdir) -> Codebase:
- op = RepoOperator.create_from_files(repo_path=f"{tmpdir}/test-repo", files={"test.py": "a = 1"}, bot_commit=True)
+ op = RepoOperator.create_from_files(repo_path=f"{tmpdir}/test-repo", files={"test.py": "a = 1"}, bot_commit=True, programming_language=ProgrammingLanguage.PYTHON)
projects = [ProjectConfig(repo_operator=op, programming_language=ProgrammingLanguage.PYTHON)]
codebase = Codebase(projects=projects)
return codebase