diff --git a/docs/mint.json b/docs/mint.json index 0b7d4767d..36f963671 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -16,7 +16,7 @@ "og:locale": "en_US", "og:logo": "https://i.imgur.com/f4OVOqI.png", "article:publisher": "Codegen, Inc.", - "twitter:site": "@codegen", + "twitter:site": "@codegen" }, "favicon": "/favicon.svg", "colors": { diff --git a/src/codegen/git/repo_operator/local_repo_operator.py b/src/codegen/git/repo_operator/local_repo_operator.py index 254b013aa..e16714d0a 100644 --- a/src/codegen/git/repo_operator/local_repo_operator.py +++ b/src/codegen/git/repo_operator/local_repo_operator.py @@ -1,3 +1,4 @@ +import logging import os from functools import cached_property from typing import Self, override @@ -6,13 +7,19 @@ from git import Remote from git import Repo as GitCLI from git.remote import PushInfoList +from github import Github +from github.PullRequest import PullRequest +from codegen.git.clients.git_repo_client import GitRepoClient from codegen.git.repo_operator.repo_operator import RepoOperator from codegen.git.schemas.enums import FetchResult +from codegen.git.schemas.github import GithubType from codegen.git.schemas.repo_config import BaseRepoConfig from codegen.git.utils.clone_url import url_to_github from codegen.git.utils.file_utils import create_files +logger = logging.getLogger(__name__) + class OperatorIsLocal(Exception): """Error raised while trying to do a remote operation on a local operator""" @@ -29,20 +36,54 @@ class LocalRepoOperator(RepoOperator): _repo_name: str _git_cli: GitCLI repo_config: BaseRepoConfig + _github_api_key: str | None + _remote_git_repo: GitRepoClient | None = None def __init__( self, repo_path: str, # full path to the repo + github_api_key: str | None = None, repo_config: BaseRepoConfig | None = None, bot_commit: bool = False, ) -> None: self._repo_path = repo_path self._repo_name = os.path.basename(repo_path) + self._github_api_key = github_api_key + self.github_type = GithubType.Github + self._remote_git_repo = None os.makedirs(self.repo_path, exist_ok=True) GitCLI.init(self.repo_path) repo_config = repo_config or BaseRepoConfig() super().__init__(repo_config, self.repo_path, bot_commit) + #################################################################################################################### + # PROPERTIES + #################################################################################################################### + + @property + def remote_git_repo(self) -> GitRepoClient: + if self._remote_git_repo is None: + if not self._github_api_key: + return None + + if not (base_url := self.base_url): + msg = "Could not determine GitHub URL from remotes" + raise ValueError(msg) + + # Extract owner and repo from the base URL + # Format: https://github.com/owner/repo + parts = base_url.split("/") + if len(parts) < 2: + msg = f"Invalid GitHub URL format: {base_url}" + raise ValueError(msg) + + owner = parts[-4] + repo = parts[-3] + + github = Github(self._github_api_key) + self._remote_git_repo = github.get_repo(f"{owner}/{repo}") + return self._remote_git_repo + #################################################################################################################### # CLASS METHODS #################################################################################################################### @@ -70,9 +111,16 @@ def create_from_files(cls, repo_path: str, files: dict[str, str], bot_commit: bo return op @classmethod - def create_from_commit(cls, repo_path: str, commit: str, url: str) -> Self: - """Do a shallow checkout of a particular commit to get a repository from a given remote URL.""" - op = cls(repo_config=BaseRepoConfig(), repo_path=repo_path, bot_commit=False) + def create_from_commit(cls, repo_path: str, commit: str, url: str, github_api_key: str | None = None) -> Self: + """Do a shallow checkout of a particular commit to get a repository from a given remote URL. + + Args: + repo_path (str): Path where the repo should be cloned + commit (str): The commit hash to checkout + url (str): Git URL of the repository + github_api_key (str | None): Optional GitHub API key for operations that need GitHub access + """ + op = cls(repo_path=repo_path, bot_commit=False, github_api_key=github_api_key) op.discard_changes() if op.get_active_branch_or_commit() != commit: op.create_remote("origin", url) @@ -81,12 +129,13 @@ def create_from_commit(cls, repo_path: str, commit: str, url: str) -> Self: return op @classmethod - def create_from_repo(cls, repo_path: str, url: str) -> Self: + def create_from_repo(cls, repo_path: str, url: str, github_api_key: str | None = None) -> Self: """Create a fresh clone of a repository or use existing one if up to date. Args: repo_path (str): Path where the repo should be cloned url (str): Git URL of the repository + github_api_key (str | None): Optional GitHub API key for operations that need GitHub access """ # Check if repo already exists if os.path.exists(repo_path): @@ -102,7 +151,7 @@ def create_from_repo(cls, repo_path: str, url: str) -> Self: remote_head = git_cli.remotes.origin.refs[git_cli.active_branch.name].commit # If up to date, use existing repo if local_head.hexsha == remote_head.hexsha: - return cls(repo_config=BaseRepoConfig(), repo_path=repo_path, bot_commit=False) + return cls(repo_path=repo_path, bot_commit=False, github_api_key=github_api_key) except Exception: # If any git operations fail, fallback to fresh clone pass @@ -113,13 +162,13 @@ def create_from_repo(cls, repo_path: str, url: str) -> Self: shutil.rmtree(repo_path) - # Do a fresh clone with depth=1 to get latest commit + # Clone the repository GitCLI.clone_from(url=url, to_path=repo_path, depth=1) # Initialize with the cloned repo git_cli = GitCLI(repo_path) - return cls(repo_config=BaseRepoConfig(), repo_path=repo_path, bot_commit=False) + return cls(repo_path=repo_path, bot_commit=False, github_api_key=github_api_key) #################################################################################################################### # PROPERTIES @@ -153,3 +202,26 @@ def pull_repo(self) -> None: def fetch_remote(self, remote_name: str = "origin", refspec: str | None = None, force: bool = True) -> FetchResult: raise OperatorIsLocal() + + def get_pull_request(self, pr_number: int) -> PullRequest | None: + """Get a GitHub Pull Request object for the given PR number. + + Args: + pr_number (int): The PR number to fetch + + Returns: + PullRequest | None: The PyGitHub PullRequest object if found, None otherwise + + Note: + This requires a GitHub API key to be set when creating the LocalRepoOperator + """ + try: + # Create GitHub client and get the PR + repo = self.remote_git_repo + if repo is None: + logger.warning("GitHub API key is required to fetch pull requests") + return None + return repo.get_pull(pr_number) + except Exception as e: + logger.warning(f"Failed to get PR {pr_number}: {e!s}") + return None diff --git a/src/codegen/git/utils/pr_review.py b/src/codegen/git/utils/pr_review.py new file mode 100644 index 000000000..271a594f5 --- /dev/null +++ b/src/codegen/git/utils/pr_review.py @@ -0,0 +1,129 @@ +from typing import TYPE_CHECKING + +import requests +from github import Repository +from github.PullRequest import PullRequest +from unidiff import PatchSet + +from codegen.git.models.pull_request_context import PullRequestContext +from codegen.git.repo_operator.local_repo_operator import LocalRepoOperator +from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator + +if TYPE_CHECKING: + from codegen.sdk.core.codebase import Codebase, Editable, File, Symbol + + +def get_merge_base(git_repo_client: Repository, pull: PullRequest | PullRequestContext) -> str: + """Gets the merge base of a pull request using a remote GitHub API client. + + Args: + git_repo_client (GitRepoClient): The GitHub repository client. + pull (PullRequest): The pull request object. + + Returns: + str: The SHA of the merge base commit. + """ + comparison = git_repo_client.compare(pull.base.sha, pull.head.sha) + return comparison.merge_base_commit.sha + + +def get_file_to_changed_ranges(pull_patch_set: PatchSet) -> dict[str, list]: + file_to_changed_ranges = {} + for patched_file in pull_patch_set: + # TODO: skip is deleted + if patched_file.is_removed_file: + continue + changed_ranges = [] # list of changed lines for the file + for hunk in patched_file: + changed_ranges.append(range(hunk.target_start, hunk.target_start + hunk.target_length)) + file_to_changed_ranges[patched_file.path] = changed_ranges + return file_to_changed_ranges + + +def get_pull_patch_set(op: LocalRepoOperator | RemoteRepoOperator, pull: PullRequestContext) -> PatchSet: + # Get the diff directly from GitHub's API + if not op.remote_git_repo: + msg = "GitHub API client is required to get PR diffs" + raise ValueError(msg) + + # Get the diff directly from the PR + diff_url = pull.raw_data.get("diff_url") + if diff_url: + # Fetch the diff content from the URL + response = requests.get(diff_url) + response.raise_for_status() + diff = response.text + else: + # If diff_url not available, get the patch directly + diff = pull.get_patch() + + # Parse the diff into a PatchSet + pull_patch_set = PatchSet(diff) + return pull_patch_set + + +def to_1_indexed(zero_indexed_range: range) -> range: + """Converts a n-indexed range to n+1-indexed. + Primarily to convert 0-indexed ranges to 1 indexed + """ + return range(zero_indexed_range.start + 1, zero_indexed_range.stop + 1) + + +def overlaps(range1: range, range2: range) -> bool: + """Returns True if the two ranges overlap, False otherwise.""" + return max(range1.start, range2.start) < min(range1.stop, range2.stop) + + +class CodegenPR: + """Wrapper around PRs - enables codemods to interact with them""" + + _gh_pr: PullRequest + _codebase: "Codebase" + _op: LocalRepoOperator | RemoteRepoOperator + + # =====[ Computed ]===== + _modified_file_ranges: dict[str, list[tuple[int, int]]] = None + + def __init__(self, op: LocalRepoOperator, codebase: "Codebase", pr: PullRequest): + self._op = op + self._gh_pr = pr + self._codebase = codebase + + @property + def modified_file_ranges(self) -> dict[str, list[tuple[int, int]]]: + """Files and the ranges within that are modified""" + if not self._modified_file_ranges: + pull_patch_set = get_pull_patch_set(op=self._op, pull=self._gh_pr) + self._modified_file_ranges = get_file_to_changed_ranges(pull_patch_set) + return self._modified_file_ranges + + @property + def modified_files(self) -> list["File"]: + filenames = self.modified_file_ranges.keys() + return [self._codebase.get_file(f, optional=True) for f in filenames] + + def is_modified(self, editable: "Editable") -> bool: + """Returns True if the Editable's range contains any modified lines""" + filepath = editable.filepath + changed_ranges = self._modified_file_ranges.get(filepath, []) + symbol_range = to_1_indexed(editable.line_range) + if any(overlaps(symbol_range, changed_range) for changed_range in changed_ranges): + return True + return False + + @property + def modified_symbols(self) -> list["Symbol"]: + # Import SourceFile locally to avoid circular dependencies + from codegen.sdk.core.file import SourceFile + + all_modified = [] + for file in self.modified_files: + if file is None: + print("Warning: File is None") + continue + if not isinstance(file, SourceFile): + continue + for symbol in file.symbols: + if self.is_modified(symbol): + all_modified.append(symbol) + return all_modified diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index cee8f58a0..59bbd180f 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -23,6 +23,7 @@ from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator from codegen.git.repo_operator.repo_operator import RepoOperator from codegen.git.schemas.enums import CheckoutResult +from codegen.git.utils.pr_review import CodegenPR from codegen.sdk._proxy import proxy_property from codegen.sdk.ai.helpers import AbstractAIHelper, MultiProviderAIHelper from codegen.sdk.codebase.codebase_ai import generate_system_prompt, generate_tools @@ -112,7 +113,7 @@ class Codebase(Generic[TSourceFile, TDirectory, TSymbol, TClass, TFunction, TImp console: Manages console output for the codebase. """ - _op: RepoOperator | RemoteRepoOperator + _op: RepoOperator | RemoteRepoOperator | LocalRepoOperator viz: VisualizationManager repo_path: Path console: Console @@ -1162,7 +1163,16 @@ def set_session_options(self, **kwargs: Unpack[SessionOptions]) -> None: self.G.transaction_manager.reset_stopwatch(self.G.session_options.max_seconds) @classmethod - def from_repo(cls, repo_name: str, *, tmp_dir: str | None = None, commit: str | None = None, shallow: bool = True, programming_language: ProgrammingLanguage | None = None) -> "Codebase": + def from_repo( + cls, + repo_name: str, + *, + tmp_dir: str | None = None, + commit: str | None = None, + shallow: bool = True, + programming_language: ProgrammingLanguage | None = None, + config: CodebaseConfig = DefaultConfig, + ) -> "Codebase": """Fetches a codebase from GitHub and returns a Codebase instance. Args: @@ -1171,6 +1181,7 @@ def from_repo(cls, repo_name: str, *, tmp_dir: str | None = None, commit: str | commit (Optional[str]): The specific commit hash to clone. Defaults to HEAD shallow (bool): Whether to do a shallow clone. Defaults to True programming_language (ProgrammingLanguage | None): The programming language of the repo. Defaults to None. + config (CodebaseConfig): Configuration for the codebase. Defaults to DefaultConfig. Returns: Codebase: A Codebase instance initialized with the cloned repository @@ -1198,26 +1209,28 @@ def from_repo(cls, repo_name: str, *, tmp_dir: str | None = None, commit: str | # Use LocalRepoOperator to fetch the repository logger.info("Cloning repository...") if commit is None: - repo_operator = LocalRepoOperator.create_from_repo(repo_path=repo_path, url=repo_url) + repo_operator = LocalRepoOperator.create_from_repo(repo_path=repo_path, url=repo_url, github_api_key=config.secrets.github_api_key if config.secrets else None) else: # Ensure the operator can handle remote operations - repo_operator = LocalRepoOperator.create_from_commit( - repo_path=repo_path, - commit=commit, - url=repo_url, - ) + repo_operator = LocalRepoOperator.create_from_commit(repo_path=repo_path, commit=commit, url=repo_url, github_api_key=config.secrets.github_api_key if config.secrets else None) logger.info("Clone completed successfully") # Initialize and return codebase with proper context logger.info("Initializing Codebase...") project = ProjectConfig.from_repo_operator(repo_operator=repo_operator, programming_language=programming_language) - codebase = Codebase(projects=[project], config=DefaultConfig) + codebase = Codebase(projects=[project], config=config) logger.info("Codebase initialization complete") return codebase except Exception as e: logger.exception(f"Failed to initialize codebase: {e}") raise + def get_modified_symbols_in_pr(self, pr_id: int) -> list[Symbol]: + """Get all modified symbols in a pull request""" + pr = self._op.get_pull_request(pr_id) + cg_pr = CodegenPR(self._op, self, pr) + return cg_pr.modified_symbols + # The last 2 lines of code are added to the runner. See codegen-backend/cli/generate/utils.py # Type Aliases diff --git a/src/codegen/sdk/secrets.py b/src/codegen/sdk/secrets.py index 058ed329c..dd4eaf15b 100644 --- a/src/codegen/sdk/secrets.py +++ b/src/codegen/sdk/secrets.py @@ -4,3 +4,4 @@ @dataclass class Secrets: openai_key: str | None = None + github_api_key: str | None = None