diff --git a/README.md b/README.md index 891e63a81..a0f63e6eb 100644 --- a/README.md +++ b/README.md @@ -486,6 +486,9 @@ The following sets of tools are available (all are on by default): - `filename`: Filename for simple single-file gist creation (string, required) - `public`: Whether the gist is public (boolean, optional) +- **get_gist** - Get Gist Content + - `gist_id`: unique ID for the gist (string, required) + - **list_gists** - List Gists - `page`: Page number for pagination (min 1) (number, optional) - `perPage`: Results per page for pagination (min 1, max 100) (number, optional) diff --git a/pkg/github/gists.go b/pkg/github/gists.go index 53e85d5ba..d0831dcf2 100644 --- a/pkg/github/gists.go +++ b/pkg/github/gists.go @@ -89,6 +89,53 @@ func ListGists(getClient GetClientFn, t translations.TranslationHelperFunc) (too } } +// GetGist creates a tool to get the content of a gist +func GetGist(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_gist", + mcp.WithDescription(t("TOOL_GET_GIST_DESCRIPTION", "Get gist content of a particular gist ID")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_GET_GIST", "Get Gist Content"), + ReadOnlyHint: ToBoolPtr(true), + }), + mcp.WithString("gist_id", + mcp.Required(), + mcp.Description("Gist ID of a particular gist"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + gistID, err := RequiredParam[string](request, "gist_id") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + gist, resp, err := client.Gists.Get(ctx, gistID) + if err != nil { + return nil, fmt.Errorf("failed to get gist: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get gist: %s", string(body))), nil + } + + r, err := json.Marshal(gist) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + // CreateGist creates a tool to create a new gist func CreateGist(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("create_gist", diff --git a/pkg/github/gists_test.go b/pkg/github/gists_test.go index 9b8b4eb6e..e4c1e26a9 100644 --- a/pkg/github/gists_test.go +++ b/pkg/github/gists_test.go @@ -192,6 +192,115 @@ func Test_ListGists(t *testing.T) { } } +func Test_GetGist(t *testing.T) { + // Verify tool definition + mockClient := github.NewClient(nil) + tool, _ := GetGist(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "get_gist", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "gist_id") + + assert.Contains(t, tool.InputSchema.Required, "gist_id") + + // Setup mock gist for success case + mockGist := github.Gist{ + ID: github.Ptr("gist1"), + Description: github.Ptr("First Gist"), + HTMLURL: github.Ptr("https://gist.github.com/user/gist1"), + Public: github.Ptr(true), + CreatedAt: &github.Timestamp{Time: time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)}, + Owner: &github.User{Login: github.Ptr("user")}, + Files: map[github.GistFilename]github.GistFile{ + github.GistFilename("file1.txt"): { + Filename: github.Ptr("file1.txt"), + Content: github.Ptr("content of file 1"), + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedGists github.Gist + expectedErrMsg string + }{ + { + name: "Successful fetching different gist", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetGistsByGistId, + mockResponse(t, http.StatusOK, mockGist), + ), + ), + requestArgs: map[string]interface{}{ + "gist_id": "gist1", + }, + expectError: false, + expectedGists: mockGist, + }, + { + name: "gist_id parameter missing", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetGistsByGistId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Invalid Request"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{}, + expectError: true, + expectedErrMsg: "missing required parameter: gist_id", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := GetGist(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + if err != nil { + assert.Contains(t, err.Error(), tc.expectedErrMsg) + } else { + // For errors returned as part of the result, not as an error + assert.NotNil(t, result) + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + } + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedGists github.Gist + err = json.Unmarshal([]byte(textContent.Text), &returnedGists) + require.NoError(t, err) + + assert.Equal(t, *tc.expectedGists.ID, *returnedGists.ID) + assert.Equal(t, *tc.expectedGists.Description, *returnedGists.Description) + assert.Equal(t, *tc.expectedGists.HTMLURL, *returnedGists.HTMLURL) + assert.Equal(t, *tc.expectedGists.Public, *returnedGists.Public) + }) + } +} + func Test_CreateGist(t *testing.T) { // Verify tool definition mockClient := github.NewClient(nil) diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 0f294cef6..84c2ed9a8 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -184,6 +184,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG gists := toolsets.NewToolset("gists", "GitHub Gist related tools"). AddReadTools( toolsets.NewServerTool(ListGists(getClient, t)), + toolsets.NewServerTool(GetGist(getClient, t)), ). AddWriteTools( toolsets.NewServerTool(CreateGist(getClient, t)),