diff --git a/src/codegen/runner/sandbox/repo.py b/src/codegen/runner/sandbox/repo.py index 297b86030..b3d06c37f 100644 --- a/src/codegen/runner/sandbox/repo.py +++ b/src/codegen/runner/sandbox/repo.py @@ -1,8 +1,8 @@ import logging -from codegen.git.schemas.enums import FetchResult -from codegen.git.utils.branch_sync import BranchSyncResult, fetch_highside_branch, get_highside_origin +from codegen.git.schemas.github import GithubType from codegen.runner.models.codemod import Codemod +from codegen.runner.utils.branch_sync import get_remote_for_github_type from codegen.sdk.codebase.factory.codebase_factory import CodebaseType logger = logging.getLogger(__name__) @@ -22,9 +22,12 @@ def set_up_base_branch(self, base_branch: str | None) -> None: if self.codebase.op.is_branch_checked_out(base_branch): return - res = self._pull_highside_to_lowside(base_branch) - if res is BranchSyncResult.SUCCESS: - self.codebase.checkout(branch=base_branch, remote=True) + # fetch the base branch from highside (do not checkout yet) + highside_remote = get_remote_for_github_type(op=self.codebase.op, github_type=GithubType.Github) + self.codebase.op.fetch_remote(highside_remote.name, refspec=f"{base_branch}:{base_branch}") + + # checkout the base branch (and possibly sync graph) + self.codebase.checkout(branch=base_branch) def set_up_head_branch(self, head_branch: str, force_push_head_branch: bool): """Set-up head branch by pushing latest highside branch to lowside and fetching the branch (so that it can be checked out later).""" @@ -43,22 +46,9 @@ def set_up_head_branch(self, head_branch: str, force_push_head_branch: bool): if force_push_head_branch: return - res = self._pull_highside_to_lowside(head_branch) - if res is BranchSyncResult.SUCCESS: - self.codebase.op.fetch_remote("origin", refspec=f"{head_branch}:{head_branch}") - - def _pull_highside_to_lowside(self, branch_name: str): - """Grabs the latest highside branch `branch_name` and pushes it to the lowside.""" - # Step 1: checkout branch that tracks highside remote - res = fetch_highside_branch(op=self.codebase.op, branch_name=branch_name) - if res == FetchResult.REFSPEC_NOT_FOUND: - return BranchSyncResult.BRANCH_NOT_FOUND - - # Step 2: push branch up to lowside - logger.info(f"Pushing branch: {branch_name} from highside to lowside w/ force=False ...") - lowside_origin = self.codebase.op.git_cli.remote("origin") - self.codebase.op.push_changes(remote=lowside_origin, refspec=f"{branch_name}:{branch_name}", force=False) - return BranchSyncResult.SUCCESS + # fetch the head branch from highside (do not checkout yet) + highside_remote = get_remote_for_github_type(op=self.codebase.op, github_type=GithubType.Github) + self.codebase.op.fetch_remote(highside_remote.name, refspec=f"{head_branch}:{head_branch}") def reset_branch(self, base_branch: str, head_branch: str) -> None: logger.info(f"Checking out base branch {base_branch} ...") @@ -76,8 +66,8 @@ def push_changes_to_remote(self, codemod: Codemod, head_branch: str, force_push: return False # =====[ Push changes highside ]===== - highside_origin = get_highside_origin(self.codebase.op) - highside_res = self.codebase.op.push_changes(remote=highside_origin, refspec=f"{head_branch}:{head_branch}", force=force_push) + highside_remote = get_remote_for_github_type(op=self.codebase.op, github_type=GithubType.Github) + highside_res = self.codebase.op.push_changes(remote=highside_remote, refspec=f"{head_branch}:{head_branch}", force=force_push) return not any(push_info.flags & push_info.ERROR for push_info in highside_res) # TODO: move bunch of codebase git operations into this class. diff --git a/src/codegen/runner/sandbox/runner.py b/src/codegen/runner/sandbox/runner.py index 13affd9b6..846a839ef 100644 --- a/src/codegen/runner/sandbox/runner.py +++ b/src/codegen/runner/sandbox/runner.py @@ -5,6 +5,7 @@ from git import Commit as GitCommit from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator +from codegen.git.schemas.github import GithubType from codegen.git.schemas.repo_config import RepoConfig from codegen.runner.models.apis import CreateBranchRequest, CreateBranchResponse, GetDiffRequest, GetDiffResponse from codegen.runner.models.configs import get_codebase_config @@ -39,7 +40,7 @@ def __init__( ) -> None: self.container_id = container_id self.repo = repo_config - self.op = RemoteRepoOperator(repo_config, base_dir=repo_config.base_dir) + self.op = RemoteRepoOperator(repo_config, base_dir=repo_config.base_dir, github_type=GithubType.Github) self.commit = self.op.git_cli.head.commit async def warmup(self) -> None: @@ -76,13 +77,13 @@ def reset_runner(self) -> None: self.codebase.checkout(branch=self.codebase.default_branch, create_if_missing=True) @staticmethod - def _set_sentry_tags(epic_id: int, is_customer: bool) -> None: + def _set_sentry_tags(epic_id: int, is_admin: bool) -> None: """Set the sentry tags for a CodemodRun""" sentry_sdk.set_tag("epic_id", epic_id) # To easily get to the epic in the UI - sentry_sdk.set_tag("is_customer", is_customer) # To filter "prod" level errors, ex if customer hits an error vs an admin + sentry_sdk.set_tag("is_admin", is_admin) # To filter "prod" level errors, ex if customer hits an error vs an admin async def get_diff(self, request: GetDiffRequest) -> GetDiffResponse: - self._set_sentry_tags(epic_id=request.codemod.epic_id, is_customer=request.codemod.is_customer) + self._set_sentry_tags(epic_id=request.codemod.epic_id, is_admin=request.codemod.is_admin) custom_scope = {"context": request.codemod.codemod_context} if request.codemod.codemod_context else {} code_to_exec = create_execute_function_from_codeblock(codeblock=request.codemod.user_code, custom_scope=custom_scope) session_options = SessionOptions(max_transactions=request.max_transactions, max_seconds=request.max_seconds) @@ -92,7 +93,7 @@ async def get_diff(self, request: GetDiffRequest) -> GetDiffResponse: return GetDiffResponse(result=res) async def create_branch(self, request: CreateBranchRequest) -> CreateBranchResponse: - self._set_sentry_tags(epic_id=request.codemod.epic_id, is_customer=request.codemod.is_customer) + self._set_sentry_tags(epic_id=request.codemod.epic_id, is_admin=request.codemod.is_admin) custom_scope = {"context": request.codemod.codemod_context} if request.codemod.codemod_context else {} code_to_exec = create_execute_function_from_codeblock(codeblock=request.codemod.user_code, custom_scope=custom_scope) branch_config = request.branch_config diff --git a/src/codegen/runner/utils/branch_sync.py b/src/codegen/runner/utils/branch_sync.py new file mode 100644 index 000000000..798621b87 --- /dev/null +++ b/src/codegen/runner/utils/branch_sync.py @@ -0,0 +1,21 @@ +from git.remote import Remote + +from codegen.git.configs.constants import HIGHSIDE_REMOTE_NAME, LOWSIDE_REMOTE_NAME +from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator +from codegen.git.schemas.github import GithubScope, GithubType +from codegen.git.utils.clone_url import get_authenticated_clone_url_for_repo_config + + +def get_remote_for_github_type(op: RemoteRepoOperator, github_type: GithubType = GithubType.GithubEnterprise) -> Remote: + if op.github_type == github_type: + return op.git_cli.remote(name="origin") + + remote_name = HIGHSIDE_REMOTE_NAME if github_type == GithubType.Github else LOWSIDE_REMOTE_NAME + remote_url = get_authenticated_clone_url_for_repo_config(repo=op.repo_config, github_type=github_type, github_scope=GithubScope.WRITE) + + if remote_name in op.git_cli.remotes: + remote = op.git_cli.remote(remote_name) + remote.set_url(remote_url) + else: + remote = op.git_cli.create_remote(remote_name, remote_url) + return remote