Skip to content

Commit 6dfdaf5

Browse files
committed
Use lockdown in issues and pullrequests tools
1 parent a63d5d1 commit 6dfdaf5

File tree

6 files changed

+197
-29
lines changed

6 files changed

+197
-29
lines changed

pkg/github/helper_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
package github
22

33
import (
4+
"bytes"
45
"encoding/json"
6+
"fmt"
7+
"io"
58
"net/http"
9+
"strings"
610
"testing"
711

812
"github.com/mark3labs/mcp-go/mcp"
@@ -108,6 +112,74 @@ func mockResponse(t *testing.T, code int, body interface{}) http.HandlerFunc {
108112
}
109113
}
110114

115+
type roundTripperFunc func(*http.Request) (*http.Response, error)
116+
117+
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
118+
return f(req)
119+
}
120+
121+
func newLockdownMockGQLHTTPClient(t *testing.T, isPrivate bool, permissions map[string]string) *http.Client {
122+
t.Helper()
123+
124+
lowerPermissions := make(map[string]string, len(permissions))
125+
for user, perm := range permissions {
126+
lowerPermissions[strings.ToLower(user)] = perm
127+
}
128+
129+
return &http.Client{
130+
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
131+
require.Equal(t, "/graphql", req.URL.Path)
132+
133+
bodyBytes, err := io.ReadAll(req.Body)
134+
require.NoError(t, err)
135+
_ = req.Body.Close()
136+
137+
var gqlRequest struct {
138+
Variables map[string]any `json:"variables"`
139+
}
140+
require.NoError(t, json.Unmarshal(bodyBytes, &gqlRequest))
141+
142+
rawUsername, ok := gqlRequest.Variables["username"]
143+
require.True(t, ok, "expected username variable in GraphQL request")
144+
145+
username := fmt.Sprint(rawUsername)
146+
permission := lowerPermissions[strings.ToLower(username)]
147+
148+
edges := []any{}
149+
if permission != "" {
150+
edges = append(edges, map[string]any{
151+
"permission": permission,
152+
"node": map[string]any{
153+
"login": username,
154+
},
155+
})
156+
}
157+
158+
response := map[string]any{
159+
"data": map[string]any{
160+
"repository": map[string]any{
161+
"isPrivate": isPrivate,
162+
"collaborators": map[string]any{
163+
"edges": edges,
164+
},
165+
},
166+
},
167+
}
168+
169+
respBytes, err := json.Marshal(response)
170+
require.NoError(t, err)
171+
172+
res := &http.Response{
173+
StatusCode: http.StatusOK,
174+
Header: make(http.Header),
175+
Body: io.NopCloser(bytes.NewReader(respBytes)),
176+
}
177+
res.Header.Set("Content-Type", "application/json")
178+
return res, nil
179+
}),
180+
}
181+
}
182+
111183
// createMCPRequest is a helper function to create a MCP request with the given arguments.
112184
func createMCPRequest(args any) mcp.CallToolRequest {
113185
return mcp.CallToolRequest{

pkg/github/issues.go

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,9 @@ Options are:
299299
case "get":
300300
return GetIssue(ctx, client, gqlClient, owner, repo, issueNumber, flags)
301301
case "get_comments":
302-
return GetIssueComments(ctx, client, owner, repo, issueNumber, pagination, flags)
302+
return GetIssueComments(ctx, client, gqlClient, owner, repo, issueNumber, pagination, flags)
303303
case "get_sub_issues":
304-
return GetSubIssues(ctx, client, owner, repo, issueNumber, pagination, flags)
304+
return GetSubIssues(ctx, client, gqlClient, owner, repo, issueNumber, pagination, flags)
305305
case "get_labels":
306306
return GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber, flags)
307307
default:
@@ -355,7 +355,7 @@ func GetIssue(ctx context.Context, client *github.Client, gqlClient *githubv4.Cl
355355
return mcp.NewToolResultText(string(r)), nil
356356
}
357357

358-
func GetIssueComments(ctx context.Context, client *github.Client, owner string, repo string, issueNumber int, pagination PaginationParams, _ FeatureFlags) (*mcp.CallToolResult, error) {
358+
func GetIssueComments(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner string, repo string, issueNumber int, pagination PaginationParams, flags FeatureFlags) (*mcp.CallToolResult, error) {
359359
opts := &github.IssueListCommentsOptions{
360360
ListOptions: github.ListOptions{
361361
Page: pagination.Page,
@@ -377,6 +377,24 @@ func GetIssueComments(ctx context.Context, client *github.Client, owner string,
377377
return mcp.NewToolResultError(fmt.Sprintf("failed to get issue comments: %s", string(body))), nil
378378
}
379379

380+
if flags.LockdownMode {
381+
filtered := make([]*github.IssueComment, 0, len(comments))
382+
for _, comment := range comments {
383+
if comment == nil || comment.User == nil || comment.User.Login == nil {
384+
continue
385+
}
386+
shouldRemove, err := lockdown.ShouldRemoveContent(ctx, gqlClient, comment.User.GetLogin(), owner, repo)
387+
if err != nil {
388+
return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil
389+
}
390+
if shouldRemove {
391+
continue
392+
}
393+
filtered = append(filtered, comment)
394+
}
395+
comments = filtered
396+
}
397+
380398
r, err := json.Marshal(comments)
381399
if err != nil {
382400
return nil, fmt.Errorf("failed to marshal response: %w", err)
@@ -385,7 +403,7 @@ func GetIssueComments(ctx context.Context, client *github.Client, owner string,
385403
return mcp.NewToolResultText(string(r)), nil
386404
}
387405

388-
func GetSubIssues(ctx context.Context, client *github.Client, owner string, repo string, issueNumber int, pagination PaginationParams, _ FeatureFlags) (*mcp.CallToolResult, error) {
406+
func GetSubIssues(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner string, repo string, issueNumber int, pagination PaginationParams, flags FeatureFlags) (*mcp.CallToolResult, error) {
389407
opts := &github.IssueListOptions{
390408
ListOptions: github.ListOptions{
391409
Page: pagination.Page,
@@ -412,6 +430,24 @@ func GetSubIssues(ctx context.Context, client *github.Client, owner string, repo
412430
return mcp.NewToolResultError(fmt.Sprintf("failed to list sub-issues: %s", string(body))), nil
413431
}
414432

433+
if flags.LockdownMode {
434+
filtered := make([]*github.SubIssue, 0, len(subIssues))
435+
for _, subIssue := range subIssues {
436+
if subIssue == nil || subIssue.User == nil || subIssue.User.Login == nil {
437+
continue
438+
}
439+
shouldRemove, err := lockdown.ShouldRemoveContent(ctx, gqlClient, subIssue.User.GetLogin(), owner, repo)
440+
if err != nil {
441+
return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil
442+
}
443+
if shouldRemove {
444+
continue
445+
}
446+
filtered = append(filtered, subIssue)
447+
}
448+
subIssues = filtered
449+
}
450+
415451
r, err := json.Marshal(subIssues)
416452
if err != nil {
417453
return nil, fmt.Errorf("failed to marshal response: %w", err)

pkg/github/issues_test.go

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1763,10 +1763,12 @@ func Test_GetIssueComments(t *testing.T) {
17631763
tests := []struct {
17641764
name string
17651765
mockedClient *http.Client
1766+
gqlHTTPClient *http.Client
17661767
requestArgs map[string]interface{}
17671768
expectError bool
17681769
expectedComments []*github.IssueComment
17691770
expectedErrMsg string
1771+
lockdownEnabled bool
17701772
}{
17711773
{
17721774
name: "successful comments retrieval",
@@ -1782,7 +1784,6 @@ func Test_GetIssueComments(t *testing.T) {
17821784
"repo": "repo",
17831785
"issue_number": float64(42),
17841786
},
1785-
expectError: false,
17861787
expectedComments: mockComments,
17871788
},
17881789
{
@@ -1809,6 +1810,27 @@ func Test_GetIssueComments(t *testing.T) {
18091810
expectError: false,
18101811
expectedComments: mockComments,
18111812
},
1813+
{
1814+
name: "lockdown enabled removes comments without push access",
1815+
mockedClient: mock.NewMockedHTTPClient(
1816+
mock.WithRequestMatch(
1817+
mock.GetReposIssuesCommentsByOwnerByRepoByIssueNumber,
1818+
mockComments,
1819+
),
1820+
),
1821+
gqlHTTPClient: newLockdownMockGQLHTTPClient(t, false, map[string]string{
1822+
"user1": "WRITE",
1823+
"user2": "READ",
1824+
}),
1825+
requestArgs: map[string]interface{}{
1826+
"method": "get_comments",
1827+
"owner": "owner",
1828+
"repo": "repo",
1829+
"issue_number": float64(42),
1830+
},
1831+
expectedComments: []*github.IssueComment{mockComments[0]},
1832+
lockdownEnabled: true,
1833+
},
18121834
{
18131835
name: "issue not found",
18141836
mockedClient: mock.NewMockedHTTPClient(
@@ -1832,8 +1854,14 @@ func Test_GetIssueComments(t *testing.T) {
18321854
t.Run(tc.name, func(t *testing.T) {
18331855
// Setup client with mock
18341856
client := github.NewClient(tc.mockedClient)
1835-
gqlClient := githubv4.NewClient(nil)
1836-
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))
1857+
var gqlClient *githubv4.Client
1858+
if tc.gqlHTTPClient != nil {
1859+
gqlClient = githubv4.NewClient(tc.gqlHTTPClient)
1860+
} else {
1861+
gqlClient = githubv4.NewClient(nil)
1862+
}
1863+
flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled})
1864+
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, flags)
18371865

18381866
// Create call request
18391867
request := createMCPRequest(tc.requestArgs)
@@ -1856,9 +1884,9 @@ func Test_GetIssueComments(t *testing.T) {
18561884
err = json.Unmarshal([]byte(textContent.Text), &returnedComments)
18571885
require.NoError(t, err)
18581886
assert.Equal(t, len(tc.expectedComments), len(returnedComments))
1859-
if len(returnedComments) > 0 {
1860-
assert.Equal(t, *tc.expectedComments[0].Body, *returnedComments[0].Body)
1861-
assert.Equal(t, *tc.expectedComments[0].User.Login, *returnedComments[0].User.Login)
1887+
for i := range returnedComments {
1888+
assert.Equal(t, *tc.expectedComments[i].Body, *returnedComments[i].Body)
1889+
assert.Equal(t, *tc.expectedComments[i].User.Login, *returnedComments[i].User.Login)
18621890
}
18631891
})
18641892
}
@@ -2686,10 +2714,12 @@ func Test_GetSubIssues(t *testing.T) {
26862714
tests := []struct {
26872715
name string
26882716
mockedClient *http.Client
2717+
gqlHTTPClient *http.Client
26892718
requestArgs map[string]interface{}
26902719
expectError bool
26912720
expectedSubIssues []*github.Issue
26922721
expectedErrMsg string
2722+
lockdownEnabled bool
26932723
}{
26942724
{
26952725
name: "successful sub-issues listing with minimal parameters",
@@ -2729,7 +2759,6 @@ func Test_GetSubIssues(t *testing.T) {
27292759
"page": float64(2),
27302760
"perPage": float64(10),
27312761
},
2732-
expectError: false,
27332762
expectedSubIssues: mockSubIssues,
27342763
},
27352764
{
@@ -2746,9 +2775,29 @@ func Test_GetSubIssues(t *testing.T) {
27462775
"repo": "repo",
27472776
"issue_number": float64(42),
27482777
},
2749-
expectError: false,
27502778
expectedSubIssues: []*github.Issue{},
27512779
},
2780+
{
2781+
name: "lockdown enabled filters sub-issues without push access",
2782+
mockedClient: mock.NewMockedHTTPClient(
2783+
mock.WithRequestMatch(
2784+
mock.GetReposIssuesSubIssuesByOwnerByRepoByIssueNumber,
2785+
mockSubIssues,
2786+
),
2787+
),
2788+
gqlHTTPClient: newLockdownMockGQLHTTPClient(t, false, map[string]string{
2789+
"user1": "WRITE",
2790+
"user2": "READ",
2791+
}),
2792+
requestArgs: map[string]interface{}{
2793+
"method": "get_sub_issues",
2794+
"owner": "owner",
2795+
"repo": "repo",
2796+
"issue_number": float64(42),
2797+
},
2798+
expectedSubIssues: []*github.Issue{mockSubIssues[0]},
2799+
lockdownEnabled: true,
2800+
},
27522801
{
27532802
name: "parent issue not found",
27542803
mockedClient: mock.NewMockedHTTPClient(
@@ -2832,8 +2881,14 @@ func Test_GetSubIssues(t *testing.T) {
28322881
t.Run(tc.name, func(t *testing.T) {
28332882
// Setup client with mock
28342883
client := github.NewClient(tc.mockedClient)
2835-
gqlClient := githubv4.NewClient(nil)
2836-
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))
2884+
var gqlClient *githubv4.Client
2885+
if tc.gqlHTTPClient != nil {
2886+
gqlClient = githubv4.NewClient(tc.gqlHTTPClient)
2887+
} else {
2888+
gqlClient = githubv4.NewClient(nil)
2889+
}
2890+
flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled})
2891+
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, flags)
28372892

28382893
// Create call request
28392894
request := createMCPRequest(tc.requestArgs)

pkg/github/pullrequests.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919
)
2020

2121
// GetPullRequest creates a tool to get details of a specific pull request.
22-
func PullRequestRead(getClient GetClientFn, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, server.ToolHandlerFunc) {
22+
func PullRequestRead(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, server.ToolHandlerFunc) {
2323
return mcp.NewTool("pull_request_read",
2424
mcp.WithDescription(t("TOOL_PULL_REQUEST_READ_DESCRIPTION", "Get information on a specific pull request in GitHub repository.")),
2525
mcp.WithToolAnnotation(mcp.ToolAnnotation{
@@ -98,7 +98,11 @@ Possible options:
9898
case "get_reviews":
9999
return GetPullRequestReviews(ctx, client, owner, repo, pullNumber)
100100
case "get_comments":
101-
return GetIssueComments(ctx, client, owner, repo, pullNumber, pagination, flags)
101+
gqlClient, err := getGQLClient(ctx)
102+
if err != nil {
103+
return nil, fmt.Errorf("failed to get GitHub graphql client: %w", err)
104+
}
105+
return GetIssueComments(ctx, client, gqlClient, owner, repo, pullNumber, pagination, flags)
102106
default:
103107
return nil, fmt.Errorf("unknown method: %s", method)
104108
}

0 commit comments

Comments
 (0)