Skip to content

Commit ee6618e

Browse files
authored
Combine tools
1 parent bc897a1 commit ee6618e

File tree

1 file changed

+79
-136
lines changed

1 file changed

+79
-136
lines changed

pkg/github/notifications.go

Lines changed: 79 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -127,50 +127,19 @@ func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFun
127127
}
128128
}
129129

130-
// markNotificationRead creates a tool to mark a notification as read.
131-
func MarkNotificationRead(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
132-
return mcp.NewTool("mark_notification_read",
133-
mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_READ_DESCRIPTION", "Mark a notification as read")),
134-
mcp.WithString("threadID",
130+
// ManageNotifications creates a tool to manage notifications (mark as read, mark all as read, or mark as done).
131+
func ManageNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
132+
return mcp.NewTool("manage_notifications",
133+
mcp.WithDescription(t("TOOL_MANAGE_NOTIFICATIONS_DESCRIPTION", "Manage notifications (mark as read, mark all as read, or mark as done)")),
134+
mcp.WithString("action",
135135
mcp.Required(),
136-
mcp.Description("The ID of the notification thread"),
136+
mcp.Description("The action to perform: 'mark_read', 'mark_all_read', or 'mark_done'"),
137+
),
138+
mcp.WithString("threadID",
139+
mcp.Description("The ID of the notification thread (required for 'mark_read' and 'mark_done')"),
137140
),
138-
),
139-
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
140-
client, err := getclient(ctx)
141-
if err != nil {
142-
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
143-
}
144-
145-
threadID, err := requiredParam[string](request, "threadID")
146-
if err != nil {
147-
return mcp.NewToolResultError(err.Error()), nil
148-
}
149-
150-
resp, err := client.Activity.MarkThreadRead(ctx, threadID)
151-
if err != nil {
152-
return nil, fmt.Errorf("failed to mark notification as read: %w", err)
153-
}
154-
defer func() { _ = resp.Body.Close() }()
155-
156-
if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
157-
body, err := io.ReadAll(resp.Body)
158-
if err != nil {
159-
return nil, fmt.Errorf("failed to read response body: %w", err)
160-
}
161-
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as read: %s", string(body))), nil
162-
}
163-
164-
return mcp.NewToolResultText("Notification marked as read"), nil
165-
}
166-
}
167-
168-
// MarkAllNotificationsRead creates a tool to mark all notifications as read.
169-
func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
170-
return mcp.NewTool("mark_all_notifications_read",
171-
mcp.WithDescription(t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read")),
172141
mcp.WithString("lastReadAt",
173-
mcp.Description("Describes the last point that notifications were checked (optional). Default: Now"),
142+
mcp.Description("Describes the last point that notifications were checked (optional, for 'mark_all_read'). Default: Now"),
174143
),
175144
),
176145
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
@@ -179,122 +148,96 @@ func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationH
179148
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
180149
}
181150

182-
lastReadAt, err := OptionalStringParam(request, "lastReadAt")
151+
action, err := requiredParam[string](request, "action")
183152
if err != nil {
184153
return mcp.NewToolResultError(err.Error()), nil
185154
}
186155

187-
var markReadOptions github.Timestamp
188-
if lastReadAt != "" {
189-
lastReadTime, err := time.Parse(time.RFC3339, lastReadAt)
156+
switch action {
157+
case "mark_read":
158+
threadID, err := requiredParam[string](request, "threadID")
190159
if err != nil {
191-
return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil
192-
}
193-
markReadOptions = github.Timestamp{
194-
Time: lastReadTime,
160+
return mcp.NewToolResultError(err.Error()), nil
195161
}
196-
}
197-
198-
resp, err := client.Activity.MarkNotificationsRead(ctx, markReadOptions)
199-
if err != nil {
200-
return nil, fmt.Errorf("failed to mark all notifications as read: %w", err)
201-
}
202-
defer func() { _ = resp.Body.Close() }()
203162

204-
if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
205-
body, err := io.ReadAll(resp.Body)
163+
resp, err := client.Activity.MarkThreadRead(ctx, threadID)
206164
if err != nil {
207-
return nil, fmt.Errorf("failed to read response body: %w", err)
165+
return nil, fmt.Errorf("failed to mark notification as read: %w", err)
166+
}
167+
defer func() { _ = resp.Body.Close() }()
168+
169+
if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
170+
body, err := io.ReadAll(resp.Body)
171+
if err != nil {
172+
return nil, fmt.Errorf("failed to read response body: %w", err)
173+
}
174+
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as read: %s", string(body))), nil
208175
}
209-
return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil
210-
}
211-
212-
return mcp.NewToolResultText("All notifications marked as read"), nil
213-
}
214-
}
215-
216-
// GetNotificationThread creates a tool to get a specific notification thread.
217-
func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
218-
return mcp.NewTool("get_notification_thread",
219-
mcp.WithDescription(t("TOOL_GET_NOTIFICATION_THREAD_DESCRIPTION", "Get a specific notification thread")),
220-
mcp.WithString("threadID",
221-
mcp.Required(),
222-
mcp.Description("The ID of the notification thread"),
223-
),
224-
),
225-
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
226-
client, err := getClient(ctx)
227-
if err != nil {
228-
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
229-
}
230-
231-
threadID, err := requiredParam[string](request, "threadID")
232-
if err != nil {
233-
return mcp.NewToolResultError(err.Error()), nil
234-
}
235176

236-
thread, resp, err := client.Activity.GetThread(ctx, threadID)
237-
if err != nil {
238-
return nil, fmt.Errorf("failed to get notification thread: %w", err)
239-
}
240-
defer func() { _ = resp.Body.Close() }()
177+
return mcp.NewToolResultText("Notification marked as read"), nil
241178

242-
if resp.StatusCode != http.StatusOK {
243-
body, err := io.ReadAll(resp.Body)
179+
case "mark_done":
180+
threadIDStr, err := requiredParam[string](request, "threadID")
244181
if err != nil {
245-
return nil, fmt.Errorf("failed to read response body: %w", err)
182+
return mcp.NewToolResultError(err.Error()), nil
246183
}
247-
return mcp.NewToolResultError(fmt.Sprintf("failed to get notification thread: %s", string(body))), nil
248-
}
249-
250-
r, err := json.Marshal(thread)
251-
if err != nil {
252-
return nil, fmt.Errorf("failed to marshal response: %w", err)
253-
}
254184

255-
return mcp.NewToolResultText(string(r)), nil
256-
}
257-
}
185+
threadID, err := strconv.ParseInt(threadIDStr, 10, 64)
186+
if err != nil {
187+
return mcp.NewToolResultError("Invalid threadID: must be a numeric value"), nil
188+
}
258189

259-
// markNotificationDone creates a tool to mark a notification as done.
260-
func MarkNotificationDone(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
261-
return mcp.NewTool("mark_notification_done",
262-
mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_DONE_DESCRIPTION", "Mark a notification as done")),
263-
mcp.WithString("threadID",
264-
mcp.Required(),
265-
mcp.Description("The ID of the notification thread"),
266-
),
267-
),
268-
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
269-
client, err := getclient(ctx)
270-
if err != nil {
271-
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
272-
}
190+
resp, err := client.Activity.MarkThreadDone(ctx, threadID)
191+
if err != nil {
192+
return nil, fmt.Errorf("failed to mark notification as done: %w", err)
193+
}
194+
defer func() { _ = resp.Body.Close() }()
195+
196+
if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
197+
body, err := io.ReadAll(resp.Body)
198+
if err != nil {
199+
return nil, fmt.Errorf("failed to read response body: %w", err)
200+
}
201+
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as done: %s", string(body))), nil
202+
}
273203

274-
threadIDStr, err := requiredParam[string](request, "threadID")
275-
if err != nil {
276-
return mcp.NewToolResultError(err.Error()), nil
277-
}
204+
return mcp.NewToolResultText("Notification marked as done"), nil
278205

279-
threadID, err := strconv.ParseInt(threadIDStr, 10, 64)
280-
if err != nil {
281-
return mcp.NewToolResultError("Invalid threadID: must be a numeric value"), nil
282-
}
206+
case "mark_all_read":
207+
lastReadAt, err := OptionalStringParam(request, "lastReadAt")
208+
if err != nil {
209+
return mcp.NewToolResultError(err.Error()), nil
210+
}
283211

284-
resp, err := client.Activity.MarkThreadDone(ctx, threadID)
285-
if err != nil {
286-
return nil, fmt.Errorf("failed to mark notification as done: %w", err)
287-
}
288-
defer func() { _ = resp.Body.Close() }()
212+
var markReadOptions github.Timestamp
213+
if lastReadAt != "" {
214+
lastReadTime, err := time.Parse(time.RFC3339, lastReadAt)
215+
if err != nil {
216+
return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil
217+
}
218+
markReadOptions = github.Timestamp{
219+
Time: lastReadTime,
220+
}
221+
}
289222

290-
if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
291-
body, err := io.ReadAll(resp.Body)
223+
resp, err := client.Activity.MarkNotificationsRead(ctx, markReadOptions)
292224
if err != nil {
293-
return nil, fmt.Errorf("failed to read response body: %w", err)
225+
return nil, fmt.Errorf("failed to mark all notifications as read: %w", err)
294226
}
295-
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as done: %s", string(body))), nil
296-
}
227+
defer func() { _ = resp.Body.Close() }()
228+
229+
if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
230+
body, err := io.ReadAll(resp.Body)
231+
if err != nil {
232+
return nil, fmt.Errorf("failed to read response body: %w", err)
233+
}
234+
return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil
235+
}
236+
237+
return mcp.NewToolResultText("All notifications marked as read"), nil
297238

298-
return mcp.NewToolResultText("Notification marked as done"), nil
239+
default:
240+
return mcp.NewToolResultError("Invalid action: must be 'mark_read', 'mark_all_read', or 'mark_done'"), nil
241+
}
299242
}
300243
}

0 commit comments

Comments
 (0)