Skip to content

Commit 928dbf7

Browse files
committed
Add OAuth flow chaining and token injection for stdio sessions
After Google OAuth, automatically redirect to server OAuth if the return URL specifies a server that requires it. Also inject user tokens into stdio sessions during creation so they can authenticate with backend services.
1 parent 1ca3c8f commit 928dbf7

File tree

10 files changed

+119
-43
lines changed

10 files changed

+119
-43
lines changed

integration/oauth_test.go

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ func TestClientRegistration(t *testing.T) {
157157
})
158158
defer stopServer(mcpCmd)
159159

160-
if !waitForHealthCheck(t, 30) {
160+
if !waitForHealthCheck(30) {
161161
t.Fatal("OAuth server failed to start")
162162
}
163163

@@ -358,7 +358,7 @@ func TestUserTokenFlow(t *testing.T) {
358358
mcpCmd := startOAuthServerWithTokenConfig(t)
359359
defer stopServer(mcpCmd)
360360

361-
if !waitForHealthCheck(t, 30) {
361+
if !waitForHealthCheck(30) {
362362
t.Fatal("Server failed to start")
363363
}
364364

@@ -532,7 +532,7 @@ func TestStateParameterHandling(t *testing.T) {
532532
})
533533
defer stopServer(mcpCmd)
534534

535-
if !waitForHealthCheck(t, 10) {
535+
if !waitForHealthCheck(10) {
536536
t.Fatal("Server failed to start")
537537
}
538538

@@ -602,7 +602,7 @@ func TestEnvironmentModes(t *testing.T) {
602602
})
603603
defer stopServer(mcpCmd)
604604

605-
if !waitForHealthCheck(t, 30) {
605+
if !waitForHealthCheck(30) {
606606
t.Fatal("Server failed to start")
607607
}
608608

@@ -643,7 +643,7 @@ func TestEnvironmentModes(t *testing.T) {
643643
})
644644
defer stopServer(mcpCmd)
645645

646-
if !waitForHealthCheck(t, 30) {
646+
if !waitForHealthCheck(30) {
647647
t.Fatal("Server failed to start")
648648
}
649649

@@ -693,7 +693,7 @@ func TestOAuthEndpoints(t *testing.T) {
693693
})
694694
defer stopServer(mcpCmd)
695695

696-
if !waitForHealthCheck(t, 10) {
696+
if !waitForHealthCheck(10) {
697697
t.Fatal("Server failed to start")
698698
}
699699

@@ -761,7 +761,7 @@ func TestCORSHeaders(t *testing.T) {
761761
})
762762
defer stopServer(mcpCmd)
763763

764-
if !waitForHealthCheck(t, 10) {
764+
if !waitForHealthCheck(10) {
765765
t.Fatal("Server failed to start")
766766
}
767767

@@ -813,7 +813,7 @@ func TestToolAdvertisementWithUserTokens(t *testing.T) {
813813
"LOG_LEVEL=debug",
814814
)
815815

816-
if !waitForHealthCheck(t, 30) {
816+
if !waitForHealthCheck(30) {
817817
t.Fatal("Server failed to start")
818818
}
819819

@@ -980,13 +980,14 @@ func TestToolAdvertisementWithUserTokens(t *testing.T) {
980980
defer resp.Body.Close()
981981

982982
// Check the response - it might be 200 if following redirects
983-
if resp.StatusCode == 200 {
983+
switch resp.StatusCode {
984+
case 200:
984985
// That's fine, it means the token was set and we got the page back
985986
t.Log("Token set successfully, got page response")
986-
} else if resp.StatusCode == 302 || resp.StatusCode == 303 {
987+
case 302, 303:
987988
// Also fine, redirect means success
988989
t.Log("Token set successfully, got redirect")
989-
} else {
990+
default:
990991
body, _ := io.ReadAll(resp.Body)
991992
t.Fatalf("Unexpected response setting token: status=%d, body=%s", resp.StatusCode, string(body))
992993
}
@@ -1119,7 +1120,7 @@ func stopServer(cmd *exec.Cmd) {
11191120
}
11201121
}
11211122

1122-
func waitForHealthCheck(t *testing.T, seconds int) bool {
1123+
func waitForHealthCheck(seconds int) bool {
11231124
for i := 0; i < seconds; i++ {
11241125
if checkHealth() {
11251126
return true

integration/security_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,11 @@ func TestSecurityScenarios(t *testing.T) {
221221
}
222222
resp.Body.Close()
223223

224-
if resp.StatusCode == 200 {
224+
switch resp.StatusCode {
225+
case 200:
225226
t.Errorf("CRITICAL: Auth bypass! 'test-token' without Bearer returned 200")
226-
} else if resp.StatusCode == 401 {
227-
} else {
227+
case 401:
228+
default:
228229
t.Logf("Unexpected status %d for malformed auth", resp.StatusCode)
229230
}
230231
})

internal/client/client.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,6 @@ func (c *Client) wrapToolHandler(
435435
errorData := createTokenRequiredError(
436436
serverName,
437437
setupBaseURL,
438-
userAuth,
439438
"configuration error: this service requires user tokens but OAuth is not properly configured.",
440439
)
441440

@@ -469,7 +468,6 @@ func (c *Client) wrapToolHandler(
469468
errorData := createTokenRequiredError(
470469
serverName,
471470
setupBaseURL,
472-
userAuth,
473471
errorMessage,
474472
)
475473

@@ -509,7 +507,7 @@ func (c *Client) Close() error {
509507
}
510508

511509
// createTokenRequiredError creates the structured error for missing user tokens
512-
func createTokenRequiredError(serverName, setupBaseURL string, userAuth *config.UserAuthentication, message string) map[string]interface{} {
510+
func createTokenRequiredError(serverName, setupBaseURL string, message string) map[string]interface{} {
513511
tokenSetupURL := fmt.Sprintf("%s/my/tokens", setupBaseURL)
514512

515513
return map[string]interface{}{

internal/client/session_manager.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ func (sm *StdioSessionManager) GetOrCreateSession(
115115
config *config.MCPClientConfig,
116116
info mcp.Implementation,
117117
baseURL string,
118+
userToken string,
118119
) (*StdioSession, error) {
119120
// Try to get existing session first
120121
if session, ok := sm.GetSession(key); ok {
@@ -125,7 +126,7 @@ func (sm *StdioSessionManager) GetOrCreateSession(
125126
return nil, err
126127
}
127128

128-
return sm.createSession(ctx, key, config, info, baseURL)
129+
return sm.createSession(key, config, userToken)
129130
}
130131

131132
// GetSession retrieves an existing session
@@ -379,14 +380,21 @@ func (sm *StdioSessionManager) getUserSessionCount(userEmail string) int {
379380

380381
// createSession creates a new stdio session
381382
func (sm *StdioSessionManager) createSession(
382-
ctx context.Context,
383383
key SessionKey,
384384
config *config.MCPClientConfig,
385-
info mcp.Implementation,
386-
baseURL string,
385+
userToken string,
387386
) (*StdioSession, error) {
387+
// Create an independent context for the stdio session. We intentionally use
388+
// context.Background() instead of the HTTP request context because stdio
389+
// sessions are long-lived processes that must persist across multiple HTTP
390+
// requests. The session will be cleaned up by the timeout-based cleanup
391+
// routine, not by HTTP request cancellation.
388392
sessionCtx, cancel := context.WithCancel(context.Background())
389393

394+
if userToken != "" && config.RequiresUserToken {
395+
config = config.ApplyUserToken(userToken)
396+
}
397+
390398
client, err := sm.createClient(key.ServerName, config)
391399
if err != nil {
392400
cancel()

internal/client/session_manager_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ func TestStdioSessionManager_CreateAndRetrieve(t *testing.T) {
4444
}
4545

4646
// Create session
47-
session1, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost")
47+
session1, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost", "")
4848
require.NoError(t, err)
4949
require.NotNil(t, session1)
5050

5151
// Retrieve same session
52-
session2, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost")
52+
session2, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost", "")
5353
require.NoError(t, err)
5454
require.NotNil(t, session2)
5555

@@ -81,22 +81,22 @@ func TestStdioSessionManager_UserLimits(t *testing.T) {
8181

8282
// Create first session
8383
key1 := SessionKey{UserEmail: userEmail, ServerName: "server", SessionID: "1"}
84-
_, err := sm.GetOrCreateSession(context.Background(), key1, config, info, "http://localhost")
84+
_, err := sm.GetOrCreateSession(context.Background(), key1, config, info, "http://localhost", "")
8585
require.NoError(t, err)
8686

8787
// Create second session (at limit)
8888
key2 := SessionKey{UserEmail: userEmail, ServerName: "server", SessionID: "2"}
89-
_, err = sm.GetOrCreateSession(context.Background(), key2, config, info, "http://localhost")
89+
_, err = sm.GetOrCreateSession(context.Background(), key2, config, info, "http://localhost", "")
9090
require.NoError(t, err)
9191

9292
// Try to create third session (should fail)
9393
key3 := SessionKey{UserEmail: userEmail, ServerName: "server", SessionID: "3"}
94-
_, err = sm.GetOrCreateSession(context.Background(), key3, config, info, "http://localhost")
94+
_, err = sm.GetOrCreateSession(context.Background(), key3, config, info, "http://localhost", "")
9595
assert.ErrorIs(t, err, ErrUserLimitExceeded)
9696

9797
// Different user should work
9898
key4 := SessionKey{UserEmail: "[email protected]", ServerName: "server", SessionID: "4"}
99-
_, err = sm.GetOrCreateSession(context.Background(), key4, config, info, "http://localhost")
99+
_, err = sm.GetOrCreateSession(context.Background(), key4, config, info, "http://localhost", "")
100100
require.NoError(t, err)
101101
}
102102

@@ -122,7 +122,7 @@ func TestStdioSessionManager_RemoveSession(t *testing.T) {
122122
info := mcp.Implementation{Name: "test", Version: "1.0"}
123123

124124
// Create session
125-
session, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost")
125+
session, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost", "")
126126
require.NoError(t, err)
127127
require.NotNil(t, session)
128128

@@ -161,7 +161,7 @@ func TestStdioSessionManager_Timeout(t *testing.T) {
161161
info := mcp.Implementation{Name: "test", Version: "1.0"}
162162

163163
// Create session
164-
session, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost")
164+
session, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost", "")
165165
require.NoError(t, err)
166166
require.NotNil(t, session)
167167

@@ -204,7 +204,7 @@ func TestStdioSessionManager_ConcurrentAccess(t *testing.T) {
204204
}
205205

206206
// Create session
207-
_, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost")
207+
_, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost", "")
208208
assert.NoError(t, err)
209209

210210
// Get session
@@ -243,7 +243,7 @@ func TestStdioSessionManager_NoLimitsForAnonymous(t *testing.T) {
243243
SessionID: fmt.Sprintf("session-%d", i),
244244
}
245245

246-
_, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost")
246+
_, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost", "")
247247
require.NoError(t, err, "Anonymous session %d should succeed", i)
248248
}
249249
}

internal/server/auth_handlers.go

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ import (
55
"encoding/json"
66
"fmt"
77
"net/http"
8+
"net/url"
89
"strings"
910
"time"
1011

1112
"github.com/dgellow/mcp-front/internal/auth"
13+
"github.com/dgellow/mcp-front/internal/config"
1214
"github.com/dgellow/mcp-front/internal/crypto"
1315
jsonwriter "github.com/dgellow/mcp-front/internal/json"
1416
"github.com/dgellow/mcp-front/internal/log"
@@ -18,12 +20,14 @@ import (
1820
// AuthHandlers wraps the auth.Server to provide HTTP handlers
1921
type AuthHandlers struct {
2022
authServer *auth.Server
23+
mcpServers map[string]*config.MCPClientConfig
2124
}
2225

2326
// NewAuthHandlers creates new auth handlers
24-
func NewAuthHandlers(authServer *auth.Server) *AuthHandlers {
27+
func NewAuthHandlers(authServer *auth.Server, mcpServers map[string]*config.MCPClientConfig) *AuthHandlers {
2528
return &AuthHandlers{
2629
authServer: authServer,
30+
mcpServers: mcpServers,
2731
}
2832
}
2933

@@ -222,7 +226,30 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ
222226
"returnURL": returnURL,
223227
})
224228

225-
// Redirect to return URL
229+
// Check if the return URL contains a server parameter for OAuth chaining
230+
parsedURL, err := url.Parse(returnURL)
231+
if err == nil {
232+
serverName := parsedURL.Query().Get("server")
233+
if serverName != "" {
234+
// Check if this server requires OAuth authentication
235+
if serverConfig, exists := h.mcpServers[serverName]; exists {
236+
if serverConfig.RequiresUserToken &&
237+
serverConfig.UserAuthentication != nil &&
238+
serverConfig.UserAuthentication.Type == config.UserAuthTypeOAuth {
239+
encodedReturnURL := url.QueryEscape(returnURL)
240+
oauthURL := fmt.Sprintf("/oauth/connect?service=%s&return=%s", serverName, encodedReturnURL)
241+
log.LogInfoWithFields("auth", "Chaining to server OAuth", map[string]interface{}{
242+
"server": serverName,
243+
"user": userInfo.Email,
244+
})
245+
http.Redirect(w, r, oauthURL, http.StatusFound)
246+
return
247+
}
248+
}
249+
}
250+
}
251+
252+
// Otherwise, redirect to return URL as normal
226253
http.Redirect(w, r, returnURL, http.StatusFound)
227254
return
228255
}

internal/server/handler.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"net/http"
77
"net/url"
8+
"strings"
89
"time"
910

1011
"github.com/dgellow/mcp-front/internal/auth"
@@ -174,7 +175,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (*Server, error) {
174175
recoverMiddleware("mcp"),
175176
}
176177

177-
authHandlers := NewAuthHandlers(s.authServer)
178+
authHandlers := NewAuthHandlers(s.authServer, cfg.MCPServers)
178179
mux.Handle("/.well-known/oauth-authorization-server", chainMiddleware(http.HandlerFunc(authHandlers.WellKnownHandler), oauthMiddlewares...))
179180
mux.Handle("/authorize", chainMiddleware(http.HandlerFunc(authHandlers.AuthorizeHandler), oauthMiddlewares...))
180181
mux.Handle("/oauth/callback", chainMiddleware(http.HandlerFunc(authHandlers.GoogleCallbackHandler), oauthMiddlewares...))
@@ -415,6 +416,27 @@ func isStdioServer(config *config.MCPClientConfig) bool {
415416
return config.Command != ""
416417
}
417418

419+
// formatUserToken formats a stored token according to the user authentication configuration
420+
func formatUserToken(storedToken *storage.StoredToken, auth *config.UserAuthentication) string {
421+
if storedToken == nil {
422+
return ""
423+
}
424+
425+
if storedToken.Type == storage.TokenTypeOAuth && storedToken.OAuthData != nil {
426+
token := storedToken.OAuthData.AccessToken
427+
if auth.TokenFormat != "" && auth.TokenFormat != "{{token}}" {
428+
return strings.ReplaceAll(auth.TokenFormat, "{{token}}", token)
429+
}
430+
return token
431+
}
432+
433+
token := storedToken.Value
434+
if auth != nil && auth.TokenFormat != "" && auth.TokenFormat != "{{token}}" {
435+
return strings.ReplaceAll(auth.TokenFormat, "{{token}}", token)
436+
}
437+
return token
438+
}
439+
418440
// sessionHandlerKey is the context key for session handlers
419441
type sessionHandlerKey struct{}
420442

@@ -455,12 +477,30 @@ func handleSessionRegistration(
455477
"command": handler.config.Command,
456478
})
457479

480+
var userToken string
481+
if handler.config.RequiresUserToken && handler.userEmail != "" && handler.h.storage != nil {
482+
storedToken, err := handler.h.storage.GetUserToken(sessionCtx, handler.userEmail, handler.h.serverName)
483+
if err != nil {
484+
log.LogDebugWithFields("server", "No user token found", map[string]interface{}{
485+
"server": handler.h.serverName,
486+
"user": handler.userEmail,
487+
})
488+
} else if storedToken != nil {
489+
if handler.config.UserAuthentication != nil {
490+
userToken = formatUserToken(storedToken, handler.config.UserAuthentication)
491+
} else {
492+
userToken = storedToken.Value
493+
}
494+
}
495+
}
496+
458497
stdioSession, err := sessionManager.GetOrCreateSession(
459498
sessionCtx,
460499
key,
461500
handler.config,
462501
handler.h.info,
463502
handler.h.setupBaseURL,
503+
userToken,
464504
)
465505
if err != nil {
466506
log.LogErrorWithFields("server", "Failed to create stdio session", map[string]interface{}{

0 commit comments

Comments
 (0)