Skip to content
243 changes: 243 additions & 0 deletions pkg/github/notifications.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
package github

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"time"

"github.com/github/github-mcp-server/pkg/translations"
"github.com/google/go-github/v69/github"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)

// getNotifications creates a tool to list notifications for the current user.
func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("get_notifications",
mcp.WithDescription(t("TOOL_GET_NOTIFICATIONS_DESCRIPTION", "Get notifications for the authenticated GitHub user")),
mcp.WithBoolean("all",
mcp.Description("If true, show notifications marked as read. Default: false"),
),
mcp.WithBoolean("participating",
mcp.Description("If true, only shows notifications in which the user is directly participating or mentioned. Default: false"),
),
mcp.WithString("since",
mcp.Description("Only show notifications updated after the given time (ISO 8601 format)"),
),
mcp.WithString("before",
mcp.Description("Only show notifications updated before the given time (ISO 8601 format)"),
),
mcp.WithNumber("per_page",
mcp.Description("Results per page (max 100). Default: 30"),
),
mcp.WithNumber("page",
mcp.Description("Page number of the results to fetch. Default: 1"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}

// Extract optional parameters with defaults
all, err := OptionalBoolParamWithDefault(request, "all", false)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

participating, err := OptionalBoolParamWithDefault(request, "participating", false)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

since, err := OptionalStringParamWithDefault(request, "since", "")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

before, err := OptionalStringParam(request, "before")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

perPage, err := OptionalIntParamWithDefault(request, "per_page", 30)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

page, err := OptionalIntParamWithDefault(request, "page", 1)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

// Build options
opts := &github.NotificationListOptions{
All: all,
Participating: participating,
ListOptions: github.ListOptions{
Page: page,
PerPage: perPage,
},
}

// Parse time parameters if provided
if since != "" {
sinceTime, err := time.Parse(time.RFC3339, since)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("invalid since time format, should be RFC3339/ISO8601: %v", err)), nil
}
opts.Since = sinceTime
}

if before != "" {
beforeTime, err := time.Parse(time.RFC3339, before)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("invalid before time format, should be RFC3339/ISO8601: %v", err)), nil
}
opts.Before = beforeTime
}

// Call GitHub API
notifications, resp, err := client.Activity.ListNotifications(ctx, opts)
if err != nil {
return nil, fmt.Errorf("failed to get notifications: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to get notifications: %s", string(body))), nil
}

// Marshal response to JSON
r, err := json.Marshal(notifications)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
}
}

// ManageNotifications creates a tool to manage notifications (mark as read, mark all as read, or mark as done).
func ManageNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("manage_notifications",
mcp.WithDescription(t("TOOL_MANAGE_NOTIFICATIONS_DESCRIPTION", "Manage notifications (mark as read, mark all as read, or mark as done)")),
mcp.WithString("action",
mcp.Required(),
mcp.Description("The action to perform: 'mark_read', 'mark_all_read', or 'mark_done'"),
),
mcp.WithString("threadID",
mcp.Description("The ID of the notification thread (required for 'mark_read' and 'mark_done')"),
),
mcp.WithString("lastReadAt",
mcp.Description("Describes the last point that notifications were checked (optional, for 'mark_all_read'). Default: Now"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}

action, err := requiredParam[string](request, "action")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

switch action {
case "mark_read":
Copy link
Preview

Copilot AI Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 'mark_read' action does not call any API to mark a notification as read; it directly returns a success message. Please add the appropriate API call to actually mark the notification as read.

Copilot uses AI. Check for mistakes.

threadID, err := requiredParam[string](request, "threadID")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

resp, err := client.Activity.MarkThreadRead(ctx, threadID)
if err != nil {
return nil, fmt.Errorf("failed to mark notification as read: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as read: %s", string(body))), nil
}

return mcp.NewToolResultText("Notification marked as read"), nil

case "mark_done":
threadIDStr, err := requiredParam[string](request, "threadID")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

threadID, err := strconv.ParseInt(threadIDStr, 10, 64)
if err != nil {
return mcp.NewToolResultError("Invalid threadID: must be a numeric value"), nil
}

resp, err := client.Activity.MarkThreadDone(ctx, threadID)
if err != nil {
return nil, fmt.Errorf("failed to mark notification as done: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil
}

return mcp.NewToolResultText("Notification marked as done"), nil

case "mark_all_read":
lastReadAt, err := OptionalStringParam(request, "lastReadAt")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

var markReadOptions github.Timestamp
if lastReadAt != "" {
lastReadTime, err := time.Parse(time.RFC3339, lastReadAt)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil
}
markReadOptions = github.Timestamp{
Time: lastReadTime,
}
}

resp, err := client.Activity.MarkNotificationsRead(ctx, markReadOptions)
if err != nil {
return nil, fmt.Errorf("failed to mark all notifications as read: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil
}

return mcp.NewToolResultText("All notifications marked as read"), nil

default:
return mcp.NewToolResultError("Invalid action: must be 'mark_read', 'mark_all_read', or 'mark_done'"), nil
}
}
}
41 changes: 41 additions & 0 deletions pkg/github/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,47 @@ func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, e
return v, nil
}

// OptionalBoolParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
// similar to optionalParam, but it also takes a default value.
func OptionalBoolParamWithDefault(r mcp.CallToolRequest, p string, d bool) (bool, error) {
v, err := OptionalParam[bool](r, p)
if err != nil {
return false, err
}
if !v {
return d, nil
}
return v, nil
}

// OptionalStringParam is a helper function that can be used to fetch a requested parameter from the request.
// It does the following checks:
// 1. Checks if the parameter is present in the request, if not, it returns its zero-value
// 2. If it is present, it checks if the parameter is of the expected type and returns it
func OptionalStringParam(r mcp.CallToolRequest, p string) (string, error) {
v, err := OptionalParam[string](r, p)
if err != nil {
return "", err
}
if v == "" {
return "", nil
}
return v, nil
}

// OptionalStringParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
// similar to optionalParam, but it also takes a default value.
func OptionalStringParamWithDefault(r mcp.CallToolRequest, p string, d string) (string, error) {
v, err := OptionalParam[string](r, p)
if err != nil {
return "", err
}
if v == "" {
return d, nil
}
return v, nil
}

// OptionalStringArrayParam is a helper function that can be used to fetch a requested parameter from the request.
// It does the following checks:
// 1. Checks if the parameter is present in the request, if not, it returns its zero-value
Expand Down
10 changes: 10 additions & 0 deletions pkg/github/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
toolsets.NewServerTool(GetSecretScanningAlert(getClient, t)),
toolsets.NewServerTool(ListSecretScanningAlerts(getClient, t)),
)

notifications := toolsets.NewToolset("notifications", "GitHub Notifications related tools").
AddReadTools(
toolsets.NewServerTool(GetNotifications(getClient, t)),
).
AddWriteTools(
toolsets.NewServerTool(ManageNotifications(getClient, t)),
)

// Keep experiments alive so the system doesn't error out when it's always enabled
experiments := toolsets.NewToolset("experiments", "Experimental features that are not considered stable yet")

Expand All @@ -88,6 +97,7 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
tsg.AddToolset(pullRequests)
tsg.AddToolset(codeSecurity)
tsg.AddToolset(secretProtection)
tsg.AddToolset(notifications)
tsg.AddToolset(experiments)
// Enable the requested features

Expand Down