Skip to content

Commit e0cde71

Browse files
committed
WIP: List notifications tool
1 parent b72a591 commit e0cde71

File tree

4 files changed

+239
-0
lines changed

4 files changed

+239
-0
lines changed

pkg/github/notifications.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package github
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
10+
"github.com/github/github-mcp-server/pkg/translations"
11+
"github.com/google/go-github/v69/github"
12+
"github.com/mark3labs/mcp-go/mcp"
13+
"github.com/mark3labs/mcp-go/server"
14+
)
15+
16+
// ListNotifications creates a tool to list notifications for a GitHub user.
17+
func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
18+
return mcp.NewTool("list_notifications",
19+
mcp.WithDescription(t("TOOL_LIST_NOTIFICATIONS_DESCRIPTION", "List notifications for a GitHub user")),
20+
mcp.WithNumber("page",
21+
mcp.Description("Page number"),
22+
),
23+
mcp.WithNumber("per_page",
24+
mcp.Description("Number of records per page"),
25+
),
26+
mcp.WithBoolean("all",
27+
mcp.Description("Whether to fetch all notifications, including read ones"),
28+
),
29+
),
30+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
31+
page, err := OptionalIntParamWithDefault(request, "page", 1)
32+
if err != nil {
33+
return mcp.NewToolResultError(err.Error()), nil
34+
}
35+
perPage, err := OptionalIntParamWithDefault(request, "per_page", 30)
36+
if err != nil {
37+
return mcp.NewToolResultError(err.Error()), nil
38+
}
39+
all, err := OptionalBoolParamWithDefault(request, "all", false) // Default to false unless specified
40+
if err != nil {
41+
return mcp.NewToolResultError(err.Error()), nil
42+
}
43+
44+
if request.Params.Arguments["all"] == true {
45+
all = true // Set to true if user explicitly asks for all notifications
46+
}
47+
48+
opts := &github.NotificationListOptions{
49+
ListOptions: github.ListOptions{
50+
Page: page,
51+
PerPage: perPage,
52+
},
53+
All: all, // Include all notifications, even those already read.
54+
}
55+
56+
client, err := getClient(ctx)
57+
if err != nil {
58+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
59+
}
60+
notifications, resp, err := client.Activity.ListNotifications(ctx, opts)
61+
if err != nil {
62+
return nil, fmt.Errorf("failed to list notifications: %w", err)
63+
}
64+
defer func() { _ = resp.Body.Close() }()
65+
66+
if resp.StatusCode != http.StatusOK {
67+
body, err := io.ReadAll(resp.Body)
68+
if err != nil {
69+
return nil, fmt.Errorf("failed to read response body: %w", err)
70+
}
71+
return mcp.NewToolResultError(fmt.Sprintf("failed to list notifications: %s", string(body))), nil
72+
}
73+
74+
// Extract the notification title in addition to reason, url, and timestamp.
75+
var extractedNotifications []map[string]interface{}
76+
for _, notification := range notifications {
77+
extractedNotifications = append(extractedNotifications, map[string]interface{}{
78+
"title": notification.GetSubject().GetTitle(),
79+
"reason": notification.GetReason(),
80+
"url": notification.GetURL(),
81+
"timestamp": notification.GetUpdatedAt(),
82+
})
83+
}
84+
85+
r, err := json.Marshal(extractedNotifications)
86+
if err != nil {
87+
return nil, fmt.Errorf("failed to marshal notifications: %w", err)
88+
}
89+
90+
return mcp.NewToolResultText(string(r)), nil
91+
}
92+
}

pkg/github/notifications_test.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
package github
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"net/http"
7+
"testing"
8+
"time"
9+
10+
"github.com/github/github-mcp-server/pkg/translations"
11+
"github.com/google/go-github/v69/github"
12+
"github.com/migueleliasweb/go-github-mock/src/mock"
13+
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
func Test_ListNotifications(t *testing.T) {
18+
// Verify tool definition
19+
mockClient := github.NewClient(nil)
20+
tool, _ := ListNotifications(stubGetClientFn(mockClient), translations.NullTranslationHelper)
21+
22+
assert.Equal(t, "list_notifications", tool.Name)
23+
assert.NotEmpty(t, tool.Description)
24+
assert.Contains(t, tool.InputSchema.Properties, "page")
25+
assert.Contains(t, tool.InputSchema.Properties, "per_page")
26+
assert.Contains(t, tool.InputSchema.Properties, "all")
27+
28+
// Setup mock notifications
29+
mockNotifications := []*github.Notification{
30+
{
31+
ID: github.String("1"),
32+
Reason: github.String("mention"),
33+
Subject: &github.NotificationSubject{
34+
Title: github.String("Test Notification 1"),
35+
},
36+
UpdatedAt: &github.Timestamp{Time: time.Now()},
37+
URL: github.String("https://example.com/notifications/threads/1"),
38+
},
39+
{
40+
ID: github.String("2"),
41+
Reason: github.String("team_mention"),
42+
Subject: &github.NotificationSubject{
43+
Title: github.String("Test Notification 2"),
44+
},
45+
UpdatedAt: &github.Timestamp{Time: time.Now()},
46+
URL: github.String("https://example.com/notifications/threads/1"),
47+
},
48+
}
49+
50+
tests := []struct {
51+
name string
52+
mockedClient *http.Client
53+
requestArgs map[string]interface{}
54+
expectError bool
55+
expectedResponse []*github.Notification
56+
expectedErrMsg string
57+
}{
58+
{
59+
name: "list all notifications",
60+
mockedClient: mock.NewMockedHTTPClient(
61+
mock.WithRequestMatch(
62+
mock.GetNotifications,
63+
mockNotifications,
64+
),
65+
),
66+
requestArgs: map[string]interface{}{
67+
"all": true,
68+
},
69+
expectError: false,
70+
expectedResponse: mockNotifications,
71+
},
72+
{
73+
name: "list unread notifications",
74+
mockedClient: mock.NewMockedHTTPClient(
75+
mock.WithRequestMatch(
76+
mock.GetNotifications,
77+
mockNotifications[:1], // Only the first notification
78+
),
79+
),
80+
requestArgs: map[string]interface{}{
81+
"all": false,
82+
},
83+
expectError: false,
84+
expectedResponse: mockNotifications[:1],
85+
},
86+
}
87+
88+
for _, tc := range tests {
89+
t.Run(tc.name, func(t *testing.T) {
90+
// Setup client with mock
91+
client := github.NewClient(tc.mockedClient)
92+
_, handler := ListNotifications(stubGetClientFn(client), translations.NullTranslationHelper)
93+
94+
// Create call request
95+
request := createMCPRequest(tc.requestArgs)
96+
// Call handler
97+
result, err := handler(context.Background(), request)
98+
99+
// Verify results
100+
if tc.expectError {
101+
require.Error(t, err)
102+
assert.Contains(t, err.Error(), tc.expectedErrMsg)
103+
return
104+
}
105+
106+
require.NoError(t, err)
107+
textContent := getTextResult(t, result)
108+
109+
// Unmarshal and verify the result
110+
var returnedNotifications []*github.Notification
111+
err = json.Unmarshal([]byte(textContent.Text), &returnedNotifications)
112+
require.NoError(t, err)
113+
assert.Equal(t, len(tc.expectedResponse), len(returnedNotifications))
114+
for i, notification := range returnedNotifications {
115+
// Ensure all required fields are mocked
116+
assert.NotNil(t, notification.Subject, "Subject should not be nil")
117+
assert.NotNil(t, notification.Subject.Title, "Title should not be nil")
118+
assert.NotNil(t, notification.Reason, "Reason should not be nil")
119+
assert.NotNil(t, notification.URL, "URL should not be nil")
120+
assert.NotNil(t, notification.UpdatedAt, "UpdatedAt should not be nil")
121+
// assert.Equal(t, *tc.expectedResponse[i].ID, *notification.ID)
122+
assert.Equal(t, *tc.expectedResponse[i].Reason, *notification.Reason)
123+
// assert.Equal(t, *tc.expectedResponse[i].Subject.Title, *notification.Subject.Title)
124+
}
125+
})
126+
}
127+
}

pkg/github/server.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,20 @@ func OptionalIntParam(r mcp.CallToolRequest, p string) (int, error) {
130130
return int(v), nil
131131
}
132132

133+
// OptionalBoolParamWithDefault is a helper function that retrieves a boolean parameter from the request.
134+
// If the parameter is not present, it returns the provided default value. If the parameter is present,
135+
// it validates its type and returns the value.
136+
func OptionalBoolParamWithDefault(request mcp.CallToolRequest, s string, b bool) (bool, error) {
137+
v, err := OptionalParam[bool](request, s)
138+
if err != nil {
139+
return false, err
140+
}
141+
if b == false {
142+
return b, nil
143+
}
144+
return v, nil
145+
}
146+
133147
// OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
134148
// similar to optionalIntParam, but it also takes a default value.
135149
func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) {

pkg/github/tools.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,19 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
7676
// Keep experiments alive so the system doesn't error out when it's always enabled
7777
experiments := toolsets.NewToolset("experiments", "Experimental features that are not considered stable yet")
7878

79+
notifications := toolsets.NewToolset("notifications", "GitHub Notifications related tools").
80+
AddReadTools(
81+
toolsets.NewServerTool(ListNotifications(getClient, t)),
82+
)
83+
7984
// Add toolsets to the group
8085
tsg.AddToolset(repos)
8186
tsg.AddToolset(issues)
8287
tsg.AddToolset(users)
8388
tsg.AddToolset(pullRequests)
8489
tsg.AddToolset(codeSecurity)
8590
tsg.AddToolset(experiments)
91+
tsg.AddToolset(notifications)
8692
// Enable the requested features
8793

8894
if err := tsg.EnableToolsets(passedToolsets); err != nil {

0 commit comments

Comments
 (0)