diff --git a/go.mod b/go.mod index 661778fc3..686558535 100644 --- a/go.mod +++ b/go.mod @@ -56,3 +56,5 @@ require ( gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/modelcontextprotocol/go-sdk => github.com/SamMorrowDrums/go-sdk v0.0.0-20251204132411-f66cde03f0bc diff --git a/go.sum b/go.sum index e422a548c..875ce37c2 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/SamMorrowDrums/go-sdk v0.0.0-20251204132411-f66cde03f0bc h1:GbuI2fLul69iqi2/f/OhWBiWXmZkP3R7h+ijwtZnqzY= +github.com/SamMorrowDrums/go-sdk v0.0.0-20251204132411-f66cde03f0bc/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= @@ -55,8 +57,6 @@ github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwX github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA= github.com/migueleliasweb/go-github-mock v1.3.0 h1:2sVP9JEMB2ubQw1IKto3/fzF51oFC6eVWOOFDgQoq88= github.com/migueleliasweb/go-github-mock v1.3.0/go.mod h1:ipQhV8fTcj/G6m7BKzin08GaJ/3B5/SonRAkgrk0zCY= -github.com/modelcontextprotocol/go-sdk v1.1.0 h1:Qjayg53dnKC4UZ+792W21e4BpwEZBzwgRW6LrjLWSwA= -github.com/modelcontextprotocol/go-sdk v1.1.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= github.com/muesli/cache2go v0.0.0-20221011235721-518229cd8021 h1:31Y+Yu373ymebRdJN1cWLLooHH8xAr0MhKTEJGV/87g= github.com/muesli/cache2go v0.0.0-20221011235721-518229cd8021/go.mod h1:WERUkUryfUWlrHnFSO/BEUZ+7Ns8aZy7iVOGewxKzcc= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= diff --git a/pkg/github/helper_test.go b/pkg/github/helper_test.go index 9c55ba841..60a7d1a00 100644 --- a/pkg/github/helper_test.go +++ b/pkg/github/helper_test.go @@ -10,6 +10,18 @@ import ( "github.com/stretchr/testify/require" ) +// mapToTypedInput converts a map[string]interface{} to a typed struct using JSON marshaling. +// This is useful for tests that need to pass typed input to handlers. +func mapToTypedInput[T any](t *testing.T, m map[string]interface{}) T { + t.Helper() + var result T + jsonBytes, err := json.Marshal(m) + require.NoError(t, err, "failed to marshal map to JSON") + err = json.Unmarshal(jsonBytes, &result) + require.NoError(t, err, "failed to unmarshal JSON to typed input") + return result +} + type expectations struct { path string queryParams map[string]string diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 46111a4d6..8caac212c 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -229,38 +229,7 @@ func fragmentToIssue(fragment IssueFragment) *github.Issue { } // IssueRead creates a tool to get details of a specific issue in a GitHub repository. -func IssueRead(getClient GetClientFn, getGQLClient GetGQLClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - schema := &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "method": { - Type: "string", - Description: `The read operation to perform on a single issue. -Options are: -1. get - Get details of a specific issue. -2. get_comments - Get issue comments. -3. get_sub_issues - Get sub-issues of the issue. -4. get_labels - Get labels assigned to the issue. -`, - Enum: []any{"get", "get_comments", "get_sub_issues", "get_labels"}, - }, - "owner": { - Type: "string", - Description: "The owner of the repository", - }, - "repo": { - Type: "string", - Description: "The name of the repository", - }, - "issue_number": { - Type: "number", - Description: "The number of the issue", - }, - }, - Required: []string{"method", "owner", "repo", "issue_number"}, - } - WithPagination(schema) - +func IssueRead(getClient GetClientFn, getGQLClient GetGQLClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, mcp.ToolHandlerFor[IssueReadInput, any]) { return mcp.Tool{ Name: "issue_read", Description: t("TOOL_ISSUE_READ_DESCRIPTION", "Get information about a specific issue in a GitHub repository."), @@ -268,30 +237,21 @@ Options are: Title: t("TOOL_ISSUE_READ_USER_TITLE", "Get issue details"), ReadOnlyHint: true, }, - InputSchema: schema, + InputSchema: IssueReadInput{}.MCPSchema(), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - method, err := RequiredParam[string](args, "method") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + func(ctx context.Context, _ *mcp.CallToolRequest, input IssueReadInput) (*mcp.CallToolResult, any, error) { + // Set pagination defaults + page := input.Page + if page == 0 { + page = 1 } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + perPage := input.PerPage + if perPage == 0 { + perPage = 30 } - issueNumber, err := RequiredInt(args, "issue_number") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + pagination := PaginationParams{ + Page: page, + PerPage: perPage, } client, err := getClient(ctx) @@ -304,21 +264,21 @@ Options are: return utils.NewToolResultErrorFromErr("failed to get GitHub graphql client", err), nil, nil } - switch method { + switch input.Method { case "get": - result, err := GetIssue(ctx, client, cache, owner, repo, issueNumber, flags) + result, err := GetIssue(ctx, client, cache, input.Owner, input.Repo, input.IssueNumber, flags) return result, nil, err case "get_comments": - result, err := GetIssueComments(ctx, client, cache, owner, repo, issueNumber, pagination, flags) + result, err := GetIssueComments(ctx, client, cache, input.Owner, input.Repo, input.IssueNumber, pagination, flags) return result, nil, err case "get_sub_issues": - result, err := GetSubIssues(ctx, client, cache, owner, repo, issueNumber, pagination, flags) + result, err := GetSubIssues(ctx, client, cache, input.Owner, input.Repo, input.IssueNumber, pagination, flags) return result, nil, err case "get_labels": - result, err := GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber) + result, err := GetIssueLabels(ctx, gqlClient, input.Owner, input.Repo, input.IssueNumber) return result, nil, err default: - return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil + return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", input.Method)), nil, nil } } } @@ -1313,49 +1273,7 @@ func UpdateIssue(ctx context.Context, client *github.Client, gqlClient *githubv4 } // ListIssues creates a tool to list and filter repository issues -func ListIssues(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - schema := &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "state": { - Type: "string", - Description: "Filter by state, by default both open and closed issues are returned when not provided", - Enum: []any{"OPEN", "CLOSED"}, - }, - "labels": { - Type: "array", - Description: "Filter by labels", - Items: &jsonschema.Schema{ - Type: "string", - }, - }, - "orderBy": { - Type: "string", - Description: "Order issues by field. If provided, the 'direction' also needs to be provided.", - Enum: []any{"CREATED_AT", "UPDATED_AT", "COMMENTS"}, - }, - "direction": { - Type: "string", - Description: "Order direction. If provided, the 'orderBy' also needs to be provided.", - Enum: []any{"ASC", "DESC"}, - }, - "since": { - Type: "string", - Description: "Filter by date (ISO 8601 timestamp)", - }, - }, - Required: []string{"owner", "repo"}, - } - WithCursorPagination(schema) - +func ListIssues(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[ListIssuesInput, any]) { return mcp.Tool{ Name: "list_issues", Description: t("TOOL_LIST_ISSUES_DESCRIPTION", "List issues in a GitHub repository. For pagination, use the 'endCursor' from the previous response's 'pageInfo' in the 'after' parameter."), @@ -1363,47 +1281,20 @@ func ListIssues(getGQLClient GetGQLClientFn, t translations.TranslationHelperFun Title: t("TOOL_LIST_ISSUES_USER_TITLE", "List issues"), ReadOnlyHint: true, }, - InputSchema: schema, + InputSchema: ListIssuesInput{}.MCPSchema(), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - // Set optional parameters if provided - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - + func(ctx context.Context, _ *mcp.CallToolRequest, input ListIssuesInput) (*mcp.CallToolResult, any, error) { // If the state has a value, cast into an array of strings var states []githubv4.IssueState - if state != "" { - states = append(states, githubv4.IssueState(state)) + if input.State != "" { + states = append(states, githubv4.IssueState(input.State)) } else { states = []githubv4.IssueState{githubv4.IssueStateOpen, githubv4.IssueStateClosed} } - // Get labels - labels, err := OptionalStringArrayParam(args, "labels") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - orderBy, err := OptionalParam[string](args, "orderBy") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - direction, err := OptionalParam[string](args, "direction") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + labels := input.Labels + orderBy := input.OrderBy + direction := input.Direction // These variables are required for the GraphQL query to be set by default // If orderBy is empty, default to CREATED_AT @@ -1415,16 +1306,12 @@ func ListIssues(getGQLClient GetGQLClientFn, t translations.TranslationHelperFun direction = "DESC" } - since, err := OptionalParam[string](args, "since") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - // There are two optional parameters: since and labels. var sinceTime time.Time var hasSince bool - if since != "" { - sinceTime, err = parseISOTimestamp(since) + if input.Since != "" { + var err error + sinceTime, err = parseISOTimestamp(input.Since) if err != nil { return utils.NewToolResultError(fmt.Sprintf("failed to list issues: %s", err.Error())), nil, nil } @@ -1433,39 +1320,28 @@ func ListIssues(getGQLClient GetGQLClientFn, t translations.TranslationHelperFun hasLabels := len(labels) > 0 // Get pagination parameters and convert to GraphQL format - pagination, err := OptionalCursorPaginationParams(args) - if err != nil { - return nil, nil, err + perPage := input.PerPage + if perPage == 0 { + perPage = 30 } - - // Check if someone tried to use page-based pagination instead of cursor-based - if _, pageProvided := args["page"]; pageProvided { - return utils.NewToolResultError("This tool uses cursor-based pagination. Use the 'after' parameter with the 'endCursor' value from the previous response instead of 'page'."), nil, nil + pagination := CursorPaginationParams{ + PerPage: perPage, + After: input.After, } - // Check if pagination parameters were explicitly provided - _, perPageProvided := args["perPage"] - paginationExplicit := perPageProvided - paginationParams, err := pagination.ToGraphQLParams() if err != nil { return nil, nil, err } - // Use default of 30 if pagination was not explicitly provided - if !paginationExplicit { - defaultFirst := int32(DefaultGraphQLPageSize) - paginationParams.First = &defaultFirst - } - client, err := getGQLClient(ctx) if err != nil { return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil } vars := map[string]interface{}{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), + "owner": githubv4.String(input.Owner), + "repo": githubv4.String(input.Repo), "states": states, "orderBy": githubv4.IssueOrderField(orderBy), "direction": githubv4.OrderDirection(direction), diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index 48901ccdc..a268d6aef 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -334,7 +334,8 @@ func Test_GetIssue(t *testing.T) { _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), cache, translations.NullTranslationHelper, flags) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + typedInput := mapToTypedInput[IssueReadInput](t, tc.requestArgs) + result, _, err := handler(context.Background(), &request, typedInput) if tc.expectHandlerError { require.Error(t, err) @@ -1244,7 +1245,8 @@ func Test_ListIssues(t *testing.T) { _, handler := ListIssues(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) req := createMCPRequest(tc.reqParams) - res, _, err := handler(context.Background(), &req, tc.reqParams) + typedInput := mapToTypedInput[ListIssuesInput](t, tc.reqParams) + res, _, err := handler(context.Background(), &req, typedInput) text := getTextResult(t, res).Text if tc.expectError { @@ -1988,9 +1990,10 @@ func Test_GetIssueComments(t *testing.T) { // Create call request request := createMCPRequest(tc.requestArgs) + typedInput := mapToTypedInput[IssueReadInput](t, tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, _, err := handler(context.Background(), &request, typedInput) // Verify results if tc.expectError { @@ -2102,7 +2105,8 @@ func Test_GetIssueLabels(t *testing.T) { _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + typedInput := mapToTypedInput[IssueReadInput](t, tc.requestArgs) + result, _, err := handler(context.Background(), &request, typedInput) require.NoError(t, err) assert.NotNil(t, result) @@ -2991,9 +2995,10 @@ func Test_GetSubIssues(t *testing.T) { // Create call request request := createMCPRequest(tc.requestArgs) + typedInput := mapToTypedInput[IssueReadInput](t, tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, _, err := handler(context.Background(), &request, typedInput) // Verify results if tc.expectError { diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 661384529..1bfaa78dc 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -21,41 +21,7 @@ import ( ) // PullRequestRead creates a tool to get details of a specific pull request. -func PullRequestRead(getClient GetClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - schema := &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "method": { - Type: "string", - Description: `Action to specify what pull request data needs to be retrieved from GitHub. -Possible options: - 1. get - Get details of a specific pull request. - 2. get_diff - Get the diff of a pull request. - 3. get_status - Get status of a head commit in a pull request. This reflects status of builds and checks. - 4. get_files - Get the list of files changed in a pull request. Use with pagination parameters to control the number of results returned. - 5. get_review_comments - Get the review comments on a pull request. They are comments made on a portion of the unified diff during a pull request review. Use with pagination parameters to control the number of results returned. - 6. get_reviews - Get the reviews on a pull request. When asked for review comments, use get_review_comments method. - 7. get_comments - Get comments on a pull request. Use this if user doesn't specifically want review comments. Use with pagination parameters to control the number of results returned. -`, - Enum: []any{"get", "get_diff", "get_status", "get_files", "get_review_comments", "get_reviews", "get_comments"}, - }, - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "pullNumber": { - Type: "number", - Description: "Pull request number", - }, - }, - Required: []string{"method", "owner", "repo", "pullNumber"}, - } - WithPagination(schema) - +func PullRequestRead(getClient GetClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, mcp.ToolHandlerFor[PullRequestReadInput, any]) { return mcp.Tool{ Name: "pull_request_read", Description: t("TOOL_PULL_REQUEST_READ_DESCRIPTION", "Get information on a specific pull request in GitHub repository."), @@ -63,29 +29,21 @@ Possible options: Title: t("TOOL_GET_PULL_REQUEST_USER_TITLE", "Get details for a single pull request"), ReadOnlyHint: true, }, - InputSchema: schema, + InputSchema: PullRequestReadInput{}.MCPSchema(), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - method, err := RequiredParam[string](args, "method") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + func(ctx context.Context, _ *mcp.CallToolRequest, input PullRequestReadInput) (*mcp.CallToolResult, any, error) { + // Set pagination defaults + page := input.Page + if page == 0 { + page = 1 } - pullNumber, err := RequiredInt(args, "pullNumber") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + perPage := input.PerPage + if perPage == 0 { + perPage = 30 } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + pagination := PaginationParams{ + Page: page, + PerPage: perPage, } client, err := getClient(ctx) @@ -93,30 +51,30 @@ Possible options: return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } - switch method { + switch input.Method { case "get": - result, err := GetPullRequest(ctx, client, cache, owner, repo, pullNumber, flags) + result, err := GetPullRequest(ctx, client, cache, input.Owner, input.Repo, input.PullNumber, flags) return result, nil, err case "get_diff": - result, err := GetPullRequestDiff(ctx, client, owner, repo, pullNumber) + result, err := GetPullRequestDiff(ctx, client, input.Owner, input.Repo, input.PullNumber) return result, nil, err case "get_status": - result, err := GetPullRequestStatus(ctx, client, owner, repo, pullNumber) + result, err := GetPullRequestStatus(ctx, client, input.Owner, input.Repo, input.PullNumber) return result, nil, err case "get_files": - result, err := GetPullRequestFiles(ctx, client, owner, repo, pullNumber, pagination) + result, err := GetPullRequestFiles(ctx, client, input.Owner, input.Repo, input.PullNumber, pagination) return result, nil, err case "get_review_comments": - result, err := GetPullRequestReviewComments(ctx, client, cache, owner, repo, pullNumber, pagination, flags) + result, err := GetPullRequestReviewComments(ctx, client, cache, input.Owner, input.Repo, input.PullNumber, pagination, flags) return result, nil, err case "get_reviews": - result, err := GetPullRequestReviews(ctx, client, cache, owner, repo, pullNumber, flags) + result, err := GetPullRequestReviews(ctx, client, cache, input.Owner, input.Repo, input.PullNumber, flags) return result, nil, err case "get_comments": - result, err := GetIssueComments(ctx, client, cache, owner, repo, pullNumber, pagination, flags) + result, err := GetIssueComments(ctx, client, cache, input.Owner, input.Repo, input.PullNumber, pagination, flags) return result, nil, err default: - return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil + return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", input.Method)), nil, nil } } } @@ -813,46 +771,7 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra } // ListPullRequests creates a tool to list and filter repository pull requests. -func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - schema := &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "state": { - Type: "string", - Description: "Filter by state", - Enum: []any{"open", "closed", "all"}, - }, - "head": { - Type: "string", - Description: "Filter by head user/org and branch", - }, - "base": { - Type: "string", - Description: "Filter by base branch", - }, - "sort": { - Type: "string", - Description: "Sort by", - Enum: []any{"created", "updated", "popularity", "long-running"}, - }, - "direction": { - Type: "string", - Description: "Sort direction", - Enum: []any{"asc", "desc"}, - }, - }, - Required: []string{"owner", "repo"}, - } - WithPagination(schema) - +func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[ListPullRequestsInput, any]) { return mcp.Tool{ Name: "list_pull_requests", Description: t("TOOL_LIST_PULL_REQUESTS_DESCRIPTION", "List pull requests in a GitHub repository. If the user specifies an author, then DO NOT use this tool and use the search_pull_requests tool instead."), @@ -860,51 +779,28 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun Title: t("TOOL_LIST_PULL_REQUESTS_USER_TITLE", "List pull requests"), ReadOnlyHint: true, }, - InputSchema: schema, + InputSchema: ListPullRequestsInput{}.MCPSchema(), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - head, err := OptionalParam[string](args, "head") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - base, err := OptionalParam[string](args, "base") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sort, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - direction, err := OptionalParam[string](args, "direction") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + func(ctx context.Context, _ *mcp.CallToolRequest, input ListPullRequestsInput) (*mcp.CallToolResult, any, error) { + // Set pagination defaults + page := input.Page + if page == 0 { + page = 1 } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + perPage := input.PerPage + if perPage == 0 { + perPage = 30 } opts := &github.PullRequestListOptions{ - State: state, - Head: head, - Base: base, - Sort: sort, - Direction: direction, + State: input.State, + Head: input.Head, + Base: input.Base, + Sort: input.Sort, + Direction: input.Direction, ListOptions: github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, + PerPage: perPage, + Page: page, }, } @@ -912,7 +808,7 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } - prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts) + prs, resp, err := client.PullRequests.List(ctx, input.Owner, input.Repo, opts) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list pull requests", diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 94313d4e3..4374e7361 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -27,7 +27,7 @@ func Test_GetPullRequest(t *testing.T) { assert.Equal(t, "pull_request_read", tool.Name) assert.NotEmpty(t, tool.Description) - schema := tool.InputSchema.(*jsonschema.Schema) + schema := tool.InputSchema assert.Contains(t, schema.Properties, "method") assert.Contains(t, schema.Properties, "owner") assert.Contains(t, schema.Properties, "repo") @@ -109,6 +109,7 @@ func Test_GetPullRequest(t *testing.T) { // Create call request request := createMCPRequest(tc.requestArgs) + typedInput := mapToTypedInput[PullRequestReadInput](t, tc.requestArgs) // Call handler result, _, err := handler(context.Background(), &request, tc.requestArgs) diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index dbf24e8e3..eca2e403d 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -18,7 +18,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[GetCommitInput, any]) { tool := mcp.Tool{ Name: "get_commit", Description: t("TOOL_GET_COMMITS_DESCRIPTION", "Get details for a commit from a GitHub repository"), @@ -26,66 +26,39 @@ func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp Title: t("TOOL_GET_COMMITS_USER_TITLE", "Get commit details"), ReadOnlyHint: true, }, - InputSchema: WithPagination(&jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "sha": { - Type: "string", - Description: "Commit SHA, branch name, or tag name", - }, - "include_diff": { - Type: "boolean", - Description: "Whether to include file diffs and stats in the response. Default is true.", - Default: json.RawMessage(`true`), - }, - }, - Required: []string{"owner", "repo", "sha"}, - }), + InputSchema: GetCommitInput{}.MCPSchema(), } - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + handler := mcp.ToolHandlerFor[GetCommitInput, any](func(ctx context.Context, _ *mcp.CallToolRequest, input GetCommitInput) (*mcp.CallToolResult, any, error) { + // Set pagination defaults + page := input.Page + if page == 0 { + page = 1 } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sha, err := RequiredParam[string](args, "sha") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - includeDiff, err := OptionalBoolParamWithDefault(args, "include_diff", true) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + perPage := input.PerPage + if perPage == 0 { + perPage = 30 } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + + // Get includeDiff with default value of true + includeDiff := true + if input.IncludeDiff != nil { + includeDiff = *input.IncludeDiff } opts := &github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, + Page: page, + PerPage: perPage, } client, err := getClient(ctx) if err != nil { return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) } - commit, resp, err := client.Repositories.GetCommit(ctx, owner, repo, sha, opts) + commit, resp, err := client.Repositories.GetCommit(ctx, input.Owner, input.Repo, input.SHA, opts) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to get commit: %s", sha), + fmt.Sprintf("failed to get commit: %s", input.SHA), resp, err, ), nil, nil @@ -115,7 +88,7 @@ func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp } // ListCommits creates a tool to get commits of a branch in a repository. -func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[ListCommitsInput, any]) { tool := mcp.Tool{ Name: "list_commits", Description: t("TOOL_LIST_COMMITS_DESCRIPTION", "Get list of commits of a branch in a GitHub repository. Returns at least 30 results per page by default, but can return more if specified using the perPage parameter (up to 100)."), @@ -123,61 +96,25 @@ func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (m Title: t("TOOL_LIST_COMMITS_USER_TITLE", "List commits"), ReadOnlyHint: true, }, - InputSchema: WithPagination(&jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "sha": { - Type: "string", - Description: "Commit SHA, branch or tag name to list commits of. If not provided, uses the default branch of the repository. If a commit SHA is provided, will list commits up to that SHA.", - }, - "author": { - Type: "string", - Description: "Author username or email address to filter commits by", - }, - }, - Required: []string{"owner", "repo"}, - }), + InputSchema: ListCommitsInput{}.MCPSchema(), } - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sha, err := OptionalParam[string](args, "sha") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - author, err := OptionalParam[string](args, "author") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + handler := mcp.ToolHandlerFor[ListCommitsInput, any](func(ctx context.Context, _ *mcp.CallToolRequest, input ListCommitsInput) (*mcp.CallToolResult, any, error) { + // Set pagination defaults + page := input.Page + if page == 0 { + page = 1 } // Set default perPage to 30 if not provided - perPage := pagination.PerPage + perPage := input.PerPage if perPage == 0 { perPage = 30 } opts := &github.CommitsListOptions{ - SHA: sha, - Author: author, + SHA: input.SHA, + Author: input.Author, ListOptions: github.ListOptions{ - Page: pagination.Page, + Page: page, PerPage: perPage, }, } @@ -186,10 +123,10 @@ func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (m if err != nil { return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) } - commits, resp, err := client.Repositories.ListCommits(ctx, owner, repo, opts) + commits, resp, err := client.Repositories.ListCommits(ctx, input.Owner, input.Repo, opts) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to list commits: %s", sha), + fmt.Sprintf("failed to list commits: %s", input.SHA), resp, err, ), nil, nil @@ -537,7 +474,7 @@ func CreateRepository(getClient GetClientFn, t translations.TranslationHelperFun } // GetFileContents creates a tool to get the contents of a file or directory from a GitHub repository. -func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[GetFileContentsInput, any]) { tool := mcp.Tool{ Name: "get_file_contents", Description: t("TOOL_GET_FILE_CONTENTS_DESCRIPTION", "Get the contents of a file or directory from a GitHub repository"), @@ -545,55 +482,19 @@ func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t t Title: t("TOOL_GET_FILE_CONTENTS_USER_TITLE", "Get file or directory contents"), ReadOnlyHint: true, }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner (username or organization)", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "path": { - Type: "string", - Description: "Path to file/directory (directories must end with a slash '/')", - Default: json.RawMessage(`"/"`), - }, - "ref": { - Type: "string", - Description: "Accepts optional git refs such as `refs/tags/{tag}`, `refs/heads/{branch}` or `refs/pull/{pr_number}/head`", - }, - "sha": { - Type: "string", - Description: "Accepts optional commit SHA. If specified, it will be used instead of ref", - }, - }, - Required: []string{"owner", "repo"}, - }, + InputSchema: GetFileContentsInput{}.MCPSchema(), } - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - path, err := RequiredParam[string](args, "path") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - ref, err := OptionalParam[string](args, "ref") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sha, err := OptionalParam[string](args, "sha") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + handler := mcp.ToolHandlerFor[GetFileContentsInput, any](func(ctx context.Context, _ *mcp.CallToolRequest, input GetFileContentsInput) (*mcp.CallToolResult, any, error) { + owner := input.Owner + repo := input.Repo + path := input.Path + ref := input.Ref + sha := input.SHA + + // Apply default for path if empty + if path == "" { + path = "/" } client, err := getClient(ctx) diff --git a/pkg/github/schema_providers.go b/pkg/github/schema_providers.go new file mode 100644 index 000000000..d2ffe46f9 --- /dev/null +++ b/pkg/github/schema_providers.go @@ -0,0 +1,690 @@ +package github + +import ( + "encoding/json" + "sync" + + "github.com/google/jsonschema-go/jsonschema" +) + +// This file contains typed input structs that implement SchemaProvider and ResolvedSchemaProvider +// interfaces for high-traffic MCP tools. This provides maximum performance by: +// 1. Avoiding reflection for schema generation +// 2. Pre-resolving schemas to skip the resolution step entirely +// +// Each input struct provides: +// - MCPSchema() - returns the pre-computed JSON schema +// - MCPResolvedSchema() - returns the pre-resolved schema ready for validation + +// schemaCache provides thread-safe lazy initialization of resolved schemas +type schemaCache struct { + once sync.Once + resolved *jsonschema.Resolved +} + +func (c *schemaCache) get(schema *jsonschema.Schema) *jsonschema.Resolved { + c.once.Do(func() { + var err error + c.resolved, err = schema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) + if err != nil { + // This should never happen with well-formed schemas + panic("failed to resolve schema: " + err.Error()) + } + }) + return c.resolved +} + +// ============================================================================ +// SearchRepositoriesInput - for search_repositories tool +// ============================================================================ + +// SearchRepositoriesInput is the typed input for the search_repositories tool. +type SearchRepositoriesInput struct { + Query string `json:"query"` + Sort string `json:"sort,omitempty"` + Order string `json:"order,omitempty"` + MinimalOutput *bool `json:"minimal_output,omitempty"` + Page int `json:"page,omitempty"` + PerPage int `json:"perPage,omitempty"` +} + +var ( + searchRepositoriesSchema = &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "query": { + Type: "string", + Description: "Repository search query. Examples: 'machine learning in:name stars:>1000 language:python', 'topic:react', 'user:facebook'. Supports advanced search syntax for precise filtering.", + }, + "sort": { + Type: "string", + Description: "Sort repositories by field, defaults to best match", + Enum: []any{"stars", "forks", "help-wanted-issues", "updated"}, + }, + "order": { + Type: "string", + Description: "Sort order", + Enum: []any{"asc", "desc"}, + }, + "minimal_output": { + Type: "boolean", + Description: "Return minimal repository information (default: true). When false, returns full GitHub API repository objects.", + Default: json.RawMessage(`true`), + }, + "page": { + Type: "number", + Description: "Page number for pagination (min 1)", + Minimum: jsonschema.Ptr(1.0), + }, + "perPage": { + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + }, + }, + Required: []string{"query"}, + } + searchRepositoriesResolvedCache schemaCache +) + +// MCPSchema returns the JSON schema for SearchRepositoriesInput. +func (SearchRepositoriesInput) MCPSchema() *jsonschema.Schema { + return searchRepositoriesSchema +} + +// MCPResolvedSchema returns the pre-resolved schema for SearchRepositoriesInput. +func (SearchRepositoriesInput) MCPResolvedSchema() *jsonschema.Resolved { + return searchRepositoriesResolvedCache.get(searchRepositoriesSchema) +} + +// ============================================================================ +// SearchCodeInput - for search_code tool +// ============================================================================ + +// SearchCodeInput is the typed input for the search_code tool. +type SearchCodeInput struct { + Query string `json:"query"` + Sort string `json:"sort,omitempty"` + Order string `json:"order,omitempty"` + Page int `json:"page,omitempty"` + PerPage int `json:"perPage,omitempty"` +} + +var ( + searchCodeSchema = &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "query": { + Type: "string", + Description: "Search query using GitHub's powerful code search syntax. Examples: 'content:Skill language:Java org:github', 'NOT is:archived language:Python OR language:go', 'repo:github/github-mcp-server'. Supports exact matching, language filters, path filters, and more.", + }, + "sort": { + Type: "string", + Description: "Sort field ('indexed' only)", + }, + "order": { + Type: "string", + Description: "Sort order for results", + Enum: []any{"asc", "desc"}, + }, + "page": { + Type: "number", + Description: "Page number for pagination (min 1)", + Minimum: jsonschema.Ptr(1.0), + }, + "perPage": { + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + }, + }, + Required: []string{"query"}, + } + searchCodeResolvedCache schemaCache +) + +// MCPSchema returns the JSON schema for SearchCodeInput. +func (SearchCodeInput) MCPSchema() *jsonschema.Schema { + return searchCodeSchema +} + +// MCPResolvedSchema returns the pre-resolved schema for SearchCodeInput. +func (SearchCodeInput) MCPResolvedSchema() *jsonschema.Resolved { + return searchCodeResolvedCache.get(searchCodeSchema) +} + +// ============================================================================ +// SearchUsersInput - for search_users tool +// ============================================================================ + +// SearchUsersInput is the typed input for the search_users tool. +type SearchUsersInput struct { + Query string `json:"query"` + Sort string `json:"sort,omitempty"` + Order string `json:"order,omitempty"` + Page int `json:"page,omitempty"` + PerPage int `json:"perPage,omitempty"` +} + +var ( + searchUsersSchema = &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "query": { + Type: "string", + Description: "User search query. Examples: 'john smith', 'location:seattle', 'followers:>100'. Search is automatically scoped to type:user.", + }, + "sort": { + Type: "string", + Description: "Sort users by number of followers or repositories, or when the person joined GitHub.", + Enum: []any{"followers", "repositories", "joined"}, + }, + "order": { + Type: "string", + Description: "Sort order", + Enum: []any{"asc", "desc"}, + }, + "page": { + Type: "number", + Description: "Page number for pagination (min 1)", + Minimum: jsonschema.Ptr(1.0), + }, + "perPage": { + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + }, + }, + Required: []string{"query"}, + } + searchUsersResolvedCache schemaCache +) + +// MCPSchema returns the JSON schema for SearchUsersInput. +func (SearchUsersInput) MCPSchema() *jsonschema.Schema { + return searchUsersSchema +} + +// MCPResolvedSchema returns the pre-resolved schema for SearchUsersInput. +func (SearchUsersInput) MCPResolvedSchema() *jsonschema.Resolved { + return searchUsersResolvedCache.get(searchUsersSchema) +} + +// ============================================================================ +// GetFileContentsInput - for get_file_contents tool +// ============================================================================ + +// GetFileContentsInput is the typed input for the get_file_contents tool. +type GetFileContentsInput struct { + Owner string `json:"owner"` + Repo string `json:"repo"` + Path string `json:"path,omitempty"` + Ref string `json:"ref,omitempty"` + SHA string `json:"sha,omitempty"` +} + +var ( + getFileContentsSchema = &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner (username or organization)", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "path": { + Type: "string", + Description: "Path to file/directory (directories must end with a slash '/')", + Default: json.RawMessage(`"/"`), + }, + "ref": { + Type: "string", + Description: "Accepts optional git refs such as `refs/tags/{tag}`, `refs/heads/{branch}` or `refs/pull/{pr_number}/head`", + }, + "sha": { + Type: "string", + Description: "Accepts optional commit SHA. If specified, it will be used instead of ref", + }, + }, + Required: []string{"owner", "repo"}, + } + getFileContentsResolvedCache schemaCache +) + +// MCPSchema returns the JSON schema for GetFileContentsInput. +func (GetFileContentsInput) MCPSchema() *jsonschema.Schema { + return getFileContentsSchema +} + +// MCPResolvedSchema returns the pre-resolved schema for GetFileContentsInput. +func (GetFileContentsInput) MCPResolvedSchema() *jsonschema.Resolved { + return getFileContentsResolvedCache.get(getFileContentsSchema) +} + +// ============================================================================ +// ListCommitsInput - for list_commits tool +// ============================================================================ + +// ListCommitsInput is the typed input for the list_commits tool. +type ListCommitsInput struct { + Owner string `json:"owner"` + Repo string `json:"repo"` + SHA string `json:"sha,omitempty"` + Author string `json:"author,omitempty"` + Page int `json:"page,omitempty"` + PerPage int `json:"perPage,omitempty"` +} + +var ( + listCommitsSchema = &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "sha": { + Type: "string", + Description: "Commit SHA, branch or tag name to list commits of. If not provided, uses the default branch of the repository. If a commit SHA is provided, will list commits up to that SHA.", + }, + "author": { + Type: "string", + Description: "Author username or email address to filter commits by", + }, + "page": { + Type: "number", + Description: "Page number for pagination (min 1)", + Minimum: jsonschema.Ptr(1.0), + }, + "perPage": { + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + }, + }, + Required: []string{"owner", "repo"}, + } + listCommitsResolvedCache schemaCache +) + +// MCPSchema returns the JSON schema for ListCommitsInput. +func (ListCommitsInput) MCPSchema() *jsonschema.Schema { + return listCommitsSchema +} + +// MCPResolvedSchema returns the pre-resolved schema for ListCommitsInput. +func (ListCommitsInput) MCPResolvedSchema() *jsonschema.Resolved { + return listCommitsResolvedCache.get(listCommitsSchema) +} + +// ============================================================================ +// GetCommitInput - for get_commit tool +// ============================================================================ + +// GetCommitInput is the typed input for the get_commit tool. +type GetCommitInput struct { + Owner string `json:"owner"` + Repo string `json:"repo"` + SHA string `json:"sha"` + IncludeDiff *bool `json:"include_diff,omitempty"` + Page int `json:"page,omitempty"` + PerPage int `json:"perPage,omitempty"` +} + +var ( + getCommitSchema = &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "sha": { + Type: "string", + Description: "Commit SHA, branch name, or tag name", + }, + "include_diff": { + Type: "boolean", + Description: "Whether to include file diffs and stats in the response. Default is true.", + Default: json.RawMessage(`true`), + }, + "page": { + Type: "number", + Description: "Page number for pagination (min 1)", + Minimum: jsonschema.Ptr(1.0), + }, + "perPage": { + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + }, + }, + Required: []string{"owner", "repo", "sha"}, + } + getCommitResolvedCache schemaCache +) + +// MCPSchema returns the JSON schema for GetCommitInput. +func (GetCommitInput) MCPSchema() *jsonschema.Schema { + return getCommitSchema +} + +// MCPResolvedSchema returns the pre-resolved schema for GetCommitInput. +func (GetCommitInput) MCPResolvedSchema() *jsonschema.Resolved { + return getCommitResolvedCache.get(getCommitSchema) +} + +// ============================================================================ +// ListIssuesInput - for list_issues tool +// ============================================================================ + +// ListIssuesInput is the typed input for the list_issues tool. +type ListIssuesInput struct { + Owner string `json:"owner"` + Repo string `json:"repo"` + State string `json:"state,omitempty"` + Labels []string `json:"labels,omitempty"` + OrderBy string `json:"orderBy,omitempty"` + Direction string `json:"direction,omitempty"` + Since string `json:"since,omitempty"` + PerPage int `json:"perPage,omitempty"` + After string `json:"after,omitempty"` +} + +var ( + listIssuesSchema = &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "state": { + Type: "string", + Description: "Filter by state, by default both open and closed issues are returned when not provided", + Enum: []any{"OPEN", "CLOSED"}, + }, + "labels": { + Type: "array", + Description: "Filter by labels", + Items: &jsonschema.Schema{ + Type: "string", + }, + }, + "orderBy": { + Type: "string", + Description: "Order issues by field. If provided, the 'direction' also needs to be provided.", + Enum: []any{"CREATED_AT", "UPDATED_AT", "COMMENTS"}, + }, + "direction": { + Type: "string", + Description: "Order direction. If provided, the 'orderBy' also needs to be provided.", + Enum: []any{"ASC", "DESC"}, + }, + "since": { + Type: "string", + Description: "Filter by date (ISO 8601 timestamp)", + }, + "perPage": { + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + }, + "after": { + Type: "string", + Description: "Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs.", + }, + }, + Required: []string{"owner", "repo"}, + } + listIssuesResolvedCache schemaCache +) + +// MCPSchema returns the JSON schema for ListIssuesInput. +func (ListIssuesInput) MCPSchema() *jsonschema.Schema { + return listIssuesSchema +} + +// MCPResolvedSchema returns the pre-resolved schema for ListIssuesInput. +func (ListIssuesInput) MCPResolvedSchema() *jsonschema.Resolved { + return listIssuesResolvedCache.get(listIssuesSchema) +} + +// ============================================================================ +// IssueReadInput - for issue_read tool +// ============================================================================ + +// IssueReadInput is the typed input for the issue_read tool. +type IssueReadInput struct { + Method string `json:"method"` + Owner string `json:"owner"` + Repo string `json:"repo"` + IssueNumber int `json:"issue_number"` + Page int `json:"page,omitempty"` + PerPage int `json:"perPage,omitempty"` +} + +var ( + issueReadSchema = &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "method": { + Type: "string", + Description: `The read operation to perform on a single issue. +Options are: +1. get - Get details of a specific issue. +2. get_comments - Get issue comments. +3. get_sub_issues - Get sub-issues of the issue. +4. get_labels - Get labels assigned to the issue. +`, + Enum: []any{"get", "get_comments", "get_sub_issues", "get_labels"}, + }, + "owner": { + Type: "string", + Description: "The owner of the repository", + }, + "repo": { + Type: "string", + Description: "The name of the repository", + }, + "issue_number": { + Type: "number", + Description: "The number of the issue", + }, + "page": { + Type: "number", + Description: "Page number for pagination (min 1)", + Minimum: jsonschema.Ptr(1.0), + }, + "perPage": { + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + }, + }, + Required: []string{"method", "owner", "repo", "issue_number"}, + } + issueReadResolvedCache schemaCache +) + +// MCPSchema returns the JSON schema for IssueReadInput. +func (IssueReadInput) MCPSchema() *jsonschema.Schema { + return issueReadSchema +} + +// MCPResolvedSchema returns the pre-resolved schema for IssueReadInput. +func (IssueReadInput) MCPResolvedSchema() *jsonschema.Resolved { + return issueReadResolvedCache.get(issueReadSchema) +} + +// ============================================================================ +// ListPullRequestsInput - for list_pull_requests tool +// ============================================================================ + +// ListPullRequestsInput is the typed input for the list_pull_requests tool. +type ListPullRequestsInput struct { + Owner string `json:"owner"` + Repo string `json:"repo"` + State string `json:"state,omitempty"` + Head string `json:"head,omitempty"` + Base string `json:"base,omitempty"` + Sort string `json:"sort,omitempty"` + Direction string `json:"direction,omitempty"` + Page int `json:"page,omitempty"` + PerPage int `json:"perPage,omitempty"` +} + +var ( + listPullRequestsSchema = &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "state": { + Type: "string", + Description: "Filter by state", + Enum: []any{"open", "closed", "all"}, + }, + "head": { + Type: "string", + Description: "Filter by head user/org and branch", + }, + "base": { + Type: "string", + Description: "Filter by base branch", + }, + "sort": { + Type: "string", + Description: "Sort by", + Enum: []any{"created", "updated", "popularity", "long-running"}, + }, + "direction": { + Type: "string", + Description: "Sort direction", + Enum: []any{"asc", "desc"}, + }, + "page": { + Type: "number", + Description: "Page number for pagination (min 1)", + Minimum: jsonschema.Ptr(1.0), + }, + "perPage": { + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + }, + }, + Required: []string{"owner", "repo"}, + } + listPullRequestsResolvedCache schemaCache +) + +// MCPSchema returns the JSON schema for ListPullRequestsInput. +func (ListPullRequestsInput) MCPSchema() *jsonschema.Schema { + return listPullRequestsSchema +} + +// MCPResolvedSchema returns the pre-resolved schema for ListPullRequestsInput. +func (ListPullRequestsInput) MCPResolvedSchema() *jsonschema.Resolved { + return listPullRequestsResolvedCache.get(listPullRequestsSchema) +} + +// ============================================================================ +// PullRequestReadInput - for pull_request_read tool +// ============================================================================ + +// PullRequestReadInput is the typed input for the pull_request_read tool. +type PullRequestReadInput struct { + Method string `json:"method"` + Owner string `json:"owner"` + Repo string `json:"repo"` + PullNumber int `json:"pullNumber"` + Page int `json:"page,omitempty"` + PerPage int `json:"perPage,omitempty"` +} + +var ( + pullRequestReadSchema = &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "method": { + Type: "string", + Description: `Action to specify what pull request data needs to be retrieved from GitHub. +Possible options: + 1. get - Get details of a specific pull request. + 2. get_diff - Get the diff of a pull request. + 3. get_status - Get status of a head commit in a pull request. This reflects status of builds and checks. + 4. get_files - Get the list of files changed in a pull request. Use with pagination parameters to control the number of results returned. + 5. get_review_comments - Get the review comments on a pull request. They are comments made on a portion of the unified diff during a pull request review. Use with pagination parameters to control the number of results returned. + 6. get_reviews - Get the reviews on a pull request. When asked for review comments, use get_review_comments method. + 7. get_comments - Get comments on a pull request. Use this if user doesn't specifically want review comments. Use with pagination parameters to control the number of results returned. +`, + Enum: []any{"get", "get_diff", "get_status", "get_files", "get_review_comments", "get_reviews", "get_comments"}, + }, + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "pullNumber": { + Type: "number", + Description: "Pull request number", + }, + "page": { + Type: "number", + Description: "Page number for pagination (min 1)", + Minimum: jsonschema.Ptr(1.0), + }, + "perPage": { + Type: "number", + Description: "Results per page for pagination (min 1, max 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + }, + }, + Required: []string{"method", "owner", "repo", "pullNumber"}, + } + pullRequestReadResolvedCache schemaCache +) + +// MCPSchema returns the JSON schema for PullRequestReadInput. +func (PullRequestReadInput) MCPSchema() *jsonschema.Schema { + return pullRequestReadSchema +} + +// MCPResolvedSchema returns the pre-resolved schema for PullRequestReadInput. +func (PullRequestReadInput) MCPResolvedSchema() *jsonschema.Resolved { + return pullRequestReadResolvedCache.get(pullRequestReadSchema) +} diff --git a/pkg/github/search.go b/pkg/github/search.go index cffd0bf15..12069f80c 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -16,34 +16,7 @@ import ( ) // SearchRepositories creates a tool to search for GitHub repositories. -func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - schema := &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "query": { - Type: "string", - Description: "Repository search query. Examples: 'machine learning in:name stars:>1000 language:python', 'topic:react', 'user:facebook'. Supports advanced search syntax for precise filtering.", - }, - "sort": { - Type: "string", - Description: "Sort repositories by field, defaults to best match", - Enum: []any{"stars", "forks", "help-wanted-issues", "updated"}, - }, - "order": { - Type: "string", - Description: "Sort order", - Enum: []any{"asc", "desc"}, - }, - "minimal_output": { - Type: "boolean", - Description: "Return minimal repository information (default: true). When false, returns full GitHub API repository objects.", - Default: json.RawMessage(`true`), - }, - }, - Required: []string{"query"}, - } - WithPagination(schema) - +func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[SearchRepositoriesInput, any]) { return mcp.Tool{ Name: "search_repositories", Description: t("TOOL_SEARCH_REPOSITORIES_DESCRIPTION", "Find GitHub repositories by name, description, readme, topics, or other metadata. Perfect for discovering projects, finding examples, or locating specific repositories across GitHub."), @@ -51,35 +24,31 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF Title: t("TOOL_SEARCH_REPOSITORIES_USER_TITLE", "Search repositories"), ReadOnlyHint: true, }, - InputSchema: schema, + InputSchema: SearchRepositoriesInput{}.MCPSchema(), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - query, err := RequiredParam[string](args, "query") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sort, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + func(ctx context.Context, _ *mcp.CallToolRequest, input SearchRepositoriesInput) (*mcp.CallToolResult, any, error) { + // Get minimalOutput with default value of true + minimalOutput := true + if input.MinimalOutput != nil { + minimalOutput = *input.MinimalOutput } - order, err := OptionalParam[string](args, "order") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + + // Set pagination defaults + page := input.Page + if page == 0 { + page = 1 } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - minimalOutput, err := OptionalBoolParamWithDefault(args, "minimal_output", true) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + perPage := input.PerPage + if perPage == 0 { + perPage = 30 } + opts := &github.SearchOptions{ - Sort: sort, - Order: order, + Sort: input.Sort, + Order: input.Order, ListOptions: github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, + Page: page, + PerPage: perPage, }, } @@ -87,10 +56,10 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } - result, resp, err := client.Search.Repositories(ctx, query, opts) + result, resp, err := client.Search.Repositories(ctx, input.Query, opts) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to search repositories with query '%s'", query), + fmt.Sprintf("failed to search repositories with query '%s'", input.Query), resp, err, ), nil, nil @@ -161,28 +130,7 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF } // SearchCode creates a tool to search for code across GitHub repositories. -func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - schema := &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "query": { - Type: "string", - Description: "Search query using GitHub's powerful code search syntax. Examples: 'content:Skill language:Java org:github', 'NOT is:archived language:Python OR language:go', 'repo:github/github-mcp-server'. Supports exact matching, language filters, path filters, and more.", - }, - "sort": { - Type: "string", - Description: "Sort field ('indexed' only)", - }, - "order": { - Type: "string", - Description: "Sort order for results", - Enum: []any{"asc", "desc"}, - }, - }, - Required: []string{"query"}, - } - WithPagination(schema) - +func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[SearchCodeInput, any]) { return mcp.Tool{ Name: "search_code", Description: t("TOOL_SEARCH_CODE_DESCRIPTION", "Fast and precise code search across ALL GitHub repositories using GitHub's native search engine. Best for finding exact symbols, functions, classes, or specific code patterns."), @@ -190,32 +138,25 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mc Title: t("TOOL_SEARCH_CODE_USER_TITLE", "Search code"), ReadOnlyHint: true, }, - InputSchema: schema, + InputSchema: SearchCodeInput{}.MCPSchema(), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - query, err := RequiredParam[string](args, "query") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sort, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - order, err := OptionalParam[string](args, "order") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + func(ctx context.Context, _ *mcp.CallToolRequest, input SearchCodeInput) (*mcp.CallToolResult, any, error) { + // Set pagination defaults + page := input.Page + if page == 0 { + page = 1 } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil + perPage := input.PerPage + if perPage == 0 { + perPage = 30 } opts := &github.SearchOptions{ - Sort: sort, - Order: order, + Sort: input.Sort, + Order: input.Order, ListOptions: github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, + PerPage: perPage, + Page: page, }, } @@ -224,10 +165,10 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mc return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } - result, resp, err := client.Search.Code(ctx, query, opts) + result, resp, err := client.Search.Code(ctx, input.Query, opts) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to search code with query '%s'", query), + fmt.Sprintf("failed to search code with query '%s'", input.Query), resp, err, ), nil, nil @@ -251,6 +192,87 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mc } } +func typedUserOrOrgHandler(accountType string, getClient GetClientFn) mcp.ToolHandlerFor[SearchUsersInput, any] { + return func(ctx context.Context, _ *mcp.CallToolRequest, input SearchUsersInput) (*mcp.CallToolResult, any, error) { + // Set pagination defaults + page := input.Page + if page == 0 { + page = 1 + } + perPage := input.PerPage + if perPage == 0 { + perPage = 30 + } + + opts := &github.SearchOptions{ + Sort: input.Sort, + Order: input.Order, + ListOptions: github.ListOptions{ + PerPage: perPage, + Page: page, + }, + } + + client, err := getClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + searchQuery := input.Query + if !hasTypeFilter(input.Query) { + searchQuery = "type:" + accountType + " " + input.Query + } + result, resp, err := client.Search.Users(ctx, searchQuery, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to search %ss with query '%s'", accountType, input.Query), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return utils.NewToolResultError(fmt.Sprintf("failed to search %ss: %s", accountType, string(body))), nil, nil + } + + minimalUsers := make([]MinimalUser, 0, len(result.Users)) + + for _, user := range result.Users { + if user.Login != nil { + mu := MinimalUser{ + Login: user.GetLogin(), + ID: user.GetID(), + ProfileURL: user.GetHTMLURL(), + AvatarURL: user.GetAvatarURL(), + } + minimalUsers = append(minimalUsers, mu) + } + } + minimalResp := &MinimalSearchUsersResult{ + TotalCount: result.GetTotal(), + IncompleteResults: result.GetIncompleteResults(), + Items: minimalUsers, + } + if result.Total != nil { + minimalResp.TotalCount = *result.Total + } + if result.IncompleteResults != nil { + minimalResp.IncompleteResults = *result.IncompleteResults + } + + r, err := json.Marshal(minimalResp) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + return utils.NewToolResultText(string(r)), nil, nil + } +} + func userOrOrgHandler(accountType string, getClient GetClientFn) mcp.ToolHandlerFor[map[string]any, any] { return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { query, err := RequiredParam[string](args, "query") @@ -340,29 +362,7 @@ func userOrOrgHandler(accountType string, getClient GetClientFn) mcp.ToolHandler } // SearchUsers creates a tool to search for GitHub users. -func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - schema := &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "query": { - Type: "string", - Description: "User search query. Examples: 'john smith', 'location:seattle', 'followers:>100'. Search is automatically scoped to type:user.", - }, - "sort": { - Type: "string", - Description: "Sort users by number of followers or repositories, or when the person joined GitHub.", - Enum: []any{"followers", "repositories", "joined"}, - }, - "order": { - Type: "string", - Description: "Sort order", - Enum: []any{"asc", "desc"}, - }, - }, - Required: []string{"query"}, - } - WithPagination(schema) - +func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[SearchUsersInput, any]) { return mcp.Tool{ Name: "search_users", Description: t("TOOL_SEARCH_USERS_DESCRIPTION", "Find GitHub users by username, real name, or other profile information. Useful for locating developers, contributors, or team members."), @@ -370,8 +370,8 @@ func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (m Title: t("TOOL_SEARCH_USERS_USER_TITLE", "Search users"), ReadOnlyHint: true, }, - InputSchema: schema, - }, userOrOrgHandler("user", getClient) + InputSchema: SearchUsersInput{}.MCPSchema(), + }, typedUserOrOrgHandler("user", getClient) } // SearchOrgs creates a tool to search for GitHub organizations.