Skip to content

Commit 1126d3c

Browse files
sridharavinashSamMorrowDrums
authored andcommitted
feat: add GitHub notifications tools for managing user notifications
1 parent 2f8c287 commit 1126d3c

File tree

2 files changed

+256
-0
lines changed

2 files changed

+256
-0
lines changed

pkg/github/notifications.go

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
package github
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"time"
10+
11+
"github.com/github/github-mcp-server/pkg/translations"
12+
"github.com/google/go-github/v69/github"
13+
"github.com/mark3labs/mcp-go/mcp"
14+
"github.com/mark3labs/mcp-go/server"
15+
)
16+
17+
// getNotifications creates a tool to list notifications for the current user.
18+
func getNotifications(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
19+
return mcp.NewTool("get_notifications",
20+
mcp.WithDescription(t("TOOL_GET_NOTIFICATIONS_DESCRIPTION", "Get notifications for the authenticated GitHub user")),
21+
mcp.WithBoolean("all",
22+
mcp.Description("If true, show notifications marked as read. Default: false"),
23+
),
24+
mcp.WithBoolean("participating",
25+
mcp.Description("If true, only shows notifications in which the user is directly participating or mentioned. Default: false"),
26+
),
27+
mcp.WithString("since",
28+
mcp.Description("Only show notifications updated after the given time (ISO 8601 format)"),
29+
),
30+
mcp.WithString("before",
31+
mcp.Description("Only show notifications updated before the given time (ISO 8601 format)"),
32+
),
33+
mcp.WithNumber("per_page",
34+
mcp.Description("Results per page (max 100). Default: 30"),
35+
),
36+
mcp.WithNumber("page",
37+
mcp.Description("Page number of the results to fetch. Default: 1"),
38+
),
39+
),
40+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
41+
// Extract optional parameters with defaults
42+
all, err := OptionalParamWithDefault[bool](request, "all", false)
43+
if err != nil {
44+
return mcp.NewToolResultError(err.Error()), nil
45+
}
46+
47+
participating, err := OptionalParamWithDefault[bool](request, "participating", false)
48+
if err != nil {
49+
return mcp.NewToolResultError(err.Error()), nil
50+
}
51+
52+
since, err := OptionalParam[string](request, "since")
53+
if err != nil {
54+
return mcp.NewToolResultError(err.Error()), nil
55+
}
56+
57+
before, err := OptionalParam[string](request, "before")
58+
if err != nil {
59+
return mcp.NewToolResultError(err.Error()), nil
60+
}
61+
62+
perPage, err := OptionalIntParamWithDefault(request, "per_page", 30)
63+
if err != nil {
64+
return mcp.NewToolResultError(err.Error()), nil
65+
}
66+
67+
page, err := OptionalIntParamWithDefault(request, "page", 1)
68+
if err != nil {
69+
return mcp.NewToolResultError(err.Error()), nil
70+
}
71+
72+
// Build options
73+
opts := &github.NotificationListOptions{
74+
All: all,
75+
Participating: participating,
76+
ListOptions: github.ListOptions{
77+
Page: page,
78+
PerPage: perPage,
79+
},
80+
}
81+
82+
// Parse time parameters if provided
83+
if since != "" {
84+
sinceTime, err := time.Parse(time.RFC3339, since)
85+
if err != nil {
86+
return mcp.NewToolResultError(fmt.Sprintf("invalid since time format, should be RFC3339/ISO8601: %v", err)), nil
87+
}
88+
opts.Since = sinceTime
89+
}
90+
91+
if before != "" {
92+
beforeTime, err := time.Parse(time.RFC3339, before)
93+
if err != nil {
94+
return mcp.NewToolResultError(fmt.Sprintf("invalid before time format, should be RFC3339/ISO8601: %v", err)), nil
95+
}
96+
opts.Before = beforeTime
97+
}
98+
99+
// Call GitHub API
100+
notifications, resp, err := client.Activity.ListNotifications(ctx, opts)
101+
if err != nil {
102+
return nil, fmt.Errorf("failed to get notifications: %w", err)
103+
}
104+
defer func() { _ = resp.Body.Close() }()
105+
106+
if resp.StatusCode != http.StatusOK {
107+
body, err := io.ReadAll(resp.Body)
108+
if err != nil {
109+
return nil, fmt.Errorf("failed to read response body: %w", err)
110+
}
111+
return mcp.NewToolResultError(fmt.Sprintf("failed to get notifications: %s", string(body))), nil
112+
}
113+
114+
// Marshal response to JSON
115+
r, err := json.Marshal(notifications)
116+
if err != nil {
117+
return nil, fmt.Errorf("failed to marshal response: %w", err)
118+
}
119+
120+
return mcp.NewToolResultText(string(r)), nil
121+
}
122+
}
123+
124+
// markNotificationRead creates a tool to mark a notification as read.
125+
func markNotificationRead(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
126+
return mcp.NewTool("mark_notification_read",
127+
mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_READ_DESCRIPTION", "Mark a notification as read")),
128+
mcp.WithString("threadID",
129+
mcp.Required(),
130+
mcp.Description("The ID of the notification thread"),
131+
),
132+
),
133+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
134+
threadID, err := requiredParam[string](request, "threadID")
135+
if err != nil {
136+
return mcp.NewToolResultError(err.Error()), nil
137+
}
138+
139+
resp, err := client.Activity.MarkThreadRead(ctx, threadID)
140+
if err != nil {
141+
return nil, fmt.Errorf("failed to mark notification as read: %w", err)
142+
}
143+
defer func() { _ = resp.Body.Close() }()
144+
145+
if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
146+
body, err := io.ReadAll(resp.Body)
147+
if err != nil {
148+
return nil, fmt.Errorf("failed to read response body: %w", err)
149+
}
150+
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as read: %s", string(body))), nil
151+
}
152+
153+
return mcp.NewToolResultText("Notification marked as read"), nil
154+
}
155+
}
156+
157+
// markAllNotificationsRead creates a tool to mark all notifications as read.
158+
func markAllNotificationsRead(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
159+
return mcp.NewTool("mark_all_notifications_read",
160+
mcp.WithDescription(t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read")),
161+
mcp.WithString("lastReadAt",
162+
mcp.Description("Describes the last point that notifications were checked (optional). Default: Now"),
163+
),
164+
),
165+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
166+
lastReadAt, err := OptionalParam[string](request, "lastReadAt")
167+
if err != nil {
168+
return mcp.NewToolResultError(err.Error()), nil
169+
}
170+
171+
var markReadOptions github.Timestamp
172+
if lastReadAt != "" {
173+
lastReadTime, err := time.Parse(time.RFC3339, lastReadAt)
174+
if err != nil {
175+
return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil
176+
}
177+
markReadOptions = github.Timestamp{
178+
Time: lastReadTime,
179+
}
180+
}
181+
182+
resp, err := client.Activity.MarkNotificationsRead(ctx, markReadOptions)
183+
if err != nil {
184+
return nil, fmt.Errorf("failed to mark all notifications as read: %w", err)
185+
}
186+
defer func() { _ = resp.Body.Close() }()
187+
188+
if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
189+
body, err := io.ReadAll(resp.Body)
190+
if err != nil {
191+
return nil, fmt.Errorf("failed to read response body: %w", err)
192+
}
193+
return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil
194+
}
195+
196+
return mcp.NewToolResultText("All notifications marked as read"), nil
197+
}
198+
}
199+
200+
// getNotificationThread creates a tool to get a specific notification thread.
201+
func getNotificationThread(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
202+
return mcp.NewTool("get_notification_thread",
203+
mcp.WithDescription(t("TOOL_GET_NOTIFICATION_THREAD_DESCRIPTION", "Get a specific notification thread")),
204+
mcp.WithString("threadID",
205+
mcp.Required(),
206+
mcp.Description("The ID of the notification thread"),
207+
),
208+
),
209+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
210+
threadID, err := requiredParam[string](request, "threadID")
211+
if err != nil {
212+
return mcp.NewToolResultError(err.Error()), nil
213+
}
214+
215+
thread, resp, err := client.Activity.GetThread(ctx, threadID)
216+
if err != nil {
217+
return nil, fmt.Errorf("failed to get notification thread: %w", err)
218+
}
219+
defer func() { _ = resp.Body.Close() }()
220+
221+
if resp.StatusCode != http.StatusOK {
222+
body, err := io.ReadAll(resp.Body)
223+
if err != nil {
224+
return nil, fmt.Errorf("failed to read response body: %w", err)
225+
}
226+
return mcp.NewToolResultError(fmt.Sprintf("failed to get notification thread: %s", string(body))), nil
227+
}
228+
229+
r, err := json.Marshal(thread)
230+
if err != nil {
231+
return nil, fmt.Errorf("failed to marshal response: %w", err)
232+
}
233+
234+
return mcp.NewToolResultText(string(r)), nil
235+
}
236+
}

pkg/github/server.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,26 @@ func OptionalParam[T any](r mcp.CallToolRequest, p string) (T, error) {
118118
return r.Params.Arguments[p].(T), nil
119119
}
120120

121+
// OptionalParam is a helper function that can be used to fetch a requested parameter from the request.
122+
// It does the following checks:
123+
// 1. Checks if the parameter is present in the request, if not, it returns its zero-value
124+
// 2. If it is present, it checks if the parameter is of the expected type and returns it
125+
func OptionalParam[T any](r mcp.CallToolRequest, p string) (T, error) {
126+
var zero T
127+
128+
// Check if the parameter is present in the request
129+
if _, ok := r.Params.Arguments[p]; !ok {
130+
return zero, nil
131+
}
132+
133+
// Check if the parameter is of the expected type
134+
if _, ok := r.Params.Arguments[p].(T); !ok {
135+
return zero, fmt.Errorf("parameter %s is not of type %T, is %T", p, zero, r.Params.Arguments[p])
136+
}
137+
138+
return r.Params.Arguments[p].(T), nil
139+
}
140+
121141
// OptionalIntParam is a helper function that can be used to fetch a requested parameter from the request.
122142
// It does the following checks:
123143
// 1. Checks if the parameter is present in the request, if not, it returns its zero-value

0 commit comments

Comments
 (0)