Skip to content

Commit 2bcd1ee

Browse files
fix(oauth): only elicit when browser fails to open
PKCE flow now opens browser directly and waits for callback without elicitation. Elicitation is only used when: - Device flow (Docker mode) - to show user code - Browser fails to open - fallback to show URL
1 parent 70c4553 commit 2bcd1ee

File tree

4 files changed

+368
-152
lines changed

4 files changed

+368
-152
lines changed

internal/ghmcp/server.go

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ type MCPServerConfig struct {
3636
// GitHub Token to authenticate with the GitHub API
3737
Token string
3838

39+
// TokenProvider is an optional function to dynamically get the token.
40+
// Used for OAuth flows where the token is obtained after server startup.
41+
// If set, this takes precedence over Token for API requests.
42+
TokenProvider func() string
43+
3944
// PrebuiltInventory is an optional pre-built inventory to avoid double building
4045
// When set, this inventory will be used instead of building a new one
4146
PrebuiltInventory *inventory.Inventory
@@ -92,9 +97,19 @@ type githubClients struct {
9297
}
9398

9499
// createGitHubClients creates all the GitHub API clients needed by the server.
95-
func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients, error) {
96-
// Construct REST client
97-
restClient := gogithub.NewClient(nil).WithAuthToken(cfg.Token)
100+
// If tokenProviderFn is provided, it will be used to get the token dynamically (for OAuth).
101+
// Otherwise, cfg.Token is used as a static token.
102+
func createGitHubClients(cfg MCPServerConfig, apiHost apiHost, tokenProviderFn tokenProvider) (*githubClients, error) {
103+
// Create bearer auth transport that can use dynamic token
104+
restTransport := &bearerAuthTransport{
105+
transport: http.DefaultTransport,
106+
token: cfg.Token,
107+
tokenProvider: tokenProviderFn,
108+
}
109+
110+
// Construct REST client with custom transport
111+
restHTTPClient := &http.Client{Transport: restTransport}
112+
restClient := gogithub.NewClient(restHTTPClient)
98113
restClient.UserAgent = fmt.Sprintf("github-mcp-server/%s", cfg.Version)
99114
restClient.BaseURL = apiHost.baseRESTURL
100115
restClient.UploadURL = apiHost.uploadURL
@@ -106,12 +121,13 @@ func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients,
106121
transport: &github.GraphQLFeaturesTransport{
107122
Transport: http.DefaultTransport,
108123
},
109-
token: cfg.Token,
124+
token: cfg.Token,
125+
tokenProvider: tokenProviderFn,
110126
},
111127
}
112128
gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient)
113129

114-
// Create raw content client (shares REST client's HTTP transport)
130+
// Create raw content client (inherits transport from REST client)
115131
rawClient := raw.NewClient(restClient, apiHost.rawURL)
116132

117133
// Set up repo access cache for lockdown mode
@@ -168,7 +184,7 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) {
168184
return nil, fmt.Errorf("failed to parse API host: %w", err)
169185
}
170186

171-
clients, err := createGitHubClients(cfg, apiHost)
187+
clients, err := createGitHubClients(cfg, apiHost, cfg.TokenProvider)
172188
if err != nil {
173189
return nil, fmt.Errorf("failed to create GitHub clients: %w", err)
174190
}
@@ -338,7 +354,7 @@ type StdioServerConfig struct {
338354
OAuthManager interface {
339355
HasToken() bool
340356
GetAccessToken() string
341-
RequestAuthentication(context.Context) error
357+
RequestAuthentication(context.Context, *mcp.ServerSession) error
342358
}
343359

344360
// OAuthScopes contains the OAuth scopes that were requested
@@ -437,10 +453,22 @@ func RunStdioServer(cfg StdioServerConfig) error {
437453
logger.Debug("skipping scope filtering for non-PAT token")
438454
}
439455

456+
// Create token provider that checks OAuth first, then falls back to static token
457+
var tokenProvider func() string
458+
if cfg.OAuthManager != nil {
459+
tokenProvider = func() string {
460+
if token := cfg.OAuthManager.GetAccessToken(); token != "" {
461+
return token
462+
}
463+
return cfg.Token
464+
}
465+
}
466+
440467
ghServer, err := NewMCPServer(MCPServerConfig{
441468
Version: cfg.Version,
442469
Host: cfg.Host,
443470
Token: cfg.Token,
471+
TokenProvider: tokenProvider,
444472
PrebuiltInventory: cfg.PrebuiltInventory,
445473
EnabledToolsets: cfg.EnabledToolsets,
446474
EnabledTools: cfg.EnabledTools,
@@ -692,14 +720,24 @@ func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error
692720
return t.transport.RoundTrip(req)
693721
}
694722

723+
// tokenProvider is a function that returns the current auth token
724+
type tokenProvider func() string
725+
695726
type bearerAuthTransport struct {
696-
transport http.RoundTripper
697-
token string
727+
transport http.RoundTripper
728+
token string // static token (used if tokenProvider is nil)
729+
tokenProvider tokenProvider // dynamic token provider (takes precedence)
698730
}
699731

700732
func (t *bearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
701733
req = req.Clone(req.Context())
702-
req.Header.Set("Authorization", "Bearer "+t.token)
734+
token := t.token
735+
if t.tokenProvider != nil {
736+
token = t.tokenProvider()
737+
}
738+
if token != "" {
739+
req.Header.Set("Authorization", "Bearer "+token)
740+
}
703741
return t.transport.RoundTrip(req)
704742
}
705743

@@ -763,7 +801,7 @@ func fetchTokenScopesForHost(ctx context.Context, token, host string) ([]string,
763801
func createOAuthMiddleware(oauthMgr interface {
764802
HasToken() bool
765803
GetAccessToken() string
766-
RequestAuthentication(context.Context) error
804+
RequestAuthentication(context.Context, *mcp.ServerSession) error
767805
}, logger *slog.Logger) func(mcp.MethodHandler) mcp.MethodHandler {
768806
return func(next mcp.MethodHandler) mcp.MethodHandler {
769807
return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) {
@@ -775,14 +813,22 @@ func createOAuthMiddleware(oauthMgr interface {
775813
// Check if we have a token
776814
if !oauthMgr.HasToken() {
777815
logger.Info("no authentication token available, triggering OAuth flow")
778-
// Trigger OAuth authentication
779-
// This will return an MCP URL elicitation error that the client will handle
780-
if err := oauthMgr.RequestAuthentication(ctx); err != nil {
781-
// Return the error (which should be a URL elicitation error)
816+
817+
// Get the session for elicitation
818+
var session *mcp.ServerSession
819+
if sess := req.GetSession(); sess != nil {
820+
// Type assert to ServerSession
821+
if ss, ok := sess.(*mcp.ServerSession); ok {
822+
session = ss
823+
}
824+
}
825+
826+
// Trigger OAuth authentication (blocks until complete)
827+
if err := oauthMgr.RequestAuthentication(ctx, session); err != nil {
782828
return nil, err
783829
}
784-
// If we get here without error, OAuth completed immediately
785-
// Fall through to execute the tool with the new token
830+
// OAuth completed successfully - fall through to execute the tool
831+
logger.Info("OAuth authentication completed successfully")
786832
}
787833

788834
// Execute the tool with authentication

0 commit comments

Comments
 (0)