diff --git a/README.md b/README.md index 6ed566086..17beb1df3 100644 --- a/README.md +++ b/README.md @@ -867,6 +867,13 @@ The following sets of tools are available (all are on by default): - `repo`: Repository name (string, required) - `tag`: Tag name (e.g., 'v1.0.0') (string, required) +- **get_repository_tree** - Get repository tree + - `owner`: Repository owner (username or organization) (string, required) + - `path_filter`: Optional path prefix to filter the tree results (e.g., 'src/' to only show files in the src directory) (string, optional) + - `recursive`: Setting this parameter to true returns the objects or subtrees referenced by the tree. Default is false (boolean, optional) + - `repo`: Repository name (string, required) + - `tree_sha`: The SHA1 value or ref (branch or tag) name of the tree. Defaults to the repository's default branch (string, optional) + - **get_tag** - Get tag details - `owner`: Repository owner (string, required) - `repo`: Repository name (string, required) diff --git a/pkg/github/__toolsnaps__/get_repository_tree.snap b/pkg/github/__toolsnaps__/get_repository_tree.snap new file mode 100644 index 000000000..0645bf241 --- /dev/null +++ b/pkg/github/__toolsnaps__/get_repository_tree.snap @@ -0,0 +1,38 @@ +{ + "annotations": { + "title": "Get repository tree", + "readOnlyHint": true + }, + "description": "Get the tree structure (files and directories) of a GitHub repository at a specific ref or SHA", + "inputSchema": { + "properties": { + "owner": { + "description": "Repository owner (username or organization)", + "type": "string" + }, + "path_filter": { + "description": "Optional path prefix to filter the tree results (e.g., 'src/' to only show files in the src directory)", + "type": "string" + }, + "recursive": { + "default": false, + "description": "Setting this parameter to true returns the objects or subtrees referenced by the tree. Default is false", + "type": "boolean" + }, + "repo": { + "description": "Repository name", + "type": "string" + }, + "tree_sha": { + "description": "The SHA1 value or ref (branch or tag) name of the tree. Defaults to the repository's default branch", + "type": "string" + } + }, + "required": [ + "owner", + "repo" + ], + "type": "object" + }, + "name": "get_repository_tree" +} \ No newline at end of file diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 0622f3101..e16f0f2b9 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -677,6 +677,146 @@ func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t t } } +// GetRepositoryTree creates a tool to get the tree structure of a GitHub repository. +func GetRepositoryTree(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("get_repository_tree", + mcp.WithDescription(t("TOOL_GET_REPOSITORY_TREE_DESCRIPTION", "Get the tree structure (files and directories) of a GitHub repository at a specific ref or SHA")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_GET_REPOSITORY_TREE_USER_TITLE", "Get repository tree"), + ReadOnlyHint: ToBoolPtr(true), + }), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner (username or organization)"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithString("tree_sha", + mcp.Description("The SHA1 value or ref (branch or tag) name of the tree. Defaults to the repository's default branch"), + ), + mcp.WithBoolean("recursive", + mcp.Description("Setting this parameter to true returns the objects or subtrees referenced by the tree. Default is false"), + mcp.DefaultBool(false), + ), + mcp.WithString("path_filter", + mcp.Description("Optional path prefix to filter the tree results (e.g., 'src/' to only show files in the src directory)"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := RequiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := RequiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + treeSHA, err := OptionalParam[string](request, "tree_sha") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + recursive, err := OptionalBoolParamWithDefault(request, "recursive", false) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pathFilter, err := OptionalParam[string](request, "path_filter") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + client, err := getClient(ctx) + if err != nil { + return mcp.NewToolResultError("failed to get GitHub client"), nil + } + + // If no tree_sha is provided, use the repository's default branch + if treeSHA == "" { + repoInfo, _, err := client.Repositories.Get(ctx, owner, repo) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to get repository info: %s", err)), nil + } + treeSHA = *repoInfo.DefaultBranch + } + + // Get the tree using the GitHub Git Tree API + tree, resp, err := client.Git.GetTree(ctx, owner, repo, treeSHA, recursive) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get repository tree", + resp, + err, + ), nil + } + defer func() { _ = resp.Body.Close() }() + + // Filter tree entries if path_filter is provided + var filteredEntries []*github.TreeEntry + if pathFilter != "" { + for _, entry := range tree.Entries { + if strings.HasPrefix(entry.GetPath(), pathFilter) { + filteredEntries = append(filteredEntries, entry) + } + } + } else { + filteredEntries = tree.Entries + } + + type TreeEntryResponse struct { + Path string `json:"path"` + Type string `json:"type"` + Size *int `json:"size,omitempty"` + Mode string `json:"mode"` + SHA string `json:"sha"` + URL string `json:"url"` + } + + type TreeResponse struct { + SHA string `json:"sha"` + Truncated bool `json:"truncated"` + Tree []TreeEntryResponse `json:"tree"` + TreeSHA string `json:"tree_sha"` + Owner string `json:"owner"` + Repo string `json:"repo"` + Recursive bool `json:"recursive"` + Count int `json:"count"` + } + + treeEntries := make([]TreeEntryResponse, len(filteredEntries)) + for i, entry := range filteredEntries { + treeEntries[i] = TreeEntryResponse{ + Path: entry.GetPath(), + Type: entry.GetType(), + Mode: entry.GetMode(), + SHA: entry.GetSHA(), + URL: entry.GetURL(), + } + if entry.Size != nil { + treeEntries[i].Size = entry.Size + } + } + + response := TreeResponse{ + SHA: *tree.SHA, + Truncated: *tree.Truncated, + Tree: treeEntries, + TreeSHA: treeSHA, + Owner: owner, + Repo: repo, + Recursive: recursive, + Count: len(filteredEntries), + } + + r, err := json.Marshal(response) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + // ForkRepository creates a tool to fork a repository. func ForkRepository(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("fork_repository", diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index 11f11493c..1ab08031d 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -3192,3 +3192,178 @@ func Test_UnstarRepository(t *testing.T) { }) } } + +func Test_GetRepositoryTree(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := GetRepositoryTree(stubGetClientFn(mockClient), translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(tool.Name, tool)) + + assert.Equal(t, "get_repository_tree", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "tree_sha") + assert.Contains(t, tool.InputSchema.Properties, "recursive") + assert.Contains(t, tool.InputSchema.Properties, "path_filter") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"}) + + // Setup mock data + mockRepo := &github.Repository{ + DefaultBranch: github.Ptr("main"), + } + mockTree := &github.Tree{ + SHA: github.Ptr("abc123"), + Truncated: github.Ptr(false), + Entries: []*github.TreeEntry{ + { + Path: github.Ptr("README.md"), + Mode: github.Ptr("100644"), + Type: github.Ptr("blob"), + SHA: github.Ptr("file1sha"), + Size: github.Ptr(123), + URL: github.Ptr("https://api.github.com/repos/owner/repo/git/blobs/file1sha"), + }, + { + Path: github.Ptr("src/main.go"), + Mode: github.Ptr("100644"), + Type: github.Ptr("blob"), + SHA: github.Ptr("file2sha"), + Size: github.Ptr(456), + URL: github.Ptr("https://api.github.com/repos/owner/repo/git/blobs/file2sha"), + }, + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedErrMsg string + }{ + { + name: "successfully get repository tree", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposByOwnerByRepo, + mockResponse(t, http.StatusOK, mockRepo), + ), + mock.WithRequestMatchHandler( + mock.GetReposGitTreesByOwnerByRepoByTreeSha, + mockResponse(t, http.StatusOK, mockTree), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + }, + { + name: "successfully get repository tree with path filter", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposByOwnerByRepo, + mockResponse(t, http.StatusOK, mockRepo), + ), + mock.WithRequestMatchHandler( + mock.GetReposGitTreesByOwnerByRepoByTreeSha, + mockResponse(t, http.StatusOK, mockTree), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path_filter": "src/", + }, + }, + { + name: "repository not found", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "nonexistent", + }, + expectError: true, + expectedErrMsg: "failed to get repository info", + }, + { + name: "tree not found", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposByOwnerByRepo, + mockResponse(t, http.StatusOK, mockRepo), + ), + mock.WithRequestMatchHandler( + mock.GetReposGitTreesByOwnerByRepoByTreeSha, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + }, + expectError: true, + expectedErrMsg: "failed to get repository tree", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, handler := GetRepositoryTree(stubGetClientFromHTTPFn(tc.mockedClient), translations.NullTranslationHelper) + + // Create the tool request + request := createMCPRequest(tc.requestArgs) + + result, err := handler(context.Background(), request) + + if tc.expectError { + require.NoError(t, err) + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) + } else { + require.NoError(t, err) + require.False(t, result.IsError) + + // Parse the result and get the text content + textContent := getTextResult(t, result) + + // Parse the JSON response + var treeResponse map[string]interface{} + err := json.Unmarshal([]byte(textContent.Text), &treeResponse) + require.NoError(t, err) + + // Verify response structure + assert.Equal(t, "owner", treeResponse["owner"]) + assert.Equal(t, "repo", treeResponse["repo"]) + assert.Contains(t, treeResponse, "tree") + assert.Contains(t, treeResponse, "count") + assert.Contains(t, treeResponse, "sha") + assert.Contains(t, treeResponse, "truncated") + + // Check filtering if path_filter was provided + if pathFilter, exists := tc.requestArgs["path_filter"]; exists { + tree := treeResponse["tree"].([]interface{}) + for _, entry := range tree { + entryMap := entry.(map[string]interface{}) + path := entryMap["path"].(string) + assert.True(t, strings.HasPrefix(path, pathFilter.(string)), + "Path %s should start with filter %s", path, pathFilter) + } + } + } + }) + } +} diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 7fb5332aa..a170b4827 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -25,6 +25,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG AddReadTools( toolsets.NewServerTool(SearchRepositories(getClient, t)), toolsets.NewServerTool(GetFileContents(getClient, getRawClient, t)), + toolsets.NewServerTool(GetRepositoryTree(getClient, t)), toolsets.NewServerTool(ListCommits(getClient, t)), toolsets.NewServerTool(SearchCode(getClient, t)), toolsets.NewServerTool(GetCommit(getClient, t)),