From 2f7f3abee494977aa2adf2f50ceef0fe0c406800 Mon Sep 17 00:00:00 2001 From: JoannaaKL Date: Fri, 14 Nov 2025 19:32:04 +0100 Subject: [PATCH 01/19] Apply lockdown mode to issues and pull requests --- pkg/github/issues.go | 51 ++++++++++++++++++++++++++----- pkg/github/pullrequests.go | 53 ++++++++++++++++++++++++++++----- pkg/github/pullrequests_test.go | 24 +++++++-------- pkg/github/tools.go | 2 +- pkg/lockdown/lockdown.go | 21 ++----------- 5 files changed, 104 insertions(+), 47 deletions(-) diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 1c4f9514c..14bc593b7 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -299,11 +299,11 @@ Options are: case "get": return GetIssue(ctx, client, gqlClient, owner, repo, issueNumber, flags) case "get_comments": - return GetIssueComments(ctx, client, owner, repo, issueNumber, pagination, flags) + return GetIssueComments(ctx, client, gqlClient, owner, repo, issueNumber, pagination, flags) case "get_sub_issues": - return GetSubIssues(ctx, client, owner, repo, issueNumber, pagination, flags) + return GetSubIssues(ctx, client, gqlClient, owner, repo, issueNumber, pagination, flags) case "get_labels": - return GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber, flags) + return GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber) default: return mcp.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil } @@ -327,11 +327,11 @@ func GetIssue(ctx context.Context, client *github.Client, gqlClient *githubv4.Cl if flags.LockdownMode { if issue.User != nil { - shouldRemoveContent, err := lockdown.ShouldRemoveContent(ctx, gqlClient, *issue.User.Login, owner, repo) + isPrivate, hasPushAccess, err := lockdown.GetRepoAccessInfo(ctx, gqlClient, *issue.User.Login, owner, repo) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil } - if shouldRemoveContent { + if !isPrivate && !hasPushAccess { return mcp.NewToolResultError("access to issue details is restricted by lockdown mode"), nil } } @@ -355,7 +355,7 @@ func GetIssue(ctx context.Context, client *github.Client, gqlClient *githubv4.Cl return mcp.NewToolResultText(string(r)), nil } -func GetIssueComments(ctx context.Context, client *github.Client, owner string, repo string, issueNumber int, pagination PaginationParams, _ FeatureFlags) (*mcp.CallToolResult, error) { +func GetIssueComments(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner string, repo string, issueNumber int, pagination PaginationParams, flags FeatureFlags) (*mcp.CallToolResult, error) { opts := &github.IssueListCommentsOptions{ ListOptions: github.ListOptions{ Page: pagination.Page, @@ -376,6 +376,23 @@ func GetIssueComments(ctx context.Context, client *github.Client, owner string, } return mcp.NewToolResultError(fmt.Sprintf("failed to get issue comments: %s", string(body))), nil } + if flags.LockdownMode { + filteredComments := []*github.IssueComment{} + for _, comment := range comments { + isPrivate, hasPushAccess, err := lockdown.GetRepoAccessInfo(ctx, gqlClient, *comment.User.Login, owner, repo) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil + } + // Do not filter content for private repositories + if isPrivate { + break + } + if hasPushAccess { + filteredComments = append(filteredComments, comment) + } + } + comments = filteredComments + } r, err := json.Marshal(comments) if err != nil { @@ -385,7 +402,7 @@ func GetIssueComments(ctx context.Context, client *github.Client, owner string, return mcp.NewToolResultText(string(r)), nil } -func GetSubIssues(ctx context.Context, client *github.Client, owner string, repo string, issueNumber int, pagination PaginationParams, _ FeatureFlags) (*mcp.CallToolResult, error) { +func GetSubIssues(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner string, repo string, issueNumber int, pagination PaginationParams, featureFlags FeatureFlags) (*mcp.CallToolResult, error) { opts := &github.IssueListOptions{ ListOptions: github.ListOptions{ Page: pagination.Page, @@ -412,6 +429,24 @@ func GetSubIssues(ctx context.Context, client *github.Client, owner string, repo return mcp.NewToolResultError(fmt.Sprintf("failed to list sub-issues: %s", string(body))), nil } + if featureFlags.LockdownMode { + filteredSubIssues := []*github.SubIssue{} + for _, subIssue := range subIssues { + isPrivate, hasPushAccess, err := lockdown.GetRepoAccessInfo(ctx, gqlClient, *subIssue.User.Login, owner, repo) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil + } + // Repo is private, do not filter content + if isPrivate { + break + } + if hasPushAccess { + filteredSubIssues = append(filteredSubIssues, subIssue) + } + } + subIssues = filteredSubIssues + } + r, err := json.Marshal(subIssues) if err != nil { return nil, fmt.Errorf("failed to marshal response: %w", err) @@ -420,7 +455,7 @@ func GetSubIssues(ctx context.Context, client *github.Client, owner string, repo return mcp.NewToolResultText(string(r)), nil } -func GetIssueLabels(ctx context.Context, client *githubv4.Client, owner string, repo string, issueNumber int, _ FeatureFlags) (*mcp.CallToolResult, error) { +func GetIssueLabels(ctx context.Context, client *githubv4.Client, owner string, repo string, issueNumber int) (*mcp.CallToolResult, error) { // Get current labels on the issue using GraphQL var query struct { Repository struct { diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index e64ae03e4..dee949789 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -14,12 +14,13 @@ import ( "github.com/shurcooL/githubv4" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/sanitize" "github.com/github/github-mcp-server/pkg/translations" ) // GetPullRequest creates a tool to get details of a specific pull request. -func PullRequestRead(getClient GetClientFn, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, server.ToolHandlerFunc) { +func PullRequestRead(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, server.ToolHandlerFunc) { return mcp.NewTool("pull_request_read", mcp.WithDescription(t("TOOL_PULL_REQUEST_READ_DESCRIPTION", "Get information on a specific pull request in GitHub repository.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -83,10 +84,15 @@ Possible options: return nil, fmt.Errorf("failed to get GitHub client: %w", err) } + gqlClient, err := getGQLClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub GraphQL client: %w", err) + } + switch method { case "get": - return GetPullRequest(ctx, client, owner, repo, pullNumber) + return GetPullRequest(ctx, client, gqlClient, owner, repo, pullNumber, flags) case "get_diff": return GetPullRequestDiff(ctx, client, owner, repo, pullNumber) case "get_status": @@ -94,18 +100,18 @@ Possible options: case "get_files": return GetPullRequestFiles(ctx, client, owner, repo, pullNumber, pagination) case "get_review_comments": - return GetPullRequestReviewComments(ctx, client, owner, repo, pullNumber, pagination) + return GetPullRequestReviewComments(ctx, client, gqlClient, owner, repo, pullNumber, pagination, flags) case "get_reviews": - return GetPullRequestReviews(ctx, client, owner, repo, pullNumber) + return GetPullRequestReviews(ctx, client, gqlClient, owner, repo, pullNumber, flags) case "get_comments": - return GetIssueComments(ctx, client, owner, repo, pullNumber, pagination, flags) + return GetIssueComments(ctx, client, gqlClient, owner, repo, pullNumber, pagination, flags) default: return nil, fmt.Errorf("unknown method: %s", method) } } } -func GetPullRequest(ctx context.Context, client *github.Client, owner, repo string, pullNumber int) (*mcp.CallToolResult, error) { +func GetPullRequest(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner, repo string, pullNumber int, ff FeatureFlags) (*mcp.CallToolResult, error) { pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, @@ -134,6 +140,17 @@ func GetPullRequest(ctx context.Context, client *github.Client, owner, repo stri } } + if ff.LockdownMode { + isPrivate, hasPushAccess, err := lockdown.GetRepoAccessInfo(ctx, gqlClient, pr.GetUser().GetLogin(), owner, repo) + if err != nil { + return nil, fmt.Errorf("failed to check content removal: %w", err) + } + + if !isPrivate && !hasPushAccess { + return mcp.NewToolResultError("access to pull request is restricted by lockdown mode"), nil + } + } + r, err := json.Marshal(pr) if err != nil { return nil, fmt.Errorf("failed to marshal response: %w", err) @@ -249,7 +266,7 @@ func GetPullRequestFiles(ctx context.Context, client *github.Client, owner, repo return mcp.NewToolResultText(string(r)), nil } -func GetPullRequestReviewComments(ctx context.Context, client *github.Client, owner, repo string, pullNumber int, pagination PaginationParams) (*mcp.CallToolResult, error) { +func GetPullRequestReviewComments(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner, repo string, pullNumber int, pagination PaginationParams, ff FeatureFlags) (*mcp.CallToolResult, error) { opts := &github.PullRequestListCommentsOptions{ ListOptions: github.ListOptions{ PerPage: pagination.PerPage, @@ -275,6 +292,16 @@ func GetPullRequestReviewComments(ctx context.Context, client *github.Client, ow return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request review comments: %s", string(body))), nil } + if ff.LockdownMode { + isPrivate, hasPushAccess, err := lockdown.GetRepoAccessInfo(ctx, gqlClient, comments[0].GetUser().GetLogin(), owner, repo) + if err != nil { + return nil, fmt.Errorf("failed to check content removal: %w", err) + } + if !isPrivate && !hasPushAccess { + return mcp.NewToolResultError("access to pull request review comments is restricted by lockdown mode"), nil + } + } + r, err := json.Marshal(comments) if err != nil { return nil, fmt.Errorf("failed to marshal response: %w", err) @@ -283,7 +310,7 @@ func GetPullRequestReviewComments(ctx context.Context, client *github.Client, ow return mcp.NewToolResultText(string(r)), nil } -func GetPullRequestReviews(ctx context.Context, client *github.Client, owner, repo string, pullNumber int) (*mcp.CallToolResult, error) { +func GetPullRequestReviews(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner, repo string, pullNumber int, ff FeatureFlags) (*mcp.CallToolResult, error) { reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, @@ -302,6 +329,16 @@ func GetPullRequestReviews(ctx context.Context, client *github.Client, owner, re return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request reviews: %s", string(body))), nil } + if ff.LockdownMode { + isPrivate, hasPushAccess, err := lockdown.GetRepoAccessInfo(ctx, gqlClient, reviews[0].GetUser().GetLogin(), owner, repo) + if err != nil { + return nil, fmt.Errorf("failed to check content removal: %w", err) + } + if !isPrivate && !hasPushAccess { + return mcp.NewToolResultError("access to pull request reviews is restricted by lockdown mode"), nil + } + } + r, err := json.Marshal(reviews) if err != nil { return nil, fmt.Errorf("failed to marshal response: %w", err) diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 347bce672..d8a123e2f 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -21,7 +21,7 @@ import ( func Test_GetPullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -102,7 +102,7 @@ func Test_GetPullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1133,7 +1133,7 @@ func Test_SearchPullRequests(t *testing.T) { func Test_GetPullRequestFiles(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1236,7 +1236,7 @@ func Test_GetPullRequestFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1277,7 +1277,7 @@ func Test_GetPullRequestFiles(t *testing.T) { func Test_GetPullRequestStatus(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1404,7 +1404,7 @@ func Test_GetPullRequestStatus(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1566,7 +1566,7 @@ func Test_UpdatePullRequestBranch(t *testing.T) { func Test_GetPullRequestComments(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1658,7 +1658,7 @@ func Test_GetPullRequestComments(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1700,7 +1700,7 @@ func Test_GetPullRequestComments(t *testing.T) { func Test_GetPullRequestReviews(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1788,7 +1788,7 @@ func Test_GetPullRequestReviews(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -2789,7 +2789,7 @@ func TestGetPullRequestDiff(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -2847,7 +2847,7 @@ index 5d6e7b2..8a4f5c3 100644 // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 0594f2f94..01d18852c 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -224,7 +224,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG ) pullRequests := toolsets.NewToolset(ToolsetMetadataPullRequests.ID, ToolsetMetadataPullRequests.Description). AddReadTools( - toolsets.NewServerTool(PullRequestRead(getClient, t, flags)), + toolsets.NewServerTool(PullRequestRead(getClient, getGQLClient, t, flags)), toolsets.NewServerTool(ListPullRequests(getClient, t)), toolsets.NewServerTool(SearchPullRequests(getClient, t)), ). diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go index 5a474f73c..9a68289ad 100644 --- a/pkg/lockdown/lockdown.go +++ b/pkg/lockdown/lockdown.go @@ -8,24 +8,9 @@ import ( "github.com/shurcooL/githubv4" ) -// ShouldRemoveContent determines if content should be removed based on -// lockdown mode rules. It checks if the repository is private and if the user -// has push access to the repository. -func ShouldRemoveContent(ctx context.Context, client *githubv4.Client, username, owner, repo string) (bool, error) { - isPrivate, hasPushAccess, err := repoAccessInfo(ctx, client, username, owner, repo) - if err != nil { - return false, err - } - - // Do not filter content for private repositories - if isPrivate { - return false, nil - } - - return !hasPushAccess, nil -} - -func repoAccessInfo(ctx context.Context, client *githubv4.Client, username, owner, repo string) (bool, bool, error) { +// GetRepoAccessInfo retrieves whether the repository is private and whether +// the user has push access to the repository. +func GetRepoAccessInfo(ctx context.Context, client *githubv4.Client, username, owner, repo string) (bool, bool, error) { if client == nil { return false, false, fmt.Errorf("nil GraphQL client") } From 55623350cc385c1a0ab031fc8547d64e038cd550 Mon Sep 17 00:00:00 2001 From: JoannaaKL Date: Mon, 17 Nov 2025 16:53:35 +0100 Subject: [PATCH 02/19] Add cache --- cmd/github-mcp-server/generate_docs.go | 7 +- cmd/github-mcp-server/main.go | 5 + internal/ghmcp/server.go | 51 ++++--- pkg/github/issues.go | 52 +++++-- pkg/github/issues_test.go | 16 +- pkg/github/pullrequests.go | 73 +++++---- pkg/github/pullrequests_test.go | 24 +-- pkg/github/server_test.go | 5 + pkg/github/tools.go | 7 +- pkg/lockdown/lockdown.go | 198 ++++++++++++++++++++++++- pkg/lockdown/lockdown_test.go | 149 +++++++++++++++++++ 11 files changed, 496 insertions(+), 91 deletions(-) create mode 100644 pkg/lockdown/lockdown_test.go diff --git a/cmd/github-mcp-server/generate_docs.go b/cmd/github-mcp-server/generate_docs.go index 359370760..ee41b8493 100644 --- a/cmd/github-mcp-server/generate_docs.go +++ b/cmd/github-mcp-server/generate_docs.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/raw" "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" @@ -64,7 +65,8 @@ func generateReadmeDocs(readmePath string) error { t, _ := translations.TranslationHelper() // Create toolset group with mock clients - tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{}) + repoAccessCache := lockdown.NewRepoAccessCache(nil) + tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{}, repoAccessCache) // Generate toolsets documentation toolsetsDoc := generateToolsetsDoc(tsg) @@ -302,7 +304,8 @@ func generateRemoteToolsetsDoc() string { t, _ := translations.TranslationHelper() // Create toolset group with mock clients - tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{}) + repoAccessCache := lockdown.NewRepoAccessCache(nil) + tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{}, repoAccessCache) // Generate table header buf.WriteString("| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) |\n") diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 125cd5a8d..3d4113644 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "strings" + "time" "github.com/github/github-mcp-server/internal/ghmcp" "github.com/github/github-mcp-server/pkg/github" @@ -50,6 +51,7 @@ var ( enabledToolsets = []string{github.ToolsetMetadataDefault.ID} } + ttl := viper.GetDuration("repo-access-cache-ttl") stdioServerConfig := ghmcp.StdioServerConfig{ Version: version, Host: viper.GetString("host"), @@ -62,6 +64,7 @@ var ( LogFilePath: viper.GetString("log-file"), ContentWindowSize: viper.GetInt("content-window-size"), LockdownMode: viper.GetBool("lockdown-mode"), + RepoAccessCacheTTL: &ttl, } return ghmcp.RunStdioServer(stdioServerConfig) }, @@ -84,6 +87,7 @@ func init() { rootCmd.PersistentFlags().String("gh-host", "", "Specify the GitHub hostname (for GitHub Enterprise etc.)") rootCmd.PersistentFlags().Int("content-window-size", 5000, "Specify the content window size") rootCmd.PersistentFlags().Bool("lockdown-mode", false, "Enable lockdown mode") + rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)") // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) @@ -95,6 +99,7 @@ func init() { _ = viper.BindPFlag("host", rootCmd.PersistentFlags().Lookup("gh-host")) _ = viper.BindPFlag("content-window-size", rootCmd.PersistentFlags().Lookup("content-window-size")) _ = viper.BindPFlag("lockdown-mode", rootCmd.PersistentFlags().Lookup("lockdown-mode")) + _ = viper.BindPFlag("repo-access-cache-ttl", rootCmd.PersistentFlags().Lookup("repo-access-cache-ttl")) // Add subcommands rootCmd.AddCommand(stdioCmd) diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 1067a222f..f82fa0553 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -16,6 +16,7 @@ import ( "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/lockdown" mcplog "github.com/github/github-mcp-server/pkg/log" "github.com/github/github-mcp-server/pkg/raw" "github.com/github/github-mcp-server/pkg/translations" @@ -54,6 +55,9 @@ type MCPServerConfig struct { // LockdownMode indicates if we should enable lockdown mode LockdownMode bool + + // RepoAccessTTL overrides the default TTL for repository access cache entries. + RepoAccessTTL *time.Duration } const stdioServerLogPrefix = "stdioserver" @@ -80,6 +84,14 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { }, } // We're going to wrap the Transport later in beforeInit gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient) + repoAccessOpts := []lockdown.RepoAccessOption{} + if cfg.RepoAccessTTL != nil { + repoAccessOpts = append(repoAccessOpts, lockdown.WithTTL(*cfg.RepoAccessTTL)) + } + var repoAccessCache *lockdown.RepoAccessCache + if cfg.LockdownMode { + repoAccessCache = lockdown.NewRepoAccessCache(gqlClient, repoAccessOpts...) + } // When a client send an initialize request, update the user agent to include the client info. beforeInit := func(_ context.Context, _ any, message *mcp.InitializeRequest) { @@ -165,6 +177,7 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { cfg.Translator, cfg.ContentWindowSize, github.FeatureFlags{LockdownMode: cfg.LockdownMode}, + repoAccessCache, ) err = tsg.EnableToolsets(enabledToolsets, nil) @@ -219,6 +232,9 @@ type StdioServerConfig struct { // LockdownMode indicates if we should enable lockdown mode LockdownMode bool + + // RepoAccessCacheTTL overrides the default TTL for repository access cache entries. + RepoAccessCacheTTL *time.Duration } // RunStdioServer is not concurrent safe. @@ -229,23 +245,6 @@ func RunStdioServer(cfg StdioServerConfig) error { t, dumpTranslations := translations.TranslationHelper() - ghServer, err := NewMCPServer(MCPServerConfig{ - Version: cfg.Version, - Host: cfg.Host, - Token: cfg.Token, - EnabledToolsets: cfg.EnabledToolsets, - DynamicToolsets: cfg.DynamicToolsets, - ReadOnly: cfg.ReadOnly, - Translator: t, - ContentWindowSize: cfg.ContentWindowSize, - LockdownMode: cfg.LockdownMode, - }) - if err != nil { - return fmt.Errorf("failed to create MCP server: %w", err) - } - - stdioServer := server.NewStdioServer(ghServer) - var slogHandler slog.Handler var logOutput io.Writer if cfg.LogFilePath != "" { @@ -262,6 +261,24 @@ func RunStdioServer(cfg StdioServerConfig) error { logger := slog.New(slogHandler) logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "dynamicToolsets", cfg.DynamicToolsets, "readOnly", cfg.ReadOnly, "lockdownEnabled", cfg.LockdownMode) stdLogger := log.New(logOutput, stdioServerLogPrefix, 0) + + ghServer, err := NewMCPServer(MCPServerConfig{ + Version: cfg.Version, + Host: cfg.Host, + Token: cfg.Token, + EnabledToolsets: cfg.EnabledToolsets, + DynamicToolsets: cfg.DynamicToolsets, + ReadOnly: cfg.ReadOnly, + Translator: t, + ContentWindowSize: cfg.ContentWindowSize, + LockdownMode: cfg.LockdownMode, + RepoAccessTTL: cfg.RepoAccessCacheTTL, + }) + if err != nil { + return fmt.Errorf("failed to create MCP server: %w", err) + } + + stdioServer := server.NewStdioServer(ghServer) stdioServer.SetErrorLogger(stdLogger) if cfg.ExportTranslations { diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 14bc593b7..0af68e712 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -228,7 +228,7 @@ func fragmentToIssue(fragment IssueFragment) *github.Issue { } // GetIssue creates a tool to get details of a specific issue in a GitHub repository. -func IssueRead(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc, flags FeatureFlags) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func IssueRead(getClient GetClientFn, getGQLClient GetGQLClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("issue_read", mcp.WithDescription(t("TOOL_ISSUE_READ_DESCRIPTION", "Get information about a specific issue in a GitHub repository.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -297,11 +297,11 @@ Options are: switch method { case "get": - return GetIssue(ctx, client, gqlClient, owner, repo, issueNumber, flags) + return GetIssue(ctx, client, cache, owner, repo, issueNumber, flags) case "get_comments": - return GetIssueComments(ctx, client, gqlClient, owner, repo, issueNumber, pagination, flags) + return GetIssueComments(ctx, client, cache, owner, repo, issueNumber, pagination, flags) case "get_sub_issues": - return GetSubIssues(ctx, client, gqlClient, owner, repo, issueNumber, pagination, flags) + return GetSubIssues(ctx, client, cache, owner, repo, issueNumber, pagination, flags) case "get_labels": return GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber) default: @@ -310,7 +310,7 @@ Options are: } } -func GetIssue(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner string, repo string, issueNumber int, flags FeatureFlags) (*mcp.CallToolResult, error) { +func GetIssue(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner string, repo string, issueNumber int, flags FeatureFlags) (*mcp.CallToolResult, error) { issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber) if err != nil { return nil, fmt.Errorf("failed to get issue: %w", err) @@ -326,8 +326,12 @@ func GetIssue(ctx context.Context, client *github.Client, gqlClient *githubv4.Cl } if flags.LockdownMode { - if issue.User != nil { - isPrivate, hasPushAccess, err := lockdown.GetRepoAccessInfo(ctx, gqlClient, *issue.User.Login, owner, repo) + if cache == nil { + return nil, fmt.Errorf("lockdown cache is not configured") + } + login := issue.GetUser().GetLogin() + if login != "" { + isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil } @@ -355,7 +359,7 @@ func GetIssue(ctx context.Context, client *github.Client, gqlClient *githubv4.Cl return mcp.NewToolResultText(string(r)), nil } -func GetIssueComments(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner string, repo string, issueNumber int, pagination PaginationParams, flags FeatureFlags) (*mcp.CallToolResult, error) { +func GetIssueComments(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner string, repo string, issueNumber int, pagination PaginationParams, flags FeatureFlags) (*mcp.CallToolResult, error) { opts := &github.IssueListCommentsOptions{ ListOptions: github.ListOptions{ Page: pagination.Page, @@ -377,9 +381,20 @@ func GetIssueComments(ctx context.Context, client *github.Client, gqlClient *git return mcp.NewToolResultError(fmt.Sprintf("failed to get issue comments: %s", string(body))), nil } if flags.LockdownMode { - filteredComments := []*github.IssueComment{} + if cache == nil { + return nil, fmt.Errorf("lockdown cache is not configured") + } + filteredComments := make([]*github.IssueComment, 0, len(comments)) for _, comment := range comments { - isPrivate, hasPushAccess, err := lockdown.GetRepoAccessInfo(ctx, gqlClient, *comment.User.Login, owner, repo) + user := comment.User + if user == nil { + continue + } + login := user.GetLogin() + if login == "" { + continue + } + isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil } @@ -402,7 +417,7 @@ func GetIssueComments(ctx context.Context, client *github.Client, gqlClient *git return mcp.NewToolResultText(string(r)), nil } -func GetSubIssues(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner string, repo string, issueNumber int, pagination PaginationParams, featureFlags FeatureFlags) (*mcp.CallToolResult, error) { +func GetSubIssues(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner string, repo string, issueNumber int, pagination PaginationParams, featureFlags FeatureFlags) (*mcp.CallToolResult, error) { opts := &github.IssueListOptions{ ListOptions: github.ListOptions{ Page: pagination.Page, @@ -430,9 +445,20 @@ func GetSubIssues(ctx context.Context, client *github.Client, gqlClient *githubv } if featureFlags.LockdownMode { - filteredSubIssues := []*github.SubIssue{} + if cache == nil { + return nil, fmt.Errorf("lockdown cache is not configured") + } + filteredSubIssues := make([]*github.SubIssue, 0, len(subIssues)) for _, subIssue := range subIssues { - isPrivate, hasPushAccess, err := lockdown.GetRepoAccessInfo(ctx, gqlClient, *subIssue.User.Login, owner, repo) + user := subIssue.User + if user == nil { + continue + } + login := user.GetLogin() + if login == "" { + continue + } + isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil } diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index 4cc3a1302..b28e8dd82 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -23,7 +23,7 @@ func Test_GetIssue(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) defaultGQLClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(defaultGQLClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(defaultGQLClient), stubRepoAccessCache(defaultGQLClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -212,7 +212,7 @@ func Test_GetIssue(t *testing.T) { } flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, flags) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, flags) request := createMCPRequest(tc.requestArgs) result, err := handler(context.Background(), request) @@ -1710,7 +1710,7 @@ func Test_GetIssueComments(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) gqlClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -1816,7 +1816,7 @@ func Test_GetIssueComments(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1853,7 +1853,7 @@ func Test_GetIssueLabels(t *testing.T) { // Verify tool definition mockGQClient := githubv4.NewClient(nil) mockClient := github.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(mockGQClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(mockGQClient), stubRepoAccessCache(mockGQClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -1928,7 +1928,7 @@ func Test_GetIssueLabels(t *testing.T) { t.Run(tc.name, func(t *testing.T) { gqlClient := githubv4.NewClient(tc.mockedClient) client := github.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) request := createMCPRequest(tc.requestArgs) result, err := handler(context.Background(), request) @@ -2619,7 +2619,7 @@ func Test_GetSubIssues(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) gqlClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -2816,7 +2816,7 @@ func Test_GetSubIssues(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index dee949789..3d805bb00 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -20,7 +20,7 @@ import ( ) // GetPullRequest creates a tool to get details of a specific pull request. -func PullRequestRead(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, server.ToolHandlerFunc) { +func PullRequestRead(getClient GetClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, server.ToolHandlerFunc) { return mcp.NewTool("pull_request_read", mcp.WithDescription(t("TOOL_PULL_REQUEST_READ_DESCRIPTION", "Get information on a specific pull request in GitHub repository.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -84,15 +84,10 @@ Possible options: return nil, fmt.Errorf("failed to get GitHub client: %w", err) } - gqlClient, err := getGQLClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub GraphQL client: %w", err) - } - switch method { case "get": - return GetPullRequest(ctx, client, gqlClient, owner, repo, pullNumber, flags) + return GetPullRequest(ctx, client, cache, owner, repo, pullNumber, flags) case "get_diff": return GetPullRequestDiff(ctx, client, owner, repo, pullNumber) case "get_status": @@ -100,18 +95,18 @@ Possible options: case "get_files": return GetPullRequestFiles(ctx, client, owner, repo, pullNumber, pagination) case "get_review_comments": - return GetPullRequestReviewComments(ctx, client, gqlClient, owner, repo, pullNumber, pagination, flags) + return GetPullRequestReviewComments(ctx, client, cache, owner, repo, pullNumber, pagination, flags) case "get_reviews": - return GetPullRequestReviews(ctx, client, gqlClient, owner, repo, pullNumber, flags) + return GetPullRequestReviews(ctx, client, cache, owner, repo, pullNumber, flags) case "get_comments": - return GetIssueComments(ctx, client, gqlClient, owner, repo, pullNumber, pagination, flags) + return GetIssueComments(ctx, client, cache, owner, repo, pullNumber, pagination, flags) default: return nil, fmt.Errorf("unknown method: %s", method) } } } -func GetPullRequest(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner, repo string, pullNumber int, ff FeatureFlags) (*mcp.CallToolResult, error) { +func GetPullRequest(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner, repo string, pullNumber int, ff FeatureFlags) (*mcp.CallToolResult, error) { pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, @@ -141,13 +136,19 @@ func GetPullRequest(ctx context.Context, client *github.Client, gqlClient *githu } if ff.LockdownMode { - isPrivate, hasPushAccess, err := lockdown.GetRepoAccessInfo(ctx, gqlClient, pr.GetUser().GetLogin(), owner, repo) - if err != nil { - return nil, fmt.Errorf("failed to check content removal: %w", err) + if cache == nil { + return nil, fmt.Errorf("lockdown cache is not configured") } + login := pr.GetUser().GetLogin() + if login != "" { + isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) + if err != nil { + return nil, fmt.Errorf("failed to check content removal: %w", err) + } - if !isPrivate && !hasPushAccess { - return mcp.NewToolResultError("access to pull request is restricted by lockdown mode"), nil + if !isPrivate && !hasPushAccess { + return mcp.NewToolResultError("access to pull request is restricted by lockdown mode"), nil + } } } @@ -266,7 +267,7 @@ func GetPullRequestFiles(ctx context.Context, client *github.Client, owner, repo return mcp.NewToolResultText(string(r)), nil } -func GetPullRequestReviewComments(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner, repo string, pullNumber int, pagination PaginationParams, ff FeatureFlags) (*mcp.CallToolResult, error) { +func GetPullRequestReviewComments(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner, repo string, pullNumber int, pagination PaginationParams, ff FeatureFlags) (*mcp.CallToolResult, error) { opts := &github.PullRequestListCommentsOptions{ ListOptions: github.ListOptions{ PerPage: pagination.PerPage, @@ -293,12 +294,20 @@ func GetPullRequestReviewComments(ctx context.Context, client *github.Client, gq } if ff.LockdownMode { - isPrivate, hasPushAccess, err := lockdown.GetRepoAccessInfo(ctx, gqlClient, comments[0].GetUser().GetLogin(), owner, repo) - if err != nil { - return nil, fmt.Errorf("failed to check content removal: %w", err) + if cache == nil { + return nil, fmt.Errorf("lockdown cache is not configured") } - if !isPrivate && !hasPushAccess { - return mcp.NewToolResultError("access to pull request review comments is restricted by lockdown mode"), nil + if len(comments) > 0 { + login := comments[0].GetUser().GetLogin() + if login != "" { + isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) + if err != nil { + return nil, fmt.Errorf("failed to check content removal: %w", err) + } + if !isPrivate && !hasPushAccess { + return mcp.NewToolResultError("access to pull request review comments is restricted by lockdown mode"), nil + } + } } } @@ -310,7 +319,7 @@ func GetPullRequestReviewComments(ctx context.Context, client *github.Client, gq return mcp.NewToolResultText(string(r)), nil } -func GetPullRequestReviews(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner, repo string, pullNumber int, ff FeatureFlags) (*mcp.CallToolResult, error) { +func GetPullRequestReviews(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner, repo string, pullNumber int, ff FeatureFlags) (*mcp.CallToolResult, error) { reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, @@ -330,12 +339,20 @@ func GetPullRequestReviews(ctx context.Context, client *github.Client, gqlClient } if ff.LockdownMode { - isPrivate, hasPushAccess, err := lockdown.GetRepoAccessInfo(ctx, gqlClient, reviews[0].GetUser().GetLogin(), owner, repo) - if err != nil { - return nil, fmt.Errorf("failed to check content removal: %w", err) + if cache == nil { + return nil, fmt.Errorf("lockdown cache is not configured") } - if !isPrivate && !hasPushAccess { - return mcp.NewToolResultError("access to pull request reviews is restricted by lockdown mode"), nil + if len(reviews) > 0 { + login := reviews[0].GetUser().GetLogin() + if login != "" { + isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) + if err != nil { + return nil, fmt.Errorf("failed to check content removal: %w", err) + } + if !isPrivate && !hasPushAccess { + return mcp.NewToolResultError("access to pull request reviews is restricted by lockdown mode"), nil + } + } } } diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index d8a123e2f..22dbec5d0 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -21,7 +21,7 @@ import ( func Test_GetPullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -102,7 +102,7 @@ func Test_GetPullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1133,7 +1133,7 @@ func Test_SearchPullRequests(t *testing.T) { func Test_GetPullRequestFiles(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1236,7 +1236,7 @@ func Test_GetPullRequestFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1277,7 +1277,7 @@ func Test_GetPullRequestFiles(t *testing.T) { func Test_GetPullRequestStatus(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1404,7 +1404,7 @@ func Test_GetPullRequestStatus(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1566,7 +1566,7 @@ func Test_UpdatePullRequestBranch(t *testing.T) { func Test_GetPullRequestComments(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1658,7 +1658,7 @@ func Test_GetPullRequestComments(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1700,7 +1700,7 @@ func Test_GetPullRequestComments(t *testing.T) { func Test_GetPullRequestReviews(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1788,7 +1788,7 @@ func Test_GetPullRequestReviews(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -2789,7 +2789,7 @@ func TestGetPullRequestDiff(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -2847,7 +2847,7 @@ index 5d6e7b2..8a4f5c3 100644 // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 3bfc1ef94..20e20b2bb 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -8,6 +8,7 @@ import ( "net/http" "testing" + "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/raw" "github.com/google/go-github/v79/github" "github.com/shurcooL/githubv4" @@ -38,6 +39,10 @@ func stubGetGQLClientFn(client *githubv4.Client) GetGQLClientFn { } } +func stubRepoAccessCache(client *githubv4.Client) *lockdown.RepoAccessCache { + return lockdown.NewRepoAccessCache(client) +} + func stubFeatureFlags(enabledFlags map[string]bool) FeatureFlags { return FeatureFlags{ LockdownMode: enabledFlags["lockdown-mode"], diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 01d18852c..74f3d52f2 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -5,6 +5,7 @@ import ( "fmt" "strings" + "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/raw" "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" @@ -159,7 +160,7 @@ func GetDefaultToolsetIDs() []string { } } -func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetGQLClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc, contentWindowSize int, flags FeatureFlags) *toolsets.ToolsetGroup { +func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetGQLClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc, contentWindowSize int, flags FeatureFlags, cache *lockdown.RepoAccessCache) *toolsets.ToolsetGroup { tsg := toolsets.NewToolsetGroup(readOnly) // Define all available features with their default state (disabled) @@ -199,7 +200,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG ) issues := toolsets.NewToolset(ToolsetMetadataIssues.ID, ToolsetMetadataIssues.Description). AddReadTools( - toolsets.NewServerTool(IssueRead(getClient, getGQLClient, t, flags)), + toolsets.NewServerTool(IssueRead(getClient, getGQLClient, cache, t, flags)), toolsets.NewServerTool(SearchIssues(getClient, t)), toolsets.NewServerTool(ListIssues(getGQLClient, t)), toolsets.NewServerTool(ListIssueTypes(getClient, t)), @@ -224,7 +225,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG ) pullRequests := toolsets.NewToolset(ToolsetMetadataPullRequests.ID, ToolsetMetadataPullRequests.Description). AddReadTools( - toolsets.NewServerTool(PullRequestRead(getClient, getGQLClient, t, flags)), + toolsets.NewServerTool(PullRequestRead(getClient, cache, t, flags)), toolsets.NewServerTool(ListPullRequests(getClient, t)), toolsets.NewServerTool(SearchPullRequests(getClient, t)), ). diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go index 9a68289ad..1749f311d 100644 --- a/pkg/lockdown/lockdown.go +++ b/pkg/lockdown/lockdown.go @@ -3,15 +3,176 @@ package lockdown import ( "context" "fmt" + "log/slog" "strings" + "sync" + "time" "github.com/shurcooL/githubv4" ) -// GetRepoAccessInfo retrieves whether the repository is private and whether -// the user has push access to the repository. -func GetRepoAccessInfo(ctx context.Context, client *githubv4.Client, username, owner, repo string) (bool, bool, error) { - if client == nil { +// RepoAccessCache caches repository metadata related to lockdown checks so that +// multiple tools can reuse the same access information safely across goroutines. +type RepoAccessCache struct { + client *githubv4.Client + mu sync.Mutex + cache map[string]*repoAccessCacheEntry + ttl time.Duration + logger *slog.Logger +} + +type repoAccessCacheEntry struct { + isPrivate bool + knownUsers map[string]bool // normalized login -> has push access + ready bool + timer *time.Timer +} + +const defaultRepoAccessTTL = 5 * time.Minute + +// RepoAccessOption configures RepoAccessCache at construction time. +type RepoAccessOption func(*RepoAccessCache) + +// WithTTL overrides the default TTL applied to cache entries. A non-positive +// duration disables expiration. +func WithTTL(ttl time.Duration) RepoAccessOption { + return func(c *RepoAccessCache) { + c.ttl = ttl + } +} + +// WithLogger sets the logger used for cache diagnostics. +func WithLogger(logger *slog.Logger) RepoAccessOption { + return func(c *RepoAccessCache) { + c.logger = logger + } +} + +// NewRepoAccessCache returns a cache bound to the provided GitHub GraphQL +// client. The cache is safe for concurrent use. +func NewRepoAccessCache(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessCache { + c := &RepoAccessCache{ + client: client, + cache: make(map[string]*repoAccessCacheEntry), + ttl: defaultRepoAccessTTL, + } + for _, opt := range opts { + if opt != nil { + opt(c) + } + } + c.logInfo("repo access cache initialized", "ttl", c.ttl) + return c +} + +// SetTTL overrides the default time-to-live used for cache entries. A non-positive +// duration disables expiration. +func (c *RepoAccessCache) SetTTL(ttl time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + c.ttl = ttl + c.logInfo("repo access cache TTL updated", "ttl", ttl) + for key, entry := range c.cache { + entry.scheduleExpiry(c, key) + } +} + +// SetLogger updates the logger used for cache diagnostics. +func (c *RepoAccessCache) SetLogger(logger *slog.Logger) { + c.mu.Lock() + c.logger = logger + c.mu.Unlock() +} + +// CacheStats summarizes cache activity counters. +type CacheStats struct { + Hits int64 + Misses int64 + Evictions int64 +} + +// GetRepoAccessInfo returns the repository's privacy status and whether the +// specified user has push permissions. Results are cached per repository to +// avoid repeated GraphQL round-trips. +func (c *RepoAccessCache) GetRepoAccessInfo(ctx context.Context, username, owner, repo string) (bool, bool, error) { + if c == nil { + return false, false, fmt.Errorf("nil repo access cache") + } + + key := cacheKey(owner, repo) + userKey := strings.ToLower(username) + c.mu.Lock() + entry := c.ensureEntry(key) + if entry.ready { + if cachedHasPush, known := entry.knownUsers[userKey]; known { + entry.scheduleExpiry(c, key) + c.logDebug("repo access cache hit", "owner", owner, "repo", repo, "user", username) + cachedPrivate := entry.isPrivate + c.mu.Unlock() + return cachedPrivate, cachedHasPush, nil + } + } + c.mu.Unlock() + c.logDebug("repo access cache miss", "owner", owner, "repo", repo, "user", username) + + isPrivate, hasPush, err := c.queryRepoAccessInfo(ctx, username, owner, repo) + if err != nil { + return false, false, err + } + + c.mu.Lock() + entry = c.ensureEntry(key) + entry.ready = true + entry.isPrivate = isPrivate + entry.knownUsers[userKey] = hasPush + entry.scheduleExpiry(c, key) + c.mu.Unlock() + + return isPrivate, hasPush, nil +} + +func (c *RepoAccessCache) ensureEntry(key string) *repoAccessCacheEntry { + if c.cache == nil { + c.cache = make(map[string]*repoAccessCacheEntry) + } + entry, ok := c.cache[key] + if !ok { + entry = &repoAccessCacheEntry{ + knownUsers: make(map[string]bool), + } + c.cache[key] = entry + } + return entry +} + +func (entry *repoAccessCacheEntry) scheduleExpiry(c *RepoAccessCache, key string) { + if entry.timer != nil { + entry.timer.Stop() + entry.timer = nil + } + + dur := c.ttl + if dur <= 0 { + return + } + + owner, repo := splitKey(key) + entry.timer = time.AfterFunc(dur, func() { + c.mu.Lock() + defer c.mu.Unlock() + + current, ok := c.cache[key] + if !ok || current != entry { + return + } + + delete(c.cache, key) + c.logDebug("repo access cache entry evicted", "owner", owner, "repo", repo) + }) +} + +func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, owner, repo string) (bool, bool, error) { + if c.client == nil { return false, false, fmt.Errorf("nil GraphQL client") } @@ -35,18 +196,15 @@ func GetRepoAccessInfo(ctx context.Context, client *githubv4.Client, username, o "username": githubv4.String(username), } - err := client.Query(ctx, &query, variables) - if err != nil { + if err := c.client.Query(ctx, &query, variables); err != nil { return false, false, fmt.Errorf("failed to query repository access info: %w", err) } - // Check if the user has push access hasPush := false for _, edge := range query.Repository.Collaborators.Edges { login := string(edge.Node.Login) if strings.EqualFold(login, username) { permission := string(edge.Permission) - // WRITE, ADMIN, and MAINTAIN permissions have push access hasPush = permission == "WRITE" || permission == "ADMIN" || permission == "MAINTAIN" break } @@ -54,3 +212,27 @@ func GetRepoAccessInfo(ctx context.Context, client *githubv4.Client, username, o return bool(query.Repository.IsPrivate), hasPush, nil } + +func cacheKey(owner, repo string) string { + return fmt.Sprintf("%s/%s", strings.ToLower(owner), strings.ToLower(repo)) +} + +func splitKey(key string) (string, string) { + owner, rest, found := strings.Cut(key, "/") + if !found { + return key, "" + } + return owner, rest +} + +func (c *RepoAccessCache) logDebug(msg string, args ...any) { + if c != nil && c.logger != nil { + c.logger.Debug(msg, args...) + } +} + +func (c *RepoAccessCache) logInfo(msg string, args ...any) { + if c != nil && c.logger != nil { + c.logger.Info(msg, args...) + } +} diff --git a/pkg/lockdown/lockdown_test.go b/pkg/lockdown/lockdown_test.go new file mode 100644 index 000000000..275924bf9 --- /dev/null +++ b/pkg/lockdown/lockdown_test.go @@ -0,0 +1,149 @@ +package lockdown + +import ( + "context" + "net/http" + "sync" + "testing" + "time" + + "github.com/github/github-mcp-server/internal/githubv4mock" + "github.com/shurcooL/githubv4" + "github.com/stretchr/testify/require" +) + +const ( + testOwner = "octo-org" + testRepo = "octo-repo" + testUser = "octocat" +) + +type repoAccessQuery struct { + Repository struct { + IsPrivate githubv4.Boolean + Collaborators struct { + Edges []struct { + Permission githubv4.String + Node struct { + Login githubv4.String + } + } + } `graphql:"collaborators(query: $username, first: 1)"` + } `graphql:"repository(owner: $owner, name: $name)"` +} + +type countingTransport struct { + mu sync.Mutex + next http.RoundTripper + calls int +} + +func (c *countingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + c.mu.Lock() + c.calls++ + c.mu.Unlock() + return c.next.RoundTrip(req) +} + +func (c *countingTransport) CallCount() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.calls +} + +func newMockRepoAccessCache(t *testing.T, ttl time.Duration) (*RepoAccessCache, *countingTransport) { + t.Helper() + + var query repoAccessQuery + + variables := map[string]any{ + "owner": githubv4.String(testOwner), + "name": githubv4.String(testRepo), + "username": githubv4.String(testUser), + } + + response := githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "isPrivate": false, + "collaborators": map[string]any{ + "edges": []any{ + map[string]any{ + "permission": "WRITE", + "node": map[string]any{ + "login": testUser, + }, + }, + }, + }, + }, + }) + + httpClient := githubv4mock.NewMockedHTTPClient(githubv4mock.NewQueryMatcher(query, variables, response)) + counting := &countingTransport{next: httpClient.Transport} + httpClient.Transport = counting + + gqlClient := githubv4.NewClient(httpClient) + + return NewRepoAccessCache(gqlClient, WithTTL(ttl)), counting +} + +func requireAccess(ctx context.Context, t *testing.T, cache *RepoAccessCache) { + t.Helper() + + isPrivate, hasPush, err := cache.GetRepoAccessInfo(ctx, testUser, testOwner, testRepo) + require.NoError(t, err) + require.False(t, isPrivate) + require.True(t, hasPush) +} + +func TestRepoAccessCacheEvictsAfterTTL(t *testing.T) { + t.Parallel() + + cache, transport := newMockRepoAccessCache(t, 5*time.Millisecond) + ctx := context.Background() + + requireAccess(ctx, t, cache) + requireAccess(ctx, t, cache) + require.EqualValues(t, 1, transport.CallCount()) + + time.Sleep(20 * time.Millisecond) + + requireAccess(ctx, t, cache) + require.EqualValues(t, 2, transport.CallCount()) +} + +func TestRepoAccessCacheTTLDisabled(t *testing.T) { + t.Parallel() + + cache, transport := newMockRepoAccessCache(t, 0) + ctx := context.Background() + + requireAccess(ctx, t, cache) + requireAccess(ctx, t, cache) + require.EqualValues(t, 1, transport.CallCount()) + + time.Sleep(20 * time.Millisecond) + + requireAccess(ctx, t, cache) + require.EqualValues(t, 1, transport.CallCount()) +} + +func TestRepoAccessCacheSetTTLReschedulesExistingEntry(t *testing.T) { + t.Parallel() + + cache, transport := newMockRepoAccessCache(t, 0) + ctx := context.Background() + + requireAccess(ctx, t, cache) + require.EqualValues(t, 1, transport.CallCount()) + + cache.SetTTL(5 * time.Millisecond) + + time.Sleep(20 * time.Millisecond) + + requireAccess(ctx, t, cache) + require.EqualValues(t, 2, transport.CallCount()) + + requireAccess(ctx, t, cache) + require.EqualValues(t, 2, transport.CallCount()) +} From e29a179d1ffad5673c82c4925799cb0afa2199a8 Mon Sep 17 00:00:00 2001 From: JoannaaKL Date: Mon, 17 Nov 2025 17:24:25 +0100 Subject: [PATCH 03/19] Unlock in defer --- pkg/lockdown/lockdown.go | 5 +---- pkg/lockdown/lockdown_test.go | 7 +++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go index 1749f311d..ddecca16d 100644 --- a/pkg/lockdown/lockdown.go +++ b/pkg/lockdown/lockdown.go @@ -102,17 +102,16 @@ func (c *RepoAccessCache) GetRepoAccessInfo(ctx context.Context, username, owner key := cacheKey(owner, repo) userKey := strings.ToLower(username) c.mu.Lock() + defer c.mu.Unlock() entry := c.ensureEntry(key) if entry.ready { if cachedHasPush, known := entry.knownUsers[userKey]; known { entry.scheduleExpiry(c, key) c.logDebug("repo access cache hit", "owner", owner, "repo", repo, "user", username) cachedPrivate := entry.isPrivate - c.mu.Unlock() return cachedPrivate, cachedHasPush, nil } } - c.mu.Unlock() c.logDebug("repo access cache miss", "owner", owner, "repo", repo, "user", username) isPrivate, hasPush, err := c.queryRepoAccessInfo(ctx, username, owner, repo) @@ -120,13 +119,11 @@ func (c *RepoAccessCache) GetRepoAccessInfo(ctx context.Context, username, owner return false, false, err } - c.mu.Lock() entry = c.ensureEntry(key) entry.ready = true entry.isPrivate = isPrivate entry.knownUsers[userKey] = hasPush entry.scheduleExpiry(c, key) - c.mu.Unlock() return isPrivate, hasPush, nil } diff --git a/pkg/lockdown/lockdown_test.go b/pkg/lockdown/lockdown_test.go index 275924bf9..2ebfe80ae 100644 --- a/pkg/lockdown/lockdown_test.go +++ b/pkg/lockdown/lockdown_test.go @@ -98,10 +98,9 @@ func requireAccess(ctx context.Context, t *testing.T, cache *RepoAccessCache) { func TestRepoAccessCacheEvictsAfterTTL(t *testing.T) { t.Parallel() + ctx := t.Context() cache, transport := newMockRepoAccessCache(t, 5*time.Millisecond) - ctx := context.Background() - requireAccess(ctx, t, cache) requireAccess(ctx, t, cache) require.EqualValues(t, 1, transport.CallCount()) @@ -113,10 +112,10 @@ func TestRepoAccessCacheEvictsAfterTTL(t *testing.T) { } func TestRepoAccessCacheTTLDisabled(t *testing.T) { + ctx := t.Context() t.Parallel() cache, transport := newMockRepoAccessCache(t, 0) - ctx := context.Background() requireAccess(ctx, t, cache) requireAccess(ctx, t, cache) @@ -129,10 +128,10 @@ func TestRepoAccessCacheTTLDisabled(t *testing.T) { } func TestRepoAccessCacheSetTTLReschedulesExistingEntry(t *testing.T) { + ctx := t.Context() t.Parallel() cache, transport := newMockRepoAccessCache(t, 0) - ctx := context.Background() requireAccess(ctx, t, cache) require.EqualValues(t, 1, transport.CallCount()) From b456547aca8cdbf88e8794c02c1a783f6daefab4 Mon Sep 17 00:00:00 2001 From: JoannaaKL Date: Tue, 18 Nov 2025 10:03:50 +0100 Subject: [PATCH 04/19] Add muesli/cache2go --- go.mod | 1 + go.sum | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/go.mod b/go.mod index 02b9ad252..2ad42d9d3 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/invopop/jsonschema v0.13.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect + github.com/muesli/cache2go v0.0.0-20221011235721-518229cd8021 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect diff --git a/go.sum b/go.sum index 1ac8b7606..d2e323cf6 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,4 @@ +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= @@ -63,6 +64,8 @@ github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwX github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA= github.com/migueleliasweb/go-github-mock v1.3.0 h1:2sVP9JEMB2ubQw1IKto3/fzF51oFC6eVWOOFDgQoq88= github.com/migueleliasweb/go-github-mock v1.3.0/go.mod h1:ipQhV8fTcj/G6m7BKzin08GaJ/3B5/SonRAkgrk0zCY= +github.com/muesli/cache2go v0.0.0-20221011235721-518229cd8021 h1:31Y+Yu373ymebRdJN1cWLLooHH8xAr0MhKTEJGV/87g= +github.com/muesli/cache2go v0.0.0-20221011235721-518229cd8021/go.mod h1:WERUkUryfUWlrHnFSO/BEUZ+7Ns8aZy7iVOGewxKzcc= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= @@ -92,6 +95,7 @@ github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3A github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= @@ -106,18 +110,23 @@ github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 h1:BHyfKlQyqbsFN5p3Ifn github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= +golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98= golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From b27f1e26ff0262ea7bae1ff178a702ceab3c3bef Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 18 Nov 2025 10:19:09 +0100 Subject: [PATCH 05/19] [WIP] Replace custom cache in lockdown.go with cache2go struct (#1425) * Initial plan * Replace custom cache with cache2go library - Added github.com/muesli/cache2go dependency - Replaced custom map-based cache with cache2go.CacheTable - Removed manual timer management (scheduleExpiry, ensureEntry methods) - Removed timer field from repoAccessCacheEntry struct - Updated GetRepoAccessInfo to use cache2go's Value() and Add() methods - Updated SetTTL to flush and re-add entries with new TTL - Used unique cache names per instance to avoid test interference - All existing tests pass with the new implementation Co-authored-by: JoannaaKL <67866556+JoannaaKL@users.noreply.github.com> * Final verification complete Co-authored-by: JoannaaKL <67866556+JoannaaKL@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: JoannaaKL <67866556+JoannaaKL@users.noreply.github.com> --- go.mod | 3 +- go.sum | 2 + pkg/lockdown/lockdown.go | 107 ++++++++++++++++++--------------------- 3 files changed, 54 insertions(+), 58 deletions(-) diff --git a/go.mod b/go.mod index 02b9ad252..8d5b1b274 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/mark3labs/mcp-go v0.36.0 github.com/microcosm-cc/bluemonday v1.0.27 github.com/migueleliasweb/go-github-mock v1.3.0 + github.com/muesli/cache2go v0.0.0-20221011235721-518229cd8021 github.com/spf13/cobra v1.10.1 github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 @@ -37,7 +38,7 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 - github.com/google/go-querystring v1.1.0 + github.com/google/go-querystring v1.1.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect diff --git a/go.sum b/go.sum index 1ac8b7606..0ff7b51fa 100644 --- a/go.sum +++ b/go.sum @@ -63,6 +63,8 @@ github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwX github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA= github.com/migueleliasweb/go-github-mock v1.3.0 h1:2sVP9JEMB2ubQw1IKto3/fzF51oFC6eVWOOFDgQoq88= github.com/migueleliasweb/go-github-mock v1.3.0/go.mod h1:ipQhV8fTcj/G6m7BKzin08GaJ/3B5/SonRAkgrk0zCY= +github.com/muesli/cache2go v0.0.0-20221011235721-518229cd8021 h1:31Y+Yu373ymebRdJN1cWLLooHH8xAr0MhKTEJGV/87g= +github.com/muesli/cache2go v0.0.0-20221011235721-518229cd8021/go.mod h1:WERUkUryfUWlrHnFSO/BEUZ+7Ns8aZy7iVOGewxKzcc= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go index ddecca16d..e41ba74b7 100644 --- a/pkg/lockdown/lockdown.go +++ b/pkg/lockdown/lockdown.go @@ -6,17 +6,21 @@ import ( "log/slog" "strings" "sync" + "sync/atomic" "time" + "github.com/muesli/cache2go" "github.com/shurcooL/githubv4" ) +var cacheNameCounter atomic.Uint64 + // RepoAccessCache caches repository metadata related to lockdown checks so that // multiple tools can reuse the same access information safely across goroutines. type RepoAccessCache struct { client *githubv4.Client mu sync.Mutex - cache map[string]*repoAccessCacheEntry + cache *cache2go.CacheTable ttl time.Duration logger *slog.Logger } @@ -25,7 +29,6 @@ type repoAccessCacheEntry struct { isPrivate bool knownUsers map[string]bool // normalized login -> has push access ready bool - timer *time.Timer } const defaultRepoAccessTTL = 5 * time.Minute @@ -51,9 +54,11 @@ func WithLogger(logger *slog.Logger) RepoAccessOption { // NewRepoAccessCache returns a cache bound to the provided GitHub GraphQL // client. The cache is safe for concurrent use. func NewRepoAccessCache(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessCache { + // Use a unique cache name for each instance to avoid sharing state between tests + cacheName := fmt.Sprintf("repoAccess-%d", cacheNameCounter.Add(1)) c := &RepoAccessCache{ client: client, - cache: make(map[string]*repoAccessCacheEntry), + cache: cache2go.Cache(cacheName), ttl: defaultRepoAccessTTL, } for _, opt := range opts { @@ -72,8 +77,19 @@ func (c *RepoAccessCache) SetTTL(ttl time.Duration) { defer c.mu.Unlock() c.ttl = ttl c.logInfo("repo access cache TTL updated", "ttl", ttl) - for key, entry := range c.cache { - entry.scheduleExpiry(c, key) + + // Collect all current entries + entries := make(map[interface{}]*repoAccessCacheEntry) + c.cache.Foreach(func(key interface{}, item *cache2go.CacheItem) { + entries[key] = item.Data().(*repoAccessCacheEntry) + }) + + // Flush the cache + c.cache.Flush() + + // Re-add all entries with the new TTL + for key, entry := range entries { + c.cache.Add(key, ttl, entry) } } @@ -103,69 +119,46 @@ func (c *RepoAccessCache) GetRepoAccessInfo(ctx context.Context, username, owner userKey := strings.ToLower(username) c.mu.Lock() defer c.mu.Unlock() - entry := c.ensureEntry(key) - if entry.ready { - if cachedHasPush, known := entry.knownUsers[userKey]; known { - entry.scheduleExpiry(c, key) - c.logDebug("repo access cache hit", "owner", owner, "repo", repo, "user", username) - cachedPrivate := entry.isPrivate - return cachedPrivate, cachedHasPush, nil + + // Try to get entry from cache - this will keep the item alive if it exists + cacheItem, err := c.cache.Value(key) + if err == nil { + entry := cacheItem.Data().(*repoAccessCacheEntry) + if entry.ready { + if cachedHasPush, known := entry.knownUsers[userKey]; known { + c.logDebug("repo access cache hit", "owner", owner, "repo", repo, "user", username) + return entry.isPrivate, cachedHasPush, nil + } } + // Entry exists but user not in knownUsers, need to query } c.logDebug("repo access cache miss", "owner", owner, "repo", repo, "user", username) - isPrivate, hasPush, err := c.queryRepoAccessInfo(ctx, username, owner, repo) - if err != nil { - return false, false, err + isPrivate, hasPush, queryErr := c.queryRepoAccessInfo(ctx, username, owner, repo) + if queryErr != nil { + return false, false, queryErr } - entry = c.ensureEntry(key) - entry.ready = true - entry.isPrivate = isPrivate - entry.knownUsers[userKey] = hasPush - entry.scheduleExpiry(c, key) - - return isPrivate, hasPush, nil -} - -func (c *RepoAccessCache) ensureEntry(key string) *repoAccessCacheEntry { - if c.cache == nil { - c.cache = make(map[string]*repoAccessCacheEntry) - } - entry, ok := c.cache[key] - if !ok { + // Get or create entry - don't use Value() here to avoid keeping alive unnecessarily + var entry *repoAccessCacheEntry + if err == nil && cacheItem != nil { + // Entry already existed, just update it + entry = cacheItem.Data().(*repoAccessCacheEntry) + } else { + // Create new entry entry = &repoAccessCacheEntry{ knownUsers: make(map[string]bool), } - c.cache[key] = entry } - return entry -} - -func (entry *repoAccessCacheEntry) scheduleExpiry(c *RepoAccessCache, key string) { - if entry.timer != nil { - entry.timer.Stop() - entry.timer = nil - } - - dur := c.ttl - if dur <= 0 { - return - } - - owner, repo := splitKey(key) - entry.timer = time.AfterFunc(dur, func() { - c.mu.Lock() - defer c.mu.Unlock() - - current, ok := c.cache[key] - if !ok || current != entry { - return - } + + entry.ready = true + entry.isPrivate = isPrivate + entry.knownUsers[userKey] = hasPush + + // Add or update the entry in cache with TTL + c.cache.Add(key, c.ttl, entry) - delete(c.cache, key) - c.logDebug("repo access cache entry evicted", "owner", owner, "repo", repo) - }) + return isPrivate, hasPush, nil } func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, owner, repo string) (bool, bool, error) { From 215b2db3b5a4bcada4443a688db249006ab691fe Mon Sep 17 00:00:00 2001 From: JoannaaKL Date: Tue, 18 Nov 2025 10:44:25 +0100 Subject: [PATCH 06/19] Use muesli for cache --- go.mod | 1 - pkg/github/issues_test.go | 8 +-- pkg/lockdown/lockdown.go | 54 +++++++------------ pkg/lockdown/lockdown_test.go | 7 ++- third-party-licenses.darwin.md | 1 + third-party-licenses.linux.md | 1 + third-party-licenses.windows.md | 1 + .../github.com/muesli/cache2go/LICENSE.txt | 28 ++++++++++ 8 files changed, 56 insertions(+), 45 deletions(-) create mode 100644 third-party/github.com/muesli/cache2go/LICENSE.txt diff --git a/go.mod b/go.mod index 7d94b02d3..8d5b1b274 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,6 @@ require ( github.com/invopop/jsonschema v0.13.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect - github.com/muesli/cache2go v0.0.0-20221011235721-518229cd8021 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index b28e8dd82..2fdff87a7 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -120,8 +120,8 @@ func Test_GetIssue(t *testing.T) { } `graphql:"repository(owner: $owner, name: $name)"` }{}, map[string]any{ - "owner": githubv4.String("owner"), - "name": githubv4.String("repo"), + "owner": githubv4.String("github"), + "name": githubv4.String("github-mcp-server"), "username": githubv4.String("testuser"), }, githubv4mock.DataResponse(map[string]any{ @@ -136,8 +136,8 @@ func Test_GetIssue(t *testing.T) { ), requestArgs: map[string]interface{}{ "method": "get", - "owner": "owner", - "repo": "repo", + "owner": "github", + "repo": "github-mcp-server", "issue_number": float64(42), }, expectedIssue: mockIssue, diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go index e41ba74b7..8fb9be08d 100644 --- a/pkg/lockdown/lockdown.go +++ b/pkg/lockdown/lockdown.go @@ -6,15 +6,12 @@ import ( "log/slog" "strings" "sync" - "sync/atomic" "time" "github.com/muesli/cache2go" "github.com/shurcooL/githubv4" ) -var cacheNameCounter atomic.Uint64 - // RepoAccessCache caches repository metadata related to lockdown checks so that // multiple tools can reuse the same access information safely across goroutines. type RepoAccessCache struct { @@ -28,7 +25,6 @@ type RepoAccessCache struct { type repoAccessCacheEntry struct { isPrivate bool knownUsers map[string]bool // normalized login -> has push access - ready bool } const defaultRepoAccessTTL = 5 * time.Minute @@ -55,7 +51,7 @@ func WithLogger(logger *slog.Logger) RepoAccessOption { // client. The cache is safe for concurrent use. func NewRepoAccessCache(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessCache { // Use a unique cache name for each instance to avoid sharing state between tests - cacheName := fmt.Sprintf("repoAccess-%d", cacheNameCounter.Add(1)) + cacheName := "repo-access-cache" c := &RepoAccessCache{ client: client, cache: cache2go.Cache(cacheName), @@ -77,16 +73,16 @@ func (c *RepoAccessCache) SetTTL(ttl time.Duration) { defer c.mu.Unlock() c.ttl = ttl c.logInfo("repo access cache TTL updated", "ttl", ttl) - + // Collect all current entries entries := make(map[interface{}]*repoAccessCacheEntry) c.cache.Foreach(func(key interface{}, item *cache2go.CacheItem) { entries[key] = item.Data().(*repoAccessCacheEntry) }) - + // Flush the cache c.cache.Flush() - + // Re-add all entries with the new TTL for key, entry := range entries { c.cache.Add(key, ttl, entry) @@ -119,16 +115,14 @@ func (c *RepoAccessCache) GetRepoAccessInfo(ctx context.Context, username, owner userKey := strings.ToLower(username) c.mu.Lock() defer c.mu.Unlock() - + // Try to get entry from cache - this will keep the item alive if it exists cacheItem, err := c.cache.Value(key) if err == nil { entry := cacheItem.Data().(*repoAccessCacheEntry) - if entry.ready { - if cachedHasPush, known := entry.knownUsers[userKey]; known { - c.logDebug("repo access cache hit", "owner", owner, "repo", repo, "user", username) - return entry.isPrivate, cachedHasPush, nil - } + if cachedHasPush, known := entry.knownUsers[userKey]; known { + c.logDebug("repo access cache hit", "owner", owner, "repo", repo, "user", username) + return entry.isPrivate, cachedHasPush, nil } // Entry exists but user not in knownUsers, need to query } @@ -139,26 +133,22 @@ func (c *RepoAccessCache) GetRepoAccessInfo(ctx context.Context, username, owner return false, false, queryErr } - // Get or create entry - don't use Value() here to avoid keeping alive unnecessarily + // Repo access info retrieved, update or create cache entry var entry *repoAccessCacheEntry if err == nil && cacheItem != nil { - // Entry already existed, just update it entry = cacheItem.Data().(*repoAccessCacheEntry) - } else { - // Create new entry - entry = &repoAccessCacheEntry{ - knownUsers: make(map[string]bool), - } + entry.knownUsers[userKey] = hasPush + return entry.isPrivate, entry.knownUsers[userKey], nil + } + + // Create new entry + entry = &repoAccessCacheEntry{ + knownUsers: map[string]bool{userKey: hasPush}, + isPrivate: isPrivate, } - - entry.ready = true - entry.isPrivate = isPrivate - entry.knownUsers[userKey] = hasPush - - // Add or update the entry in cache with TTL c.cache.Add(key, c.ttl, entry) - return isPrivate, hasPush, nil + return entry.isPrivate, entry.knownUsers[userKey], nil } func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, owner, repo string) (bool, bool, error) { @@ -207,14 +197,6 @@ func cacheKey(owner, repo string) string { return fmt.Sprintf("%s/%s", strings.ToLower(owner), strings.ToLower(repo)) } -func splitKey(key string) (string, string) { - owner, rest, found := strings.Cut(key, "/") - if !found { - return key, "" - } - return owner, rest -} - func (c *RepoAccessCache) logDebug(msg string, args ...any) { if c != nil && c.logger != nil { c.logger.Debug(msg, args...) diff --git a/pkg/lockdown/lockdown_test.go b/pkg/lockdown/lockdown_test.go index 2ebfe80ae..8e4ac548c 100644 --- a/pkg/lockdown/lockdown_test.go +++ b/pkg/lockdown/lockdown_test.go @@ -115,14 +115,13 @@ func TestRepoAccessCacheTTLDisabled(t *testing.T) { ctx := t.Context() t.Parallel() - cache, transport := newMockRepoAccessCache(t, 0) + // make sure cache TTL is sufficiently large to avoid evictions during the test + cache, transport := newMockRepoAccessCache(t, 1000*time.Millisecond) requireAccess(ctx, t, cache) requireAccess(ctx, t, cache) require.EqualValues(t, 1, transport.CallCount()) - time.Sleep(20 * time.Millisecond) - requireAccess(ctx, t, cache) require.EqualValues(t, 1, transport.CallCount()) } @@ -131,7 +130,7 @@ func TestRepoAccessCacheSetTTLReschedulesExistingEntry(t *testing.T) { ctx := t.Context() t.Parallel() - cache, transport := newMockRepoAccessCache(t, 0) + cache, transport := newMockRepoAccessCache(t, 10*time.Millisecond) requireAccess(ctx, t, cache) require.EqualValues(t, 1, transport.CallCount()) diff --git a/third-party-licenses.darwin.md b/third-party-licenses.darwin.md index eecc6faa8..68a45fa7a 100644 --- a/third-party-licenses.darwin.md +++ b/third-party-licenses.darwin.md @@ -28,6 +28,7 @@ Some packages may only be included on certain architectures or operating systems - [github.com/mark3labs/mcp-go](https://pkg.go.dev/github.com/mark3labs/mcp-go) ([MIT](https://github.com/mark3labs/mcp-go/blob/v0.36.0/LICENSE)) - [github.com/microcosm-cc/bluemonday](https://pkg.go.dev/github.com/microcosm-cc/bluemonday) ([BSD-3-Clause](https://github.com/microcosm-cc/bluemonday/blob/v1.0.27/LICENSE.md)) - [github.com/migueleliasweb/go-github-mock/src/mock](https://pkg.go.dev/github.com/migueleliasweb/go-github-mock/src/mock) ([MIT](https://github.com/migueleliasweb/go-github-mock/blob/v1.3.0/LICENSE)) + - [github.com/muesli/cache2go](https://pkg.go.dev/github.com/muesli/cache2go) ([BSD-3-Clause](https://github.com/muesli/cache2go/blob/518229cd8021/LICENSE.txt)) - [github.com/pelletier/go-toml/v2](https://pkg.go.dev/github.com/pelletier/go-toml/v2) ([MIT](https://github.com/pelletier/go-toml/blob/v2.2.4/LICENSE)) - [github.com/sagikazarmark/locafero](https://pkg.go.dev/github.com/sagikazarmark/locafero) ([MIT](https://github.com/sagikazarmark/locafero/blob/v0.11.0/LICENSE)) - [github.com/shurcooL/githubv4](https://pkg.go.dev/github.com/shurcooL/githubv4) ([MIT](https://github.com/shurcooL/githubv4/blob/48295856cce7/LICENSE)) diff --git a/third-party-licenses.linux.md b/third-party-licenses.linux.md index eecc6faa8..68a45fa7a 100644 --- a/third-party-licenses.linux.md +++ b/third-party-licenses.linux.md @@ -28,6 +28,7 @@ Some packages may only be included on certain architectures or operating systems - [github.com/mark3labs/mcp-go](https://pkg.go.dev/github.com/mark3labs/mcp-go) ([MIT](https://github.com/mark3labs/mcp-go/blob/v0.36.0/LICENSE)) - [github.com/microcosm-cc/bluemonday](https://pkg.go.dev/github.com/microcosm-cc/bluemonday) ([BSD-3-Clause](https://github.com/microcosm-cc/bluemonday/blob/v1.0.27/LICENSE.md)) - [github.com/migueleliasweb/go-github-mock/src/mock](https://pkg.go.dev/github.com/migueleliasweb/go-github-mock/src/mock) ([MIT](https://github.com/migueleliasweb/go-github-mock/blob/v1.3.0/LICENSE)) + - [github.com/muesli/cache2go](https://pkg.go.dev/github.com/muesli/cache2go) ([BSD-3-Clause](https://github.com/muesli/cache2go/blob/518229cd8021/LICENSE.txt)) - [github.com/pelletier/go-toml/v2](https://pkg.go.dev/github.com/pelletier/go-toml/v2) ([MIT](https://github.com/pelletier/go-toml/blob/v2.2.4/LICENSE)) - [github.com/sagikazarmark/locafero](https://pkg.go.dev/github.com/sagikazarmark/locafero) ([MIT](https://github.com/sagikazarmark/locafero/blob/v0.11.0/LICENSE)) - [github.com/shurcooL/githubv4](https://pkg.go.dev/github.com/shurcooL/githubv4) ([MIT](https://github.com/shurcooL/githubv4/blob/48295856cce7/LICENSE)) diff --git a/third-party-licenses.windows.md b/third-party-licenses.windows.md index 75fe8172a..2d8ef9111 100644 --- a/third-party-licenses.windows.md +++ b/third-party-licenses.windows.md @@ -29,6 +29,7 @@ Some packages may only be included on certain architectures or operating systems - [github.com/mark3labs/mcp-go](https://pkg.go.dev/github.com/mark3labs/mcp-go) ([MIT](https://github.com/mark3labs/mcp-go/blob/v0.36.0/LICENSE)) - [github.com/microcosm-cc/bluemonday](https://pkg.go.dev/github.com/microcosm-cc/bluemonday) ([BSD-3-Clause](https://github.com/microcosm-cc/bluemonday/blob/v1.0.27/LICENSE.md)) - [github.com/migueleliasweb/go-github-mock/src/mock](https://pkg.go.dev/github.com/migueleliasweb/go-github-mock/src/mock) ([MIT](https://github.com/migueleliasweb/go-github-mock/blob/v1.3.0/LICENSE)) + - [github.com/muesli/cache2go](https://pkg.go.dev/github.com/muesli/cache2go) ([BSD-3-Clause](https://github.com/muesli/cache2go/blob/518229cd8021/LICENSE.txt)) - [github.com/pelletier/go-toml/v2](https://pkg.go.dev/github.com/pelletier/go-toml/v2) ([MIT](https://github.com/pelletier/go-toml/blob/v2.2.4/LICENSE)) - [github.com/sagikazarmark/locafero](https://pkg.go.dev/github.com/sagikazarmark/locafero) ([MIT](https://github.com/sagikazarmark/locafero/blob/v0.11.0/LICENSE)) - [github.com/shurcooL/githubv4](https://pkg.go.dev/github.com/shurcooL/githubv4) ([MIT](https://github.com/shurcooL/githubv4/blob/48295856cce7/LICENSE)) diff --git a/third-party/github.com/muesli/cache2go/LICENSE.txt b/third-party/github.com/muesli/cache2go/LICENSE.txt new file mode 100644 index 000000000..3dbf3d932 --- /dev/null +++ b/third-party/github.com/muesli/cache2go/LICENSE.txt @@ -0,0 +1,28 @@ +Copyright (c) 2012, Radu Ioan Fericean + 2013-2017, Christian Muehlhaeuser +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +Redistributions in binary form must reproduce the above copyright notice, this +list of conditions and the following disclaimer in the documentation and/or +other materials provided with the distribution. + +Neither the name of Radu Ioan Fericean nor the names of its contributors may be +used to endorse or promote products derived from this software without specific +prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. From 5bba60add62d8e0b964590ae33ba24a6d0be315e Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 18 Nov 2025 11:15:46 +0100 Subject: [PATCH 07/19] Make RepoAccessCache a singleton (#1426) * Initial plan * Implement RepoAccessCache as a singleton pattern Co-authored-by: JoannaaKL <67866556+JoannaaKL@users.noreply.github.com> * Complete singleton implementation and verification Co-authored-by: JoannaaKL <67866556+JoannaaKL@users.noreply.github.com> * Remove cacheIDCounter as requested Co-authored-by: JoannaaKL <67866556+JoannaaKL@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: JoannaaKL <67866556+JoannaaKL@users.noreply.github.com> --- go.sum | 7 ----- internal/ghmcp/server.go | 2 +- pkg/lockdown/lockdown.go | 45 ++++++++++++++++++++++++++-- pkg/lockdown/lockdown_test.go | 55 +++++++++++++++++++++++++++++++++++ 4 files changed, 98 insertions(+), 11 deletions(-) diff --git a/go.sum b/go.sum index d2e323cf6..0ff7b51fa 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,3 @@ -cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= @@ -95,7 +94,6 @@ github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3A github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= @@ -110,23 +108,18 @@ github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 h1:BHyfKlQyqbsFN5p3Ifn github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= -golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98= golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index f82fa0553..15b1efc10 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -90,7 +90,7 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { } var repoAccessCache *lockdown.RepoAccessCache if cfg.LockdownMode { - repoAccessCache = lockdown.NewRepoAccessCache(gqlClient, repoAccessOpts...) + repoAccessCache = lockdown.GetInstance(gqlClient, repoAccessOpts...) } // When a client send an initialize request, update the user agent to include the client info. diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go index 8fb9be08d..368290811 100644 --- a/pkg/lockdown/lockdown.go +++ b/pkg/lockdown/lockdown.go @@ -29,6 +29,12 @@ type repoAccessCacheEntry struct { const defaultRepoAccessTTL = 5 * time.Minute +var ( + instance *RepoAccessCache + instanceOnce sync.Once + instanceMu sync.RWMutex +) + // RepoAccessOption configures RepoAccessCache at construction time. type RepoAccessOption func(*RepoAccessCache) @@ -47,10 +53,43 @@ func WithLogger(logger *slog.Logger) RepoAccessOption { } } -// NewRepoAccessCache returns a cache bound to the provided GitHub GraphQL -// client. The cache is safe for concurrent use. +// GetInstance returns the singleton instance of RepoAccessCache. +// It initializes the instance on first call with the provided client and options. +// Subsequent calls ignore the client and options parameters and return the existing instance. +// This is the preferred way to access the cache in production code. +func GetInstance(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessCache { + instanceOnce.Do(func() { + instance = newRepoAccessCache(client, opts...) + }) + return instance +} + +// ResetInstance clears the singleton instance. This is primarily for testing purposes. +// It flushes the cache and allows re-initialization with different parameters. +// Note: This should not be called while the instance is in use. +func ResetInstance() { + instanceMu.Lock() + defer instanceMu.Unlock() + if instance != nil { + instance.cache.Flush() + } + instance = nil + instanceOnce = sync.Once{} +} + +// NewRepoAccessCache returns a cache bound to the provided GitHub GraphQL client. +// The cache is safe for concurrent use. +// +// For production code, consider using GetInstance() to ensure singleton behavior and +// consistent configuration across the application. NewRepoAccessCache is appropriate +// for testing scenarios where independent cache instances are needed. func NewRepoAccessCache(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessCache { - // Use a unique cache name for each instance to avoid sharing state between tests + return newRepoAccessCache(client, opts...) +} + +// newRepoAccessCache creates a new cache instance. This is a private helper function +// used by GetInstance. +func newRepoAccessCache(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessCache { cacheName := "repo-access-cache" c := &RepoAccessCache{ client: client, diff --git a/pkg/lockdown/lockdown_test.go b/pkg/lockdown/lockdown_test.go index 8e4ac548c..145d3d629 100644 --- a/pkg/lockdown/lockdown_test.go +++ b/pkg/lockdown/lockdown_test.go @@ -145,3 +145,58 @@ func TestRepoAccessCacheSetTTLReschedulesExistingEntry(t *testing.T) { requireAccess(ctx, t, cache) require.EqualValues(t, 2, transport.CallCount()) } + +func TestGetInstanceReturnsSingleton(t *testing.T) { + // Reset any existing singleton + ResetInstance() + defer ResetInstance() // Clean up after test + + gqlClient := githubv4.NewClient(nil) + + // Get instance twice, should return the same instance + instance1 := GetInstance(gqlClient) + instance2 := GetInstance(gqlClient) + + // Verify they're the same instance (same pointer) + require.Same(t, instance1, instance2, "GetInstance should return the same singleton instance") + + // Verify subsequent calls with different options are ignored + instance3 := GetInstance(gqlClient, WithTTL(1*time.Second)) + require.Same(t, instance1, instance3, "GetInstance should ignore options on subsequent calls") + require.Equal(t, defaultRepoAccessTTL, instance3.ttl, "TTL should remain unchanged after first initialization") +} + +func TestResetInstanceClearsSingleton(t *testing.T) { + // Reset any existing singleton + ResetInstance() + defer ResetInstance() // Clean up after test + + gqlClient := githubv4.NewClient(nil) + + // Get first instance with default TTL + instance1 := GetInstance(gqlClient) + require.Equal(t, defaultRepoAccessTTL, instance1.ttl) + + // Reset the singleton + ResetInstance() + + // Get new instance with custom TTL + customTTL := 10 * time.Second + instance2 := GetInstance(gqlClient, WithTTL(customTTL)) + require.NotSame(t, instance1, instance2, "After reset, GetInstance should return a new instance") + require.Equal(t, customTTL, instance2.ttl, "New instance should have the custom TTL") +} + +func TestNewRepoAccessCacheCreatesIndependentInstances(t *testing.T) { + t.Parallel() + + gqlClient := githubv4.NewClient(nil) + + // NewRepoAccessCache should create independent instances + cache1 := NewRepoAccessCache(gqlClient, WithTTL(1*time.Second)) + cache2 := NewRepoAccessCache(gqlClient, WithTTL(2*time.Second)) + + require.NotSame(t, cache1, cache2, "NewRepoAccessCache should create different instances") + require.Equal(t, 1*time.Second, cache1.ttl) + require.Equal(t, 2*time.Second, cache2.ttl) +} From 2d630e5d5763f8ff044ae65aded08da63ed6c059 Mon Sep 17 00:00:00 2001 From: JoannaaKL Date: Tue, 18 Nov 2025 11:28:25 +0100 Subject: [PATCH 08/19] Update mutexes --- cmd/github-mcp-server/generate_docs.go | 4 +- pkg/github/server_test.go | 2 +- pkg/lockdown/lockdown.go | 24 ------ pkg/lockdown/lockdown_test.go | 109 ++----------------------- 4 files changed, 8 insertions(+), 131 deletions(-) diff --git a/cmd/github-mcp-server/generate_docs.go b/cmd/github-mcp-server/generate_docs.go index ee41b8493..2fa81d45a 100644 --- a/cmd/github-mcp-server/generate_docs.go +++ b/cmd/github-mcp-server/generate_docs.go @@ -65,7 +65,7 @@ func generateReadmeDocs(readmePath string) error { t, _ := translations.TranslationHelper() // Create toolset group with mock clients - repoAccessCache := lockdown.NewRepoAccessCache(nil) + repoAccessCache := lockdown.GetInstance(nil) tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{}, repoAccessCache) // Generate toolsets documentation @@ -304,7 +304,7 @@ func generateRemoteToolsetsDoc() string { t, _ := translations.TranslationHelper() // Create toolset group with mock clients - repoAccessCache := lockdown.NewRepoAccessCache(nil) + repoAccessCache := lockdown.GetInstance(nil) tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{}, repoAccessCache) // Generate table header diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 20e20b2bb..446dd2179 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -40,7 +40,7 @@ func stubGetGQLClientFn(client *githubv4.Client) GetGQLClientFn { } func stubRepoAccessCache(client *githubv4.Client) *lockdown.RepoAccessCache { - return lockdown.NewRepoAccessCache(client) + return lockdown.GetInstance(client) } func stubFeatureFlags(enabledFlags map[string]bool) FeatureFlags { diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go index 368290811..7644bb7a1 100644 --- a/pkg/lockdown/lockdown.go +++ b/pkg/lockdown/lockdown.go @@ -32,7 +32,6 @@ const defaultRepoAccessTTL = 5 * time.Minute var ( instance *RepoAccessCache instanceOnce sync.Once - instanceMu sync.RWMutex ) // RepoAccessOption configures RepoAccessCache at construction time. @@ -64,29 +63,6 @@ func GetInstance(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessC return instance } -// ResetInstance clears the singleton instance. This is primarily for testing purposes. -// It flushes the cache and allows re-initialization with different parameters. -// Note: This should not be called while the instance is in use. -func ResetInstance() { - instanceMu.Lock() - defer instanceMu.Unlock() - if instance != nil { - instance.cache.Flush() - } - instance = nil - instanceOnce = sync.Once{} -} - -// NewRepoAccessCache returns a cache bound to the provided GitHub GraphQL client. -// The cache is safe for concurrent use. -// -// For production code, consider using GetInstance() to ensure singleton behavior and -// consistent configuration across the application. NewRepoAccessCache is appropriate -// for testing scenarios where independent cache instances are needed. -func NewRepoAccessCache(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessCache { - return newRepoAccessCache(client, opts...) -} - // newRepoAccessCache creates a new cache instance. This is a private helper function // used by GetInstance. func newRepoAccessCache(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessCache { diff --git a/pkg/lockdown/lockdown_test.go b/pkg/lockdown/lockdown_test.go index 145d3d629..f53c66afa 100644 --- a/pkg/lockdown/lockdown_test.go +++ b/pkg/lockdown/lockdown_test.go @@ -1,7 +1,6 @@ package lockdown import ( - "context" "net/http" "sync" "testing" @@ -84,16 +83,7 @@ func newMockRepoAccessCache(t *testing.T, ttl time.Duration) (*RepoAccessCache, gqlClient := githubv4.NewClient(httpClient) - return NewRepoAccessCache(gqlClient, WithTTL(ttl)), counting -} - -func requireAccess(ctx context.Context, t *testing.T, cache *RepoAccessCache) { - t.Helper() - - isPrivate, hasPush, err := cache.GetRepoAccessInfo(ctx, testUser, testOwner, testRepo) - require.NoError(t, err) - require.False(t, isPrivate) - require.True(t, hasPush) + return GetInstance(gqlClient, WithTTL(ttl)), counting } func TestRepoAccessCacheEvictsAfterTTL(t *testing.T) { @@ -101,102 +91,13 @@ func TestRepoAccessCacheEvictsAfterTTL(t *testing.T) { ctx := t.Context() cache, transport := newMockRepoAccessCache(t, 5*time.Millisecond) - requireAccess(ctx, t, cache) - requireAccess(ctx, t, cache) - require.EqualValues(t, 1, transport.CallCount()) - - time.Sleep(20 * time.Millisecond) - - requireAccess(ctx, t, cache) - require.EqualValues(t, 2, transport.CallCount()) -} - -func TestRepoAccessCacheTTLDisabled(t *testing.T) { - ctx := t.Context() - t.Parallel() - - // make sure cache TTL is sufficiently large to avoid evictions during the test - cache, transport := newMockRepoAccessCache(t, 1000*time.Millisecond) - - requireAccess(ctx, t, cache) - requireAccess(ctx, t, cache) - require.EqualValues(t, 1, transport.CallCount()) - - requireAccess(ctx, t, cache) - require.EqualValues(t, 1, transport.CallCount()) -} - -func TestRepoAccessCacheSetTTLReschedulesExistingEntry(t *testing.T) { - ctx := t.Context() - t.Parallel() - - cache, transport := newMockRepoAccessCache(t, 10*time.Millisecond) - - requireAccess(ctx, t, cache) + _, _, err := cache.GetRepoAccessInfo(ctx, testUser, testOwner, testRepo) + require.NoError(t, err) require.EqualValues(t, 1, transport.CallCount()) - cache.SetTTL(5 * time.Millisecond) - time.Sleep(20 * time.Millisecond) - requireAccess(ctx, t, cache) - require.EqualValues(t, 2, transport.CallCount()) - - requireAccess(ctx, t, cache) + _, _, err = cache.GetRepoAccessInfo(ctx, testUser, testOwner, testRepo) + require.NoError(t, err) require.EqualValues(t, 2, transport.CallCount()) } - -func TestGetInstanceReturnsSingleton(t *testing.T) { - // Reset any existing singleton - ResetInstance() - defer ResetInstance() // Clean up after test - - gqlClient := githubv4.NewClient(nil) - - // Get instance twice, should return the same instance - instance1 := GetInstance(gqlClient) - instance2 := GetInstance(gqlClient) - - // Verify they're the same instance (same pointer) - require.Same(t, instance1, instance2, "GetInstance should return the same singleton instance") - - // Verify subsequent calls with different options are ignored - instance3 := GetInstance(gqlClient, WithTTL(1*time.Second)) - require.Same(t, instance1, instance3, "GetInstance should ignore options on subsequent calls") - require.Equal(t, defaultRepoAccessTTL, instance3.ttl, "TTL should remain unchanged after first initialization") -} - -func TestResetInstanceClearsSingleton(t *testing.T) { - // Reset any existing singleton - ResetInstance() - defer ResetInstance() // Clean up after test - - gqlClient := githubv4.NewClient(nil) - - // Get first instance with default TTL - instance1 := GetInstance(gqlClient) - require.Equal(t, defaultRepoAccessTTL, instance1.ttl) - - // Reset the singleton - ResetInstance() - - // Get new instance with custom TTL - customTTL := 10 * time.Second - instance2 := GetInstance(gqlClient, WithTTL(customTTL)) - require.NotSame(t, instance1, instance2, "After reset, GetInstance should return a new instance") - require.Equal(t, customTTL, instance2.ttl, "New instance should have the custom TTL") -} - -func TestNewRepoAccessCacheCreatesIndependentInstances(t *testing.T) { - t.Parallel() - - gqlClient := githubv4.NewClient(nil) - - // NewRepoAccessCache should create independent instances - cache1 := NewRepoAccessCache(gqlClient, WithTTL(1*time.Second)) - cache2 := NewRepoAccessCache(gqlClient, WithTTL(2*time.Second)) - - require.NotSame(t, cache1, cache2, "NewRepoAccessCache should create different instances") - require.Equal(t, 1*time.Second, cache1.ttl) - require.Equal(t, 2*time.Second, cache2.ttl) -} From 5da1d0a1c668824fc042d0327b1958d5f6be2566 Mon Sep 17 00:00:00 2001 From: JoannaaKL Date: Tue, 18 Nov 2025 12:06:45 +0100 Subject: [PATCH 09/19] . --- pkg/github/issues_test.go | 16 ++++++++-------- pkg/github/pullrequests_test.go | 24 ++++++++++++------------ pkg/github/server_test.go | 5 +++-- 3 files changed, 23 insertions(+), 22 deletions(-) diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index 2fdff87a7..436a34207 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -23,7 +23,7 @@ func Test_GetIssue(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) defaultGQLClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(defaultGQLClient), stubRepoAccessCache(defaultGQLClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(defaultGQLClient), stubRepoAccessCache(defaultGQLClient, 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -212,7 +212,7 @@ func Test_GetIssue(t *testing.T) { } flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, flags) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, flags) request := createMCPRequest(tc.requestArgs) result, err := handler(context.Background(), request) @@ -1710,7 +1710,7 @@ func Test_GetIssueComments(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) gqlClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -1816,7 +1816,7 @@ func Test_GetIssueComments(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1853,7 +1853,7 @@ func Test_GetIssueLabels(t *testing.T) { // Verify tool definition mockGQClient := githubv4.NewClient(nil) mockClient := github.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(mockGQClient), stubRepoAccessCache(mockGQClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(mockGQClient), stubRepoAccessCache(mockGQClient, 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -1928,7 +1928,7 @@ func Test_GetIssueLabels(t *testing.T) { t.Run(tc.name, func(t *testing.T) { gqlClient := githubv4.NewClient(tc.mockedClient) client := github.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) request := createMCPRequest(tc.requestArgs) result, err := handler(context.Background(), request) @@ -2619,7 +2619,7 @@ func Test_GetSubIssues(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) gqlClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -2816,7 +2816,7 @@ func Test_GetSubIssues(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 22dbec5d0..89a06aaec 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -21,7 +21,7 @@ import ( func Test_GetPullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -102,7 +102,7 @@ func Test_GetPullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1133,7 +1133,7 @@ func Test_SearchPullRequests(t *testing.T) { func Test_GetPullRequestFiles(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1236,7 +1236,7 @@ func Test_GetPullRequestFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1277,7 +1277,7 @@ func Test_GetPullRequestFiles(t *testing.T) { func Test_GetPullRequestStatus(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1404,7 +1404,7 @@ func Test_GetPullRequestStatus(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1566,7 +1566,7 @@ func Test_UpdatePullRequestBranch(t *testing.T) { func Test_GetPullRequestComments(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1658,7 +1658,7 @@ func Test_GetPullRequestComments(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1700,7 +1700,7 @@ func Test_GetPullRequestComments(t *testing.T) { func Test_GetPullRequestReviews(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1788,7 +1788,7 @@ func Test_GetPullRequestReviews(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -2789,7 +2789,7 @@ func TestGetPullRequestDiff(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -2847,7 +2847,7 @@ index 5d6e7b2..8a4f5c3 100644 // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil)), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 446dd2179..2a4ae48f4 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "testing" + "time" "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/raw" @@ -39,8 +40,8 @@ func stubGetGQLClientFn(client *githubv4.Client) GetGQLClientFn { } } -func stubRepoAccessCache(client *githubv4.Client) *lockdown.RepoAccessCache { - return lockdown.GetInstance(client) +func stubRepoAccessCache(client *githubv4.Client, ttl time.Duration) *lockdown.RepoAccessCache { + return lockdown.GetInstance(client, lockdown.WithTTL(ttl)) } func stubFeatureFlags(enabledFlags map[string]bool) FeatureFlags { From d9e0e0c9cb7926b6007ff9299c092a14860df786 Mon Sep 17 00:00:00 2001 From: JoannaaKL Date: Tue, 18 Nov 2025 13:50:38 +0100 Subject: [PATCH 10/19] Reuse cache --- pkg/github/issues_test.go | 137 +++++++++++++++++--------------------- 1 file changed, 60 insertions(+), 77 deletions(-) diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index 436a34207..208e7874a 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -22,8 +22,64 @@ import ( func Test_GetIssue(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - defaultGQLClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(defaultGQLClient), stubRepoAccessCache(defaultGQLClient, 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + type repoAccessQuery struct { + Repository struct { + IsPrivate githubv4.Boolean + Collaborators struct { + Edges []struct { + Permission githubv4.String + Node struct { + Login githubv4.String + } + } + } `graphql:"collaborators(query: $username, first: 1)"` + } `graphql:"repository(owner: $owner, name: $name)"` + } + + lockdownHTTPClient := githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + repoAccessQuery{}, + map[string]any{ + "owner": githubv4.String("github"), + "name": githubv4.String("github-mcp-server"), + "username": githubv4.String("testuser"), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "isPrivate": true, + "collaborators": map[string]any{ + "edges": []any{}, + }, + }, + }), + ), + githubv4mock.NewQueryMatcher( + repoAccessQuery{}, + map[string]any{ + "owner": githubv4.String("owner"), + "name": githubv4.String("repo"), + "username": githubv4.String("testuser"), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "isPrivate": false, + "collaborators": map[string]any{ + "edges": []any{ + map[string]any{ + "permission": "READ", + "node": map[string]any{ + "login": "testuser", + }, + }, + }, + }, + }, + }), + ), + ) + defaultGQLClient := githubv4.NewClient(lockdownHTTPClient) + repoAccessCache := stubRepoAccessCache(defaultGQLClient, 15*time.Minute) + tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(defaultGQLClient), repoAccessCache, translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -55,7 +111,6 @@ func Test_GetIssue(t *testing.T) { tests := []struct { name string mockedClient *http.Client - gqlHTTPClient *http.Client requestArgs map[string]interface{} expectHandlerError bool expectResultError bool @@ -104,36 +159,6 @@ func Test_GetIssue(t *testing.T) { mockIssue, ), ), - gqlHTTPClient: githubv4mock.NewMockedHTTPClient( - githubv4mock.NewQueryMatcher( - struct { - Repository struct { - IsPrivate githubv4.Boolean - Collaborators struct { - Edges []struct { - Permission githubv4.String - Node struct { - Login githubv4.String - } - } - } `graphql:"collaborators(query: $username, first: 1)"` - } `graphql:"repository(owner: $owner, name: $name)"` - }{}, - map[string]any{ - "owner": githubv4.String("github"), - "name": githubv4.String("github-mcp-server"), - "username": githubv4.String("testuser"), - }, - githubv4mock.DataResponse(map[string]any{ - "repository": map[string]any{ - "isPrivate": true, - "collaborators": map[string]any{ - "edges": []any{}, - }, - }, - }), - ), - ), requestArgs: map[string]interface{}{ "method": "get", "owner": "github", @@ -151,43 +176,6 @@ func Test_GetIssue(t *testing.T) { mockIssue, ), ), - gqlHTTPClient: githubv4mock.NewMockedHTTPClient( - githubv4mock.NewQueryMatcher( - struct { - Repository struct { - IsPrivate githubv4.Boolean - Collaborators struct { - Edges []struct { - Permission githubv4.String - Node struct { - Login githubv4.String - } - } - } `graphql:"collaborators(query: $username, first: 1)"` - } `graphql:"repository(owner: $owner, name: $name)"` - }{}, - map[string]any{ - "owner": githubv4.String("owner"), - "name": githubv4.String("repo"), - "username": githubv4.String("testuser"), - }, - githubv4mock.DataResponse(map[string]any{ - "repository": map[string]any{ - "isPrivate": false, - "collaborators": map[string]any{ - "edges": []any{ - map[string]any{ - "permission": "READ", - "node": map[string]any{ - "login": "testuser", - }, - }, - }, - }, - }, - }), - ), - ), requestArgs: map[string]interface{}{ "method": "get", "owner": "owner", @@ -204,15 +192,10 @@ func Test_GetIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - var gqlClient *githubv4.Client - if tc.gqlHTTPClient != nil { - gqlClient = githubv4.NewClient(tc.gqlHTTPClient) - } else { - gqlClient = defaultGQLClient - } + gqlClient := defaultGQLClient flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, flags) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), repoAccessCache, translations.NullTranslationHelper, flags) request := createMCPRequest(tc.requestArgs) result, err := handler(context.Background(), request) From c46bd2efba123c49d2c1250139bbcdb9811c56bc Mon Sep 17 00:00:00 2001 From: JoannaaKL Date: Tue, 18 Nov 2025 15:41:19 +0100 Subject: [PATCH 11/19] . --- pkg/github/issues_test.go | 157 ++++++++++++++++++-------------- pkg/github/pullrequests_test.go | 10 +- pkg/github/server_test.go | 3 +- pkg/lockdown/lockdown.go | 37 ++++++-- 4 files changed, 127 insertions(+), 80 deletions(-) diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index 208e7874a..2f4b584a5 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -11,6 +11,7 @@ import ( "github.com/github/github-mcp-server/internal/githubv4mock" "github.com/github/github-mcp-server/internal/toolsnaps" + "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/migueleliasweb/go-github-mock/src/mock" @@ -19,66 +20,13 @@ import ( "github.com/stretchr/testify/require" ) +var defaultGQLClient *githubv4.Client = githubv4.NewClient(githubv4mock.NewMockedHTTPClient()) +var repoAccessCache *lockdown.RepoAccessCache = stubRepoAccessCache(defaultGQLClient, 15*time.Minute) + func Test_GetIssue(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - type repoAccessQuery struct { - Repository struct { - IsPrivate githubv4.Boolean - Collaborators struct { - Edges []struct { - Permission githubv4.String - Node struct { - Login githubv4.String - } - } - } `graphql:"collaborators(query: $username, first: 1)"` - } `graphql:"repository(owner: $owner, name: $name)"` - } - - lockdownHTTPClient := githubv4mock.NewMockedHTTPClient( - githubv4mock.NewQueryMatcher( - repoAccessQuery{}, - map[string]any{ - "owner": githubv4.String("github"), - "name": githubv4.String("github-mcp-server"), - "username": githubv4.String("testuser"), - }, - githubv4mock.DataResponse(map[string]any{ - "repository": map[string]any{ - "isPrivate": true, - "collaborators": map[string]any{ - "edges": []any{}, - }, - }, - }), - ), - githubv4mock.NewQueryMatcher( - repoAccessQuery{}, - map[string]any{ - "owner": githubv4.String("owner"), - "name": githubv4.String("repo"), - "username": githubv4.String("testuser"), - }, - githubv4mock.DataResponse(map[string]any{ - "repository": map[string]any{ - "isPrivate": false, - "collaborators": map[string]any{ - "edges": []any{ - map[string]any{ - "permission": "READ", - "node": map[string]any{ - "login": "testuser", - }, - }, - }, - }, - }, - }), - ), - ) - defaultGQLClient := githubv4.NewClient(lockdownHTTPClient) - repoAccessCache := stubRepoAccessCache(defaultGQLClient, 15*time.Minute) + defaultGQLClient := githubv4.NewClient(nil) tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(defaultGQLClient), repoAccessCache, translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -111,6 +59,7 @@ func Test_GetIssue(t *testing.T) { tests := []struct { name string mockedClient *http.Client + gqlHTTPClient *http.Client requestArgs map[string]interface{} expectHandlerError bool expectResultError bool @@ -159,10 +108,40 @@ func Test_GetIssue(t *testing.T) { mockIssue, ), ), + gqlHTTPClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + IsPrivate githubv4.Boolean + Collaborators struct { + Edges []struct { + Permission githubv4.String + Node struct { + Login githubv4.String + } + } + } `graphql:"collaborators(query: $username, first: 1)"` + } `graphql:"repository(owner: $owner, name: $name)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "name": githubv4.String("repo"), + "username": githubv4.String("testuser"), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "isPrivate": true, + "collaborators": map[string]any{ + "edges": []any{}, + }, + }, + }), + ), + ), requestArgs: map[string]interface{}{ "method": "get", - "owner": "github", - "repo": "github-mcp-server", + "owner": "owner", + "repo": "repo", "issue_number": float64(42), }, expectedIssue: mockIssue, @@ -176,6 +155,43 @@ func Test_GetIssue(t *testing.T) { mockIssue, ), ), + gqlHTTPClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + IsPrivate githubv4.Boolean + Collaborators struct { + Edges []struct { + Permission githubv4.String + Node struct { + Login githubv4.String + } + } + } `graphql:"collaborators(query: $username, first: 1)"` + } `graphql:"repository(owner: $owner, name: $name)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "name": githubv4.String("repo"), + "username": githubv4.String("testuser"), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "isPrivate": false, + "collaborators": map[string]any{ + "edges": []any{ + map[string]any{ + "permission": "READ", + "node": map[string]any{ + "login": "testuser", + }, + }, + }, + }, + }, + }), + ), + ), requestArgs: map[string]interface{}{ "method": "get", "owner": "owner", @@ -192,10 +208,17 @@ func Test_GetIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - gqlClient := defaultGQLClient + var gqlClient *githubv4.Client + cache := repoAccessCache + if tc.gqlHTTPClient != nil { + gqlClient = githubv4.NewClient(tc.gqlHTTPClient) + cache = stubRepoAccessCache(gqlClient, 15*time.Minute) + } else { + gqlClient = defaultGQLClient + } flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), repoAccessCache, translations.NullTranslationHelper, flags) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), cache, translations.NullTranslationHelper, flags) request := createMCPRequest(tc.requestArgs) result, err := handler(context.Background(), request) @@ -1693,7 +1716,7 @@ func Test_GetIssueComments(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) gqlClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -1799,7 +1822,7 @@ func Test_GetIssueComments(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1836,7 +1859,7 @@ func Test_GetIssueLabels(t *testing.T) { // Verify tool definition mockGQClient := githubv4.NewClient(nil) mockClient := github.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(mockGQClient), stubRepoAccessCache(mockGQClient, 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(mockGQClient), stubRepoAccessCache(mockGQClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -1911,7 +1934,7 @@ func Test_GetIssueLabels(t *testing.T) { t.Run(tc.name, func(t *testing.T) { gqlClient := githubv4.NewClient(tc.mockedClient) client := github.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) request := createMCPRequest(tc.requestArgs) result, err := handler(context.Background(), request) @@ -2602,7 +2625,7 @@ func Test_GetSubIssues(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) gqlClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -2799,7 +2822,7 @@ func Test_GetSubIssues(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 89a06aaec..185e4ec67 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -21,7 +21,7 @@ import ( func Test_GetPullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -102,7 +102,7 @@ func Test_GetPullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1133,7 +1133,7 @@ func Test_SearchPullRequests(t *testing.T) { func Test_GetPullRequestFiles(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1236,7 +1236,7 @@ func Test_GetPullRequestFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1277,7 +1277,7 @@ func Test_GetPullRequestFiles(t *testing.T) { func Test_GetPullRequestStatus(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 2a4ae48f4..7d5b727b1 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -41,7 +41,8 @@ func stubGetGQLClientFn(client *githubv4.Client) GetGQLClientFn { } func stubRepoAccessCache(client *githubv4.Client, ttl time.Duration) *lockdown.RepoAccessCache { - return lockdown.GetInstance(client, lockdown.WithTTL(ttl)) + cacheName := fmt.Sprintf("repo-access-cache-test-%d", time.Now().UnixNano()) + return lockdown.NewRepoAccessCache(client, lockdown.WithTTL(ttl), lockdown.WithCacheName(cacheName)) } func stubFeatureFlags(enabledFlags map[string]bool) FeatureFlags { diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go index 7644bb7a1..b7bca7878 100644 --- a/pkg/lockdown/lockdown.go +++ b/pkg/lockdown/lockdown.go @@ -27,11 +27,14 @@ type repoAccessCacheEntry struct { knownUsers map[string]bool // normalized login -> has push access } -const defaultRepoAccessTTL = 5 * time.Minute +const ( + defaultRepoAccessTTL = 5 * time.Minute + defaultRepoAccessCacheKey = "repo-access-cache" +) var ( - instance *RepoAccessCache - instanceOnce sync.Once + instance *RepoAccessCache + instanceMu sync.Mutex ) // RepoAccessOption configures RepoAccessCache at construction time. @@ -52,24 +55,41 @@ func WithLogger(logger *slog.Logger) RepoAccessOption { } } +// WithCacheName overrides the cache table name used for storing entries. This option is intended for tests +// that need isolated cache instances. +func WithCacheName(name string) RepoAccessOption { + return func(c *RepoAccessCache) { + if name != "" { + c.cache = cache2go.Cache(name) + } + } +} + // GetInstance returns the singleton instance of RepoAccessCache. // It initializes the instance on first call with the provided client and options. // Subsequent calls ignore the client and options parameters and return the existing instance. // This is the preferred way to access the cache in production code. func GetInstance(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessCache { - instanceOnce.Do(func() { + instanceMu.Lock() + defer instanceMu.Unlock() + if instance == nil { instance = newRepoAccessCache(client, opts...) - }) + } return instance } +// NewRepoAccessCache constructs a repo access cache without mutating the global singleton. +// This helper is useful for tests that need isolated cache instances. +func NewRepoAccessCache(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessCache { + return newRepoAccessCache(client, opts...) +} + // newRepoAccessCache creates a new cache instance. This is a private helper function // used by GetInstance. func newRepoAccessCache(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessCache { - cacheName := "repo-access-cache" c := &RepoAccessCache{ client: client, - cache: cache2go.Cache(cacheName), + cache: cache2go.Cache(defaultRepoAccessCacheKey), ttl: defaultRepoAccessTTL, } for _, opt := range opts { @@ -77,6 +97,9 @@ func newRepoAccessCache(client *githubv4.Client, opts ...RepoAccessOption) *Repo opt(c) } } + if c.cache == nil { + c.cache = cache2go.Cache(defaultRepoAccessCacheKey) + } c.logInfo("repo access cache initialized", "ttl", c.ttl) return c } From c0edac0c0f356fa4d0dc5a43f0a85a0442da38c2 Mon Sep 17 00:00:00 2001 From: JoannaaKL Date: Wed, 19 Nov 2025 12:43:16 +0100 Subject: [PATCH 12/19] . --- pkg/lockdown/lockdown.go | 43 ++++++++--------------------------- pkg/lockdown/lockdown_test.go | 1 - 2 files changed, 10 insertions(+), 34 deletions(-) diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go index b7bca7878..1507260e6 100644 --- a/pkg/lockdown/lockdown.go +++ b/pkg/lockdown/lockdown.go @@ -104,29 +104,6 @@ func newRepoAccessCache(client *githubv4.Client, opts ...RepoAccessOption) *Repo return c } -// SetTTL overrides the default time-to-live used for cache entries. A non-positive -// duration disables expiration. -func (c *RepoAccessCache) SetTTL(ttl time.Duration) { - c.mu.Lock() - defer c.mu.Unlock() - c.ttl = ttl - c.logInfo("repo access cache TTL updated", "ttl", ttl) - - // Collect all current entries - entries := make(map[interface{}]*repoAccessCacheEntry) - c.cache.Foreach(func(key interface{}, item *cache2go.CacheItem) { - entries[key] = item.Data().(*repoAccessCacheEntry) - }) - - // Flush the cache - c.cache.Flush() - - // Re-add all entries with the new TTL - for key, entry := range entries { - c.cache.Add(key, ttl, entry) - } -} - // SetLogger updates the logger used for cache diagnostics. func (c *RepoAccessCache) SetLogger(logger *slog.Logger) { c.mu.Lock() @@ -162,8 +139,16 @@ func (c *RepoAccessCache) GetRepoAccessInfo(ctx context.Context, username, owner c.logDebug("repo access cache hit", "owner", owner, "repo", repo, "user", username) return entry.isPrivate, cachedHasPush, nil } - // Entry exists but user not in knownUsers, need to query + c.logDebug("known users cache miss", "owner", owner, "repo", repo, "user", username) + _, hasPush, queryErr := c.queryRepoAccessInfo(ctx, username, owner, repo) + if queryErr != nil { + return false, false, queryErr + } + entry.knownUsers[userKey] = hasPush + c.cache.Add(key, c.ttl, entry) + return entry.isPrivate, entry.knownUsers[userKey], nil } + c.logDebug("repo access cache miss", "owner", owner, "repo", repo, "user", username) isPrivate, hasPush, queryErr := c.queryRepoAccessInfo(ctx, username, owner, repo) @@ -171,16 +156,8 @@ func (c *RepoAccessCache) GetRepoAccessInfo(ctx context.Context, username, owner return false, false, queryErr } - // Repo access info retrieved, update or create cache entry - var entry *repoAccessCacheEntry - if err == nil && cacheItem != nil { - entry = cacheItem.Data().(*repoAccessCacheEntry) - entry.knownUsers[userKey] = hasPush - return entry.isPrivate, entry.knownUsers[userKey], nil - } - // Create new entry - entry = &repoAccessCacheEntry{ + entry := &repoAccessCacheEntry{ knownUsers: map[string]bool{userKey: hasPush}, isPrivate: isPrivate, } diff --git a/pkg/lockdown/lockdown_test.go b/pkg/lockdown/lockdown_test.go index f53c66afa..3312e1fdf 100644 --- a/pkg/lockdown/lockdown_test.go +++ b/pkg/lockdown/lockdown_test.go @@ -87,7 +87,6 @@ func newMockRepoAccessCache(t *testing.T, ttl time.Duration) (*RepoAccessCache, } func TestRepoAccessCacheEvictsAfterTTL(t *testing.T) { - t.Parallel() ctx := t.Context() cache, transport := newMockRepoAccessCache(t, 5*time.Millisecond) From eda6b289b2fbbc931530332f3e7efc929c30f3d9 Mon Sep 17 00:00:00 2001 From: JoannaaKL Date: Wed, 19 Nov 2025 13:44:35 +0100 Subject: [PATCH 13/19] Fix logic after vibe coding --- pkg/github/issues.go | 2 ++ pkg/github/pullrequests.go | 42 ++++++++++++++++++++++++-------------- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 0af68e712..76ea47db1 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -400,6 +400,7 @@ func GetIssueComments(ctx context.Context, client *github.Client, cache *lockdow } // Do not filter content for private repositories if isPrivate { + filteredComments = comments break } if hasPushAccess { @@ -464,6 +465,7 @@ func GetSubIssues(ctx context.Context, client *github.Client, cache *lockdown.Re } // Repo is private, do not filter content if isPrivate { + filteredSubIssues = subIssues break } if hasPushAccess { diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 840ee1668..32f15b754 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -297,18 +297,25 @@ func GetPullRequestReviewComments(ctx context.Context, client *github.Client, ca if cache == nil { return nil, fmt.Errorf("lockdown cache is not configured") } - if len(comments) > 0 { - login := comments[0].GetUser().GetLogin() - if login != "" { - isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) - if err != nil { - return nil, fmt.Errorf("failed to check content removal: %w", err) - } - if !isPrivate && !hasPushAccess { - return mcp.NewToolResultError("access to pull request review comments is restricted by lockdown mode"), nil - } + filteredComments := make([]*github.PullRequestComment, 0, len(comments)) + for _, comment := range comments { + user := comment.GetUser() + if user == nil { + continue + } + isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, user.GetLogin(), owner, repo) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil + } + if isPrivate { + filteredComments = comments + break + } + if hasPushAccess { + filteredComments = append(filteredComments, comment) } } + comments = filteredComments } r, err := json.Marshal(comments) @@ -342,16 +349,21 @@ func GetPullRequestReviews(ctx context.Context, client *github.Client, cache *lo if cache == nil { return nil, fmt.Errorf("lockdown cache is not configured") } - if len(reviews) > 0 { - login := reviews[0].GetUser().GetLogin() + filteredReviews := make([]*github.PullRequestReview, 0, len(reviews)) + for _, review := range reviews { + login := review.GetUser().GetLogin() if login != "" { isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) if err != nil { - return nil, fmt.Errorf("failed to check content removal: %w", err) + return nil, fmt.Errorf("failed to check lockdown mode: %w", err) + } + if isPrivate { + filteredReviews = reviews } - if !isPrivate && !hasPushAccess { - return mcp.NewToolResultError("access to pull request reviews is restricted by lockdown mode"), nil + if hasPushAccess { + filteredReviews = append(filteredReviews, review) } + reviews = filteredReviews } } } From 53c3a25c351365565c320b089ae5a6b61ec99448 Mon Sep 17 00:00:00 2001 From: JoannaaKL Date: Wed, 19 Nov 2025 13:59:51 +0100 Subject: [PATCH 14/19] Update docs --- README.md | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7c4884074..a565f1e37 100644 --- a/README.md +++ b/README.md @@ -1264,7 +1264,7 @@ docker run -i --rm \ ## Lockdown Mode -Lockdown mode limits the content that the server will surface from public repositories. When enabled, requests that fetch issue details will return an error if the issue was created by someone who does not have push access to the repository. Private repositories are unaffected, and collaborators can still access their own issues. +Lockdown mode limits the content that the server will surface from public repositories. When enabled, the server checks whether the author of each item has push access to the repository. Private repositories are unaffected, and collaborators keep full access to their own content. ```bash ./github-mcp-server --lockdown-mode @@ -1279,7 +1279,20 @@ docker run -i --rm \ ghcr.io/github/github-mcp-server ``` -At the moment lockdown mode applies to the issue read toolset, but it is designed to extend to additional data surfaces over time. +The behavior of lockdown mode depends on the tool invoked. + +Following tools will return an error when the author lacks the push access: + +- `issue_read:get` +- `pull_request_read:get` + +Following tools will filter out content from users lacking the push access: + +- `issue_read:get_comments` +- `issue_read:get_sub_issues` +- `pull_request_read:get_comments` +- `pull_request_read:get_review_comments` +- `pull_request_read:get_reviews` ## i18n / Overriding Descriptions From 60ce4615b8aa3d8f08ec310fffa9ace281ffbf7d Mon Sep 17 00:00:00 2001 From: JoannaaKL Date: Wed, 19 Nov 2025 15:02:36 +0100 Subject: [PATCH 15/19] . --- pkg/github/issues_test.go | 134 ++++++++++++++++++++++++++++++++++---- pkg/github/server_test.go | 2 +- pkg/lockdown/lockdown.go | 43 +++--------- 3 files changed, 134 insertions(+), 45 deletions(-) diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index 2f4b584a5..c73f2bc80 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -1,9 +1,11 @@ package github import ( + "bytes" "context" "encoding/json" "fmt" + "io" "net/http" "strings" "testing" @@ -20,9 +22,103 @@ import ( "github.com/stretchr/testify/require" ) -var defaultGQLClient *githubv4.Client = githubv4.NewClient(githubv4mock.NewMockedHTTPClient()) +var defaultGQLClient *githubv4.Client = githubv4.NewClient(newRepoAccessHTTPClient()) var repoAccessCache *lockdown.RepoAccessCache = stubRepoAccessCache(defaultGQLClient, 15*time.Minute) +type repoAccessKey struct { + owner string + repo string + username string +} + +type repoAccessValue struct { + isPrivate bool + permission string +} + +type repoAccessMockTransport struct { + responses map[repoAccessKey]repoAccessValue +} + +func newRepoAccessHTTPClient() *http.Client { + responses := map[repoAccessKey]repoAccessValue{ + {owner: "owner2", repo: "repo2", username: "testuser2"}: {isPrivate: true}, + {owner: "owner", repo: "repo", username: "testuser"}: {isPrivate: false, permission: "READ"}, + } + + return &http.Client{Transport: &repoAccessMockTransport{responses: responses}} +} + +func (rt *repoAccessMockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if req.Body == nil { + return nil, fmt.Errorf("missing request body") + } + + var payload struct { + Query string `json:"query"` + Variables map[string]any `json:"variables"` + } + + if err := json.NewDecoder(req.Body).Decode(&payload); err != nil { + return nil, err + } + _ = req.Body.Close() + + owner := toString(payload.Variables["owner"]) + repo := toString(payload.Variables["name"]) + username := toString(payload.Variables["username"]) + + value, ok := rt.responses[repoAccessKey{owner: owner, repo: repo, username: username}] + if !ok { + value = repoAccessValue{isPrivate: false, permission: "WRITE"} + } + + edges := []any{} + if value.permission != "" { + edges = append(edges, map[string]any{ + "permission": value.permission, + "node": map[string]any{ + "login": username, + }, + }) + } + + responseBody, err := json.Marshal(map[string]any{ + "data": map[string]any{ + "repository": map[string]any{ + "isPrivate": value.isPrivate, + "collaborators": map[string]any{ + "edges": edges, + }, + }, + }, + }) + if err != nil { + return nil, err + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(responseBody)), + } + resp.Header.Set("Content-Type", "application/json") + return resp, nil +} + +func toString(v any) string { + switch value := v.(type) { + case string: + return value + case fmt.Stringer: + return value.String() + case nil: + return "" + default: + return fmt.Sprintf("%v", value) + } +} + func Test_GetIssue(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) @@ -55,6 +151,22 @@ func Test_GetIssue(t *testing.T) { }, }, } + mockIssue2 := &github.Issue{ + Number: github.Ptr(422), + Title: github.Ptr("Test Issue 2"), + Body: github.Ptr("This is a test issue 2"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/issues/42"), + User: &github.User{ + Login: github.Ptr("testuser2"), + }, + Repository: &github.Repository{ + Name: github.Ptr("repo2"), + Owner: &github.User{ + Login: github.Ptr("owner2"), + }, + }, + } tests := []struct { name string @@ -77,8 +189,8 @@ func Test_GetIssue(t *testing.T) { ), requestArgs: map[string]interface{}{ "method": "get", - "owner": "owner", - "repo": "repo", + "owner": "owner2", + "repo": "repo2", "issue_number": float64(42), }, expectedIssue: mockIssue, @@ -105,7 +217,7 @@ func Test_GetIssue(t *testing.T) { mockedClient: mock.NewMockedHTTPClient( mock.WithRequestMatch( mock.GetReposIssuesByOwnerByRepoByIssueNumber, - mockIssue, + mockIssue2, ), ), gqlHTTPClient: githubv4mock.NewMockedHTTPClient( @@ -124,9 +236,9 @@ func Test_GetIssue(t *testing.T) { } `graphql:"repository(owner: $owner, name: $name)"` }{}, map[string]any{ - "owner": githubv4.String("owner"), - "name": githubv4.String("repo"), - "username": githubv4.String("testuser"), + "owner": githubv4.String("owner2"), + "name": githubv4.String("repo2"), + "username": githubv4.String("testuser2"), }, githubv4mock.DataResponse(map[string]any{ "repository": map[string]any{ @@ -140,11 +252,11 @@ func Test_GetIssue(t *testing.T) { ), requestArgs: map[string]interface{}{ "method": "get", - "owner": "owner", - "repo": "repo", - "issue_number": float64(42), + "owner": "owner2", + "repo": "repo2", + "issue_number": float64(422), }, - expectedIssue: mockIssue, + expectedIssue: mockIssue2, lockdownEnabled: true, }, { diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 7d5b727b1..2e1c42580 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -42,7 +42,7 @@ func stubGetGQLClientFn(client *githubv4.Client) GetGQLClientFn { func stubRepoAccessCache(client *githubv4.Client, ttl time.Duration) *lockdown.RepoAccessCache { cacheName := fmt.Sprintf("repo-access-cache-test-%d", time.Now().UnixNano()) - return lockdown.NewRepoAccessCache(client, lockdown.WithTTL(ttl), lockdown.WithCacheName(cacheName)) + return lockdown.GetInstance(client, lockdown.WithTTL(ttl), lockdown.WithCacheName(cacheName)) } func stubFeatureFlags(enabledFlags map[string]bool) FeatureFlags { diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go index 1507260e6..8f57225ea 100644 --- a/pkg/lockdown/lockdown.go +++ b/pkg/lockdown/lockdown.go @@ -73,35 +73,18 @@ func GetInstance(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessC instanceMu.Lock() defer instanceMu.Unlock() if instance == nil { - instance = newRepoAccessCache(client, opts...) - } - return instance -} - -// NewRepoAccessCache constructs a repo access cache without mutating the global singleton. -// This helper is useful for tests that need isolated cache instances. -func NewRepoAccessCache(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessCache { - return newRepoAccessCache(client, opts...) -} - -// newRepoAccessCache creates a new cache instance. This is a private helper function -// used by GetInstance. -func newRepoAccessCache(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessCache { - c := &RepoAccessCache{ - client: client, - cache: cache2go.Cache(defaultRepoAccessCacheKey), - ttl: defaultRepoAccessTTL, - } - for _, opt := range opts { - if opt != nil { - opt(c) + instance = &RepoAccessCache{ + client: client, + cache: cache2go.Cache(defaultRepoAccessCacheKey), + ttl: defaultRepoAccessTTL, + } + for _, opt := range opts { + if opt != nil { + opt(instance) + } } } - if c.cache == nil { - c.cache = cache2go.Cache(defaultRepoAccessCacheKey) - } - c.logInfo("repo access cache initialized", "ttl", c.ttl) - return c + return instance } // SetLogger updates the logger used for cache diagnostics. @@ -217,9 +200,3 @@ func (c *RepoAccessCache) logDebug(msg string, args ...any) { c.logger.Debug(msg, args...) } } - -func (c *RepoAccessCache) logInfo(msg string, args ...any) { - if c != nil && c.logger != nil { - c.logger.Info(msg, args...) - } -} From 2de28f79a8d961888720ad2aa1fb72b879b76d57 Mon Sep 17 00:00:00 2001 From: JoannaaKL Date: Wed, 19 Nov 2025 16:26:10 +0100 Subject: [PATCH 16/19] Refactoring to make the code pretty --- pkg/github/issues.go | 20 +++++----- pkg/github/pullrequests.go | 17 +++++---- pkg/lockdown/lockdown.go | 71 ++++++++++++++++++++++++----------- pkg/lockdown/lockdown_test.go | 14 ++++++- 4 files changed, 81 insertions(+), 41 deletions(-) diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 76ea47db1..f3c902016 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -331,11 +331,11 @@ func GetIssue(ctx context.Context, client *github.Client, cache *lockdown.RepoAc } login := issue.GetUser().GetLogin() if login != "" { - isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) + info, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil } - if !isPrivate && !hasPushAccess { + if info.ViewerLogin != login && !info.IsPrivate && !info.HasPushAccess { return mcp.NewToolResultError("access to issue details is restricted by lockdown mode"), nil } } @@ -394,16 +394,16 @@ func GetIssueComments(ctx context.Context, client *github.Client, cache *lockdow if login == "" { continue } - isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) + info, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil } - // Do not filter content for private repositories - if isPrivate { + // Do not filter content for private repositories or if the comment author is the viewer + if info.IsPrivate || info.ViewerLogin == login { filteredComments = comments break } - if hasPushAccess { + if info.HasPushAccess { filteredComments = append(filteredComments, comment) } } @@ -459,16 +459,16 @@ func GetSubIssues(ctx context.Context, client *github.Client, cache *lockdown.Re if login == "" { continue } - isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) + info, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil } - // Repo is private, do not filter content - if isPrivate { + // Repo is private or the comment author is the viewer, do not filter content + if info.IsPrivate || info.ViewerLogin == login { filteredSubIssues = subIssues break } - if hasPushAccess { + if info.HasPushAccess { filteredSubIssues = append(filteredSubIssues, subIssue) } } diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 32f15b754..c96f4b50a 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -141,12 +141,12 @@ func GetPullRequest(ctx context.Context, client *github.Client, cache *lockdown. } login := pr.GetUser().GetLogin() if login != "" { - isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) + info, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) if err != nil { return nil, fmt.Errorf("failed to check content removal: %w", err) } - if !isPrivate && !hasPushAccess { + if info.ViewerLogin != login && !info.IsPrivate && !info.HasPushAccess { return mcp.NewToolResultError("access to pull request is restricted by lockdown mode"), nil } } @@ -303,15 +303,16 @@ func GetPullRequestReviewComments(ctx context.Context, client *github.Client, ca if user == nil { continue } - isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, user.GetLogin(), owner, repo) + info, err := cache.GetRepoAccessInfo(ctx, user.GetLogin(), owner, repo) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil } - if isPrivate { + // Do not filter content for private repositories or if the comment author is the viewer + if info.IsPrivate || info.ViewerLogin == user.GetLogin() { filteredComments = comments break } - if hasPushAccess { + if info.HasPushAccess { filteredComments = append(filteredComments, comment) } } @@ -353,14 +354,14 @@ func GetPullRequestReviews(ctx context.Context, client *github.Client, cache *lo for _, review := range reviews { login := review.GetUser().GetLogin() if login != "" { - isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) + info, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) if err != nil { return nil, fmt.Errorf("failed to check lockdown mode: %w", err) } - if isPrivate { + if info.IsPrivate || info.ViewerLogin == login { filteredReviews = reviews } - if hasPushAccess { + if info.HasPushAccess { filteredReviews = append(filteredReviews, review) } reviews = filteredReviews diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go index 8f57225ea..d7e2bed2b 100644 --- a/pkg/lockdown/lockdown.go +++ b/pkg/lockdown/lockdown.go @@ -23,8 +23,16 @@ type RepoAccessCache struct { } type repoAccessCacheEntry struct { - isPrivate bool - knownUsers map[string]bool // normalized login -> has push access + isPrivate bool + knownUsers map[string]bool // normalized login -> has push access + viewerLogin string +} + +// RepoAccessInfo captures repository metadata needed for lockdown decisions. +type RepoAccessInfo struct { + IsPrivate bool + HasPushAccess bool + ViewerLogin string } const ( @@ -101,12 +109,11 @@ type CacheStats struct { Evictions int64 } -// GetRepoAccessInfo returns the repository's privacy status and whether the -// specified user has push permissions. Results are cached per repository to -// avoid repeated GraphQL round-trips. -func (c *RepoAccessCache) GetRepoAccessInfo(ctx context.Context, username, owner, repo string) (bool, bool, error) { +// GetRepoAccessInfo returns repository access metadata for the provided user. +// Results are cached per repository to avoid repeated GraphQL round-trips. +func (c *RepoAccessCache) GetRepoAccessInfo(ctx context.Context, username, owner, repo string) (RepoAccessInfo, error) { if c == nil { - return false, false, fmt.Errorf("nil repo access cache") + return RepoAccessInfo{}, fmt.Errorf("nil repo access cache") } key := cacheKey(owner, repo) @@ -120,41 +127,59 @@ func (c *RepoAccessCache) GetRepoAccessInfo(ctx context.Context, username, owner entry := cacheItem.Data().(*repoAccessCacheEntry) if cachedHasPush, known := entry.knownUsers[userKey]; known { c.logDebug("repo access cache hit", "owner", owner, "repo", repo, "user", username) - return entry.isPrivate, cachedHasPush, nil + return RepoAccessInfo{ + IsPrivate: entry.isPrivate, + HasPushAccess: cachedHasPush, + ViewerLogin: entry.viewerLogin, + }, nil } c.logDebug("known users cache miss", "owner", owner, "repo", repo, "user", username) - _, hasPush, queryErr := c.queryRepoAccessInfo(ctx, username, owner, repo) + info, queryErr := c.queryRepoAccessInfo(ctx, username, owner, repo) if queryErr != nil { - return false, false, queryErr + return RepoAccessInfo{}, queryErr } - entry.knownUsers[userKey] = hasPush + entry.knownUsers[userKey] = info.HasPushAccess + entry.viewerLogin = info.ViewerLogin + entry.isPrivate = info.IsPrivate c.cache.Add(key, c.ttl, entry) - return entry.isPrivate, entry.knownUsers[userKey], nil + return RepoAccessInfo{ + IsPrivate: entry.isPrivate, + HasPushAccess: entry.knownUsers[userKey], + ViewerLogin: entry.viewerLogin, + }, nil } c.logDebug("repo access cache miss", "owner", owner, "repo", repo, "user", username) - isPrivate, hasPush, queryErr := c.queryRepoAccessInfo(ctx, username, owner, repo) + info, queryErr := c.queryRepoAccessInfo(ctx, username, owner, repo) if queryErr != nil { - return false, false, queryErr + return RepoAccessInfo{}, queryErr } // Create new entry entry := &repoAccessCacheEntry{ - knownUsers: map[string]bool{userKey: hasPush}, - isPrivate: isPrivate, + knownUsers: map[string]bool{userKey: info.HasPushAccess}, + isPrivate: info.IsPrivate, + viewerLogin: info.ViewerLogin, } c.cache.Add(key, c.ttl, entry) - return entry.isPrivate, entry.knownUsers[userKey], nil + return RepoAccessInfo{ + IsPrivate: entry.isPrivate, + HasPushAccess: entry.knownUsers[userKey], + ViewerLogin: entry.viewerLogin, + }, nil } -func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, owner, repo string) (bool, bool, error) { +func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, owner, repo string) (RepoAccessInfo, error) { if c.client == nil { - return false, false, fmt.Errorf("nil GraphQL client") + return RepoAccessInfo{}, fmt.Errorf("nil GraphQL client") } var query struct { + Viewer struct { + Login githubv4.String + } Repository struct { IsPrivate githubv4.Boolean Collaborators struct { @@ -175,7 +200,7 @@ func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, own } if err := c.client.Query(ctx, &query, variables); err != nil { - return false, false, fmt.Errorf("failed to query repository access info: %w", err) + return RepoAccessInfo{}, fmt.Errorf("failed to query repository access info: %w", err) } hasPush := false @@ -188,7 +213,11 @@ func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, own } } - return bool(query.Repository.IsPrivate), hasPush, nil + return RepoAccessInfo{ + IsPrivate: bool(query.Repository.IsPrivate), + HasPushAccess: hasPush, + ViewerLogin: string(query.Viewer.Login), + }, nil } func cacheKey(owner, repo string) string { diff --git a/pkg/lockdown/lockdown_test.go b/pkg/lockdown/lockdown_test.go index 3312e1fdf..65906c0c8 100644 --- a/pkg/lockdown/lockdown_test.go +++ b/pkg/lockdown/lockdown_test.go @@ -18,6 +18,9 @@ const ( ) type repoAccessQuery struct { + Viewer struct { + Login githubv4.String + } Repository struct { IsPrivate githubv4.Boolean Collaborators struct { @@ -62,6 +65,9 @@ func newMockRepoAccessCache(t *testing.T, ttl time.Duration) (*RepoAccessCache, } response := githubv4mock.DataResponse(map[string]any{ + "viewer": map[string]any{ + "login": testUser, + }, "repository": map[string]any{ "isPrivate": false, "collaborators": map[string]any{ @@ -90,13 +96,17 @@ func TestRepoAccessCacheEvictsAfterTTL(t *testing.T) { ctx := t.Context() cache, transport := newMockRepoAccessCache(t, 5*time.Millisecond) - _, _, err := cache.GetRepoAccessInfo(ctx, testUser, testOwner, testRepo) + info, err := cache.GetRepoAccessInfo(ctx, testUser, testOwner, testRepo) require.NoError(t, err) + require.Equal(t, testUser, info.ViewerLogin) + require.True(t, info.HasPushAccess) require.EqualValues(t, 1, transport.CallCount()) time.Sleep(20 * time.Millisecond) - _, _, err = cache.GetRepoAccessInfo(ctx, testUser, testOwner, testRepo) + info, err = cache.GetRepoAccessInfo(ctx, testUser, testOwner, testRepo) require.NoError(t, err) + require.Equal(t, testUser, info.ViewerLogin) + require.True(t, info.HasPushAccess) require.EqualValues(t, 2, transport.CallCount()) } From c8d5b6cfe20c2ece598624903b52b9c6b9711393 Mon Sep 17 00:00:00 2001 From: JoannaaKL Date: Thu, 20 Nov 2025 15:05:59 +0100 Subject: [PATCH 17/19] Hide lockdown logic behind shouldFilter function --- pkg/github/issues.go | 22 ++++++---------------- pkg/github/pullrequests.go | 20 ++++++-------------- pkg/lockdown/lockdown.go | 16 +++++++++++++--- pkg/lockdown/lockdown_test.go | 4 ++-- 4 files changed, 27 insertions(+), 35 deletions(-) diff --git a/pkg/github/issues.go b/pkg/github/issues.go index f3c902016..2bdd2b5cb 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -331,11 +331,11 @@ func GetIssue(ctx context.Context, client *github.Client, cache *lockdown.RepoAc } login := issue.GetUser().GetLogin() if login != "" { - info, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) + isSafeContent, err := cache.IsSafeContent(ctx, login, owner, repo) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil } - if info.ViewerLogin != login && !info.IsPrivate && !info.HasPushAccess { + if !isSafeContent { return mcp.NewToolResultError("access to issue details is restricted by lockdown mode"), nil } } @@ -394,16 +394,11 @@ func GetIssueComments(ctx context.Context, client *github.Client, cache *lockdow if login == "" { continue } - info, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) + isSafeContent, err := cache.IsSafeContent(ctx, login, owner, repo) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil } - // Do not filter content for private repositories or if the comment author is the viewer - if info.IsPrivate || info.ViewerLogin == login { - filteredComments = comments - break - } - if info.HasPushAccess { + if !isSafeContent { filteredComments = append(filteredComments, comment) } } @@ -459,16 +454,11 @@ func GetSubIssues(ctx context.Context, client *github.Client, cache *lockdown.Re if login == "" { continue } - info, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) + isSafeContent, err := cache.IsSafeContent(ctx, login, owner, repo) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil } - // Repo is private or the comment author is the viewer, do not filter content - if info.IsPrivate || info.ViewerLogin == login { - filteredSubIssues = subIssues - break - } - if info.HasPushAccess { + if !isSafeContent { filteredSubIssues = append(filteredSubIssues, subIssue) } } diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index c96f4b50a..46d9d064e 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -141,12 +141,12 @@ func GetPullRequest(ctx context.Context, client *github.Client, cache *lockdown. } login := pr.GetUser().GetLogin() if login != "" { - info, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) + isSafeContent, err := cache.IsSafeContent(ctx, login, owner, repo) if err != nil { return nil, fmt.Errorf("failed to check content removal: %w", err) } - if info.ViewerLogin != login && !info.IsPrivate && !info.HasPushAccess { + if !isSafeContent { return mcp.NewToolResultError("access to pull request is restricted by lockdown mode"), nil } } @@ -303,16 +303,11 @@ func GetPullRequestReviewComments(ctx context.Context, client *github.Client, ca if user == nil { continue } - info, err := cache.GetRepoAccessInfo(ctx, user.GetLogin(), owner, repo) + isSafeContent, err := cache.IsSafeContent(ctx, user.GetLogin(), owner, repo) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil } - // Do not filter content for private repositories or if the comment author is the viewer - if info.IsPrivate || info.ViewerLogin == user.GetLogin() { - filteredComments = comments - break - } - if info.HasPushAccess { + if !isSafeContent { filteredComments = append(filteredComments, comment) } } @@ -354,14 +349,11 @@ func GetPullRequestReviews(ctx context.Context, client *github.Client, cache *lo for _, review := range reviews { login := review.GetUser().GetLogin() if login != "" { - info, err := cache.GetRepoAccessInfo(ctx, login, owner, repo) + isSafeContent, err := cache.IsSafeContent(ctx, login, owner, repo) if err != nil { return nil, fmt.Errorf("failed to check lockdown mode: %w", err) } - if info.IsPrivate || info.ViewerLogin == login { - filteredReviews = reviews - } - if info.HasPushAccess { + if !isSafeContent { filteredReviews = append(filteredReviews, review) } reviews = filteredReviews diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go index d7e2bed2b..d7e444d77 100644 --- a/pkg/lockdown/lockdown.go +++ b/pkg/lockdown/lockdown.go @@ -109,9 +109,19 @@ type CacheStats struct { Evictions int64 } -// GetRepoAccessInfo returns repository access metadata for the provided user. -// Results are cached per repository to avoid repeated GraphQL round-trips. -func (c *RepoAccessCache) GetRepoAccessInfo(ctx context.Context, username, owner, repo string) (RepoAccessInfo, error) { +func (c *RepoAccessCache) IsSafeContent(ctx context.Context, username, owner, repo string) (bool, error) { + repoInfo, err := c.getRepoAccessInfo(ctx, username, owner, repo) + if err != nil { + c.logDebug("error checking repo access info for content filtering", "owner", owner, "repo", repo, "user", username, "error", err) + return false, err + } + if repoInfo.IsPrivate || repoInfo.ViewerLogin == username { + return true, nil + } + return repoInfo.HasPushAccess, nil +} + +func (c *RepoAccessCache) getRepoAccessInfo(ctx context.Context, username, owner, repo string) (RepoAccessInfo, error) { if c == nil { return RepoAccessInfo{}, fmt.Errorf("nil repo access cache") } diff --git a/pkg/lockdown/lockdown_test.go b/pkg/lockdown/lockdown_test.go index 65906c0c8..c1cf5e86b 100644 --- a/pkg/lockdown/lockdown_test.go +++ b/pkg/lockdown/lockdown_test.go @@ -96,7 +96,7 @@ func TestRepoAccessCacheEvictsAfterTTL(t *testing.T) { ctx := t.Context() cache, transport := newMockRepoAccessCache(t, 5*time.Millisecond) - info, err := cache.GetRepoAccessInfo(ctx, testUser, testOwner, testRepo) + info, err := cache.getRepoAccessInfo(ctx, testUser, testOwner, testRepo) require.NoError(t, err) require.Equal(t, testUser, info.ViewerLogin) require.True(t, info.HasPushAccess) @@ -104,7 +104,7 @@ func TestRepoAccessCacheEvictsAfterTTL(t *testing.T) { time.Sleep(20 * time.Millisecond) - info, err = cache.GetRepoAccessInfo(ctx, testUser, testOwner, testRepo) + info, err = cache.getRepoAccessInfo(ctx, testUser, testOwner, testRepo) require.NoError(t, err) require.Equal(t, testUser, info.ViewerLogin) require.True(t, info.HasPushAccess) From 447c9022aee7b9333609f5d1364f3a3c58793ebf Mon Sep 17 00:00:00 2001 From: JoannaaKL Date: Fri, 21 Nov 2025 10:17:49 +0100 Subject: [PATCH 18/19] . --- pkg/github/issues.go | 4 ++-- pkg/github/pullrequests.go | 4 ++-- pkg/lockdown/lockdown.go | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 2bdd2b5cb..f35168705 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -398,7 +398,7 @@ func GetIssueComments(ctx context.Context, client *github.Client, cache *lockdow if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil } - if !isSafeContent { + if isSafeContent { filteredComments = append(filteredComments, comment) } } @@ -458,7 +458,7 @@ func GetSubIssues(ctx context.Context, client *github.Client, cache *lockdown.Re if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil } - if !isSafeContent { + if isSafeContent { filteredSubIssues = append(filteredSubIssues, subIssue) } } diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 46d9d064e..6fb5ed30b 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -307,7 +307,7 @@ func GetPullRequestReviewComments(ctx context.Context, client *github.Client, ca if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil } - if !isSafeContent { + if isSafeContent { filteredComments = append(filteredComments, comment) } } @@ -353,7 +353,7 @@ func GetPullRequestReviews(ctx context.Context, client *github.Client, cache *lo if err != nil { return nil, fmt.Errorf("failed to check lockdown mode: %w", err) } - if !isSafeContent { + if isSafeContent { filteredReviews = append(filteredReviews, review) } reviews = filteredReviews diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go index d7e444d77..4c3500440 100644 --- a/pkg/lockdown/lockdown.go +++ b/pkg/lockdown/lockdown.go @@ -36,7 +36,7 @@ type RepoAccessInfo struct { } const ( - defaultRepoAccessTTL = 5 * time.Minute + defaultRepoAccessTTL = 20 * time.Minute defaultRepoAccessCacheKey = "repo-access-cache" ) From f40df276aa5e5cff380f802bf07741e27beaace1 Mon Sep 17 00:00:00 2001 From: JoannaaKL Date: Fri, 21 Nov 2025 10:27:11 +0100 Subject: [PATCH 19/19] Tests --- pkg/github/issues_test.go | 58 +++++++++++++-- pkg/github/pullrequests_test.go | 123 ++++++++++++++++++++++++++++---- 2 files changed, 164 insertions(+), 17 deletions(-) diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index c73f2bc80..a05312b91 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -1864,10 +1864,12 @@ func Test_GetIssueComments(t *testing.T) { tests := []struct { name string mockedClient *http.Client + gqlHTTPClient *http.Client requestArgs map[string]interface{} expectError bool expectedComments []*github.IssueComment expectedErrMsg string + lockdownEnabled bool }{ { name: "successful comments retrieval", @@ -1927,14 +1929,57 @@ func Test_GetIssueComments(t *testing.T) { expectError: true, expectedErrMsg: "failed to get issue comments", }, + { + name: "lockdown enabled filters comments without push access", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposIssuesCommentsByOwnerByRepoByIssueNumber, + []*github.IssueComment{ + { + ID: github.Ptr(int64(789)), + Body: github.Ptr("Maintainer comment"), + User: &github.User{Login: github.Ptr("maintainer")}, + }, + { + ID: github.Ptr(int64(790)), + Body: github.Ptr("External user comment"), + User: &github.User{Login: github.Ptr("testuser")}, + }, + }, + ), + ), + gqlHTTPClient: newRepoAccessHTTPClient(), + requestArgs: map[string]interface{}{ + "method": "get_comments", + "owner": "owner", + "repo": "repo", + "issue_number": float64(42), + }, + expectError: false, + expectedComments: []*github.IssueComment{ + { + ID: github.Ptr(int64(789)), + Body: github.Ptr("Maintainer comment"), + User: &github.User{Login: github.Ptr("maintainer")}, + }, + }, + lockdownEnabled: true, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - gqlClient := githubv4.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + var gqlClient *githubv4.Client + if tc.gqlHTTPClient != nil { + gqlClient = githubv4.NewClient(tc.gqlHTTPClient) + } else { + gqlClient = githubv4.NewClient(nil) + } + cache := stubRepoAccessCache(gqlClient, 15*time.Minute) + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), cache, translations.NullTranslationHelper, flags) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1957,9 +2002,12 @@ func Test_GetIssueComments(t *testing.T) { err = json.Unmarshal([]byte(textContent.Text), &returnedComments) require.NoError(t, err) assert.Equal(t, len(tc.expectedComments), len(returnedComments)) - if len(returnedComments) > 0 { - assert.Equal(t, *tc.expectedComments[0].Body, *returnedComments[0].Body) - assert.Equal(t, *tc.expectedComments[0].User.Login, *returnedComments[0].User.Login) + for i := range tc.expectedComments { + require.NotNil(t, tc.expectedComments[i].User) + require.NotNil(t, returnedComments[i].User) + assert.Equal(t, tc.expectedComments[i].GetID(), returnedComments[i].GetID()) + assert.Equal(t, tc.expectedComments[i].GetBody(), returnedComments[i].GetBody()) + assert.Equal(t, tc.expectedComments[i].GetUser().GetLogin(), returnedComments[i].GetUser().GetLogin()) } }) } diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index ac2d47ae0..6eac5ce83 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -1610,10 +1610,12 @@ func Test_GetPullRequestComments(t *testing.T) { tests := []struct { name string mockedClient *http.Client + gqlHTTPClient *http.Client requestArgs map[string]interface{} expectError bool expectedComments []*github.PullRequestComment expectedErrMsg string + lockdownEnabled bool }{ { name: "successful comments fetch", @@ -1652,13 +1654,57 @@ func Test_GetPullRequestComments(t *testing.T) { expectError: true, expectedErrMsg: "failed to get pull request review comments", }, + { + name: "lockdown enabled filters review comments without push access", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsCommentsByOwnerByRepoByPullNumber, + []*github.PullRequestComment{ + { + ID: github.Ptr(int64(2010)), + Body: github.Ptr("Maintainer review comment"), + User: &github.User{Login: github.Ptr("maintainer")}, + }, + { + ID: github.Ptr(int64(2011)), + Body: github.Ptr("External review comment"), + User: &github.User{Login: github.Ptr("testuser")}, + }, + }, + ), + ), + gqlHTTPClient: newRepoAccessHTTPClient(), + requestArgs: map[string]interface{}{ + "method": "get_review_comments", + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + }, + expectError: false, + expectedComments: []*github.PullRequestComment{ + { + ID: github.Ptr(int64(2010)), + Body: github.Ptr("Maintainer review comment"), + User: &github.User{Login: github.Ptr("maintainer")}, + }, + }, + lockdownEnabled: true, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + var gqlClient *githubv4.Client + if tc.gqlHTTPClient != nil { + gqlClient = githubv4.NewClient(tc.gqlHTTPClient) + } else { + gqlClient = githubv4.NewClient(nil) + } + cache := stubRepoAccessCache(gqlClient, 5*time.Minute) + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) + _, handler := PullRequestRead(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1687,11 +1733,13 @@ func Test_GetPullRequestComments(t *testing.T) { require.NoError(t, err) assert.Len(t, returnedComments, len(tc.expectedComments)) for i, comment := range returnedComments { - assert.Equal(t, *tc.expectedComments[i].ID, *comment.ID) - assert.Equal(t, *tc.expectedComments[i].Body, *comment.Body) - assert.Equal(t, *tc.expectedComments[i].User.Login, *comment.User.Login) - assert.Equal(t, *tc.expectedComments[i].Path, *comment.Path) - assert.Equal(t, *tc.expectedComments[i].HTMLURL, *comment.HTMLURL) + require.NotNil(t, tc.expectedComments[i].User) + require.NotNil(t, comment.User) + assert.Equal(t, tc.expectedComments[i].GetID(), comment.GetID()) + assert.Equal(t, tc.expectedComments[i].GetBody(), comment.GetBody()) + assert.Equal(t, tc.expectedComments[i].GetUser().GetLogin(), comment.GetUser().GetLogin()) + assert.Equal(t, tc.expectedComments[i].GetPath(), comment.GetPath()) + assert.Equal(t, tc.expectedComments[i].GetHTMLURL(), comment.GetHTMLURL()) } }) } @@ -1740,10 +1788,12 @@ func Test_GetPullRequestReviews(t *testing.T) { tests := []struct { name string mockedClient *http.Client + gqlHTTPClient *http.Client requestArgs map[string]interface{} expectError bool expectedReviews []*github.PullRequestReview expectedErrMsg string + lockdownEnabled bool }{ { name: "successful reviews fetch", @@ -1782,13 +1832,60 @@ func Test_GetPullRequestReviews(t *testing.T) { expectError: true, expectedErrMsg: "failed to get pull request reviews", }, + { + name: "lockdown enabled filters reviews without push access", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsReviewsByOwnerByRepoByPullNumber, + []*github.PullRequestReview{ + { + ID: github.Ptr(int64(2030)), + State: github.Ptr("APPROVED"), + Body: github.Ptr("Maintainer review"), + User: &github.User{Login: github.Ptr("maintainer")}, + }, + { + ID: github.Ptr(int64(2031)), + State: github.Ptr("COMMENTED"), + Body: github.Ptr("External reviewer"), + User: &github.User{Login: github.Ptr("testuser")}, + }, + }, + ), + ), + gqlHTTPClient: newRepoAccessHTTPClient(), + requestArgs: map[string]interface{}{ + "method": "get_reviews", + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + }, + expectError: false, + expectedReviews: []*github.PullRequestReview{ + { + ID: github.Ptr(int64(2030)), + State: github.Ptr("APPROVED"), + Body: github.Ptr("Maintainer review"), + User: &github.User{Login: github.Ptr("maintainer")}, + }, + }, + lockdownEnabled: true, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + var gqlClient *githubv4.Client + if tc.gqlHTTPClient != nil { + gqlClient = githubv4.NewClient(tc.gqlHTTPClient) + } else { + gqlClient = githubv4.NewClient(nil) + } + cache := stubRepoAccessCache(gqlClient, 5*time.Minute) + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) + _, handler := PullRequestRead(stubGetClientFn(client), cache, translations.NullTranslationHelper, flags) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1817,11 +1914,13 @@ func Test_GetPullRequestReviews(t *testing.T) { require.NoError(t, err) assert.Len(t, returnedReviews, len(tc.expectedReviews)) for i, review := range returnedReviews { - assert.Equal(t, *tc.expectedReviews[i].ID, *review.ID) - assert.Equal(t, *tc.expectedReviews[i].State, *review.State) - assert.Equal(t, *tc.expectedReviews[i].Body, *review.Body) - assert.Equal(t, *tc.expectedReviews[i].User.Login, *review.User.Login) - assert.Equal(t, *tc.expectedReviews[i].HTMLURL, *review.HTMLURL) + require.NotNil(t, tc.expectedReviews[i].User) + require.NotNil(t, review.User) + assert.Equal(t, tc.expectedReviews[i].GetID(), review.GetID()) + assert.Equal(t, tc.expectedReviews[i].GetState(), review.GetState()) + assert.Equal(t, tc.expectedReviews[i].GetBody(), review.GetBody()) + assert.Equal(t, tc.expectedReviews[i].GetUser().GetLogin(), review.GetUser().GetLogin()) + assert.Equal(t, tc.expectedReviews[i].GetHTMLURL(), review.GetHTMLURL()) } }) }