Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/codegen/git/clients/git_repo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,25 @@
class GitRepoClient:
"""Wrapper around PyGithub's Remote Repository."""

repo: RepoConfig
repo_config: RepoConfig
github_type: GithubType = GithubType.GithubEnterprise
gh_client: GithubClientType
read_client: Repository
access_scope: GithubScope
__write_client: Repository | None # Will not be initialized if access scope is read-only

def __init__(self, repo_config: RepoConfig, github_type: GithubType = GithubType.GithubEnterprise, access_scope: GithubScope = GithubScope.READ) -> None:
self.repo = repo_config
self.repo_config = repo_config
self.github_type = github_type
self.gh_client = GithubClientFactory.create_from_repo(self.repo, github_type)
self.gh_client = GithubClientFactory.create_from_repo(self.repo_config, github_type)
self.read_client = self._create_client(GithubScope.READ)
self.__write_client = self._create_client(GithubScope.WRITE) if access_scope == GithubScope.WRITE else None
self.access_scope = access_scope

def _create_client(self, github_scope: GithubScope = GithubScope.READ) -> Repository:
client = self.gh_client.get_repo_by_full_name(self.repo.full_name, github_scope=github_scope)
client = self.gh_client.get_repo_by_full_name(self.repo_config.full_name, github_scope=github_scope)
if not client:
msg = f"Repo {self.repo.full_name} not found in {self.github_type.value}!"
msg = f"Repo {self.repo_config.full_name} not found in {self.github_type.value}!"
raise ValueError(msg)
return client

Expand All @@ -61,7 +61,7 @@ def _write_client(self) -> Repository:

@property
def id(self) -> int:
return self.repo.id
return self.repo_config.id

@property
def default_branch(self) -> str:
Expand Down Expand Up @@ -160,7 +160,7 @@ def get_pull_by_branch_and_state(
if not base_branch_name:
base_branch_name = self.default_branch

head_branch_name = f"{self.repo.organization_name}:{head_branch_name}"
head_branch_name = f"{self.repo_config.organization_name}:{head_branch_name}"

# retrieve all pulls ordered by created descending
prs = self.read_client.get_pulls(base=base_branch_name, head=head_branch_name, state=state, sort="created", direction="desc")
Expand Down
4 changes: 2 additions & 2 deletions src/codegen/git/utils/codeowner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ def create_codeowners_parser_for_repo(py_github_repo: GitRepoClient) -> CodeOwne
return codeowners
except Exception as e:
continue
logger.info(f"Failed to create CODEOWNERS parser for repo: {py_github_repo.repo.id}. Returning None.")
logger.info(f"Failed to create CODEOWNERS parser for repo: {py_github_repo.repo_config.id}. Returning None.")
return None


def get_codeowners_for_pull(repo: GitRepoClient, pull: PullRequest) -> list[str]:
codeowners_parser = create_codeowners_parser_for_repo(repo)
if not codeowners_parser:
logger.warning(f"Failed to create codeowners parser for repo: {repo.repo.id}. Returning empty list.")
logger.warning(f"Failed to create codeowners parser for repo: {repo.repo_config.id}. Returning empty list.")
return []
codeowners_for_pull_set = set()
pull_files = pull.get_files()
Expand Down
Loading