diff --git a/examples/client/simple-auth/go.mod b/examples/client/simple-auth/go.mod new file mode 100644 index 00000000..30f8d0c5 --- /dev/null +++ b/examples/client/simple-auth/go.mod @@ -0,0 +1,13 @@ +module simple-auth-client + +go 1.23.0 + +require github.com/modelcontextprotocol/go-sdk v0.3.0 + +require ( + github.com/google/jsonschema-go v0.3.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/oauth2 v0.30.0 // indirect +) + +replace github.com/modelcontextprotocol/go-sdk => ../../../ diff --git a/examples/client/simple-auth/go.sum b/examples/client/simple-auth/go.sum new file mode 100644 index 00000000..08607c99 --- /dev/null +++ b/examples/client/simple-auth/go.sum @@ -0,0 +1,10 @@ +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= +github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= +golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= diff --git a/examples/client/simple-auth/main.go b/examples/client/simple-auth/main.go new file mode 100644 index 00000000..50f9ab10 --- /dev/null +++ b/examples/client/simple-auth/main.go @@ -0,0 +1,573 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +// Simple MCP client example with OAuth authentication support. +// +// This client connects to an MCP server using streamable HTTP or SSE transport. +// +// Usage: +// +// go run main.go +// +// Environment variables: +// +// MCP_SERVER_PORT - Server port (default: 8000) +// MCP_TRANSPORT_TYPE - Transport type: streamable-http (default) or sse +package main + +import ( + "bufio" + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "log" + "net/http" + "net/url" + "os" + "os/exec" + "runtime" + "strings" + "sync" + "time" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" +) + +// registerClient performs Dynamic Client Registration (RFC 7591) with the authorization server. +// Returns the client ID and client secret. +func registerClient(ctx context.Context, authServerURL, redirectURI string, authMeta *oauthex.AuthServerMeta) (clientID, clientSecret string, err error) { + clientMeta := &oauthex.ClientRegistrationMetadata{ + ClientName: "Simple Auth Client", + RedirectURIs: []string{redirectURI}, + GrantTypes: []string{"authorization_code", "refresh_token"}, + ResponseTypes: []string{"code"}, + TokenEndpointAuthMethod: "client_secret_post", + Scope: "user", + } + + registrationEndpoint := authMeta.RegistrationEndpoint + if registrationEndpoint == "" { + // Fallback to default registration endpoint if not in metadata + registrationEndpoint = authServerURL + "/register" + } + + fmt.Printf("Registering client at %s\n", registrationEndpoint) + clientInfo, err := oauthex.RegisterClient(ctx, registrationEndpoint, clientMeta, nil) + if err != nil { + return "", "", fmt.Errorf("failed to register client: %w", err) + } + + fmt.Printf("Client registered with ID: %s\n", clientInfo.ClientID) + return clientInfo.ClientID, clientInfo.ClientSecret, nil +} + +// generatePKCE generates PKCE code verifier and challenge using golang.org/x/oauth2. +// Returns the verifier (43-128 characters) and the challenge (SHA256 hash). +func generatePKCE() (verifier, challenge string) { + verifier = oauth2.GenerateVerifier() + challenge = oauth2.S256ChallengeFromVerifier(verifier) + return verifier, challenge +} + +// openBrowser opens the specified URL in the default browser. +func openBrowser(url string) error { + var cmd *exec.Cmd + + switch runtime.GOOS { + case "darwin": + cmd = exec.Command("open", url) + case "linux": + cmd = exec.Command("xdg-open", url) + case "windows": + cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url) + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } + + return cmd.Start() +} + +// performOAuthFlow executes the OAuth 2.0 authorization code flow with PKCE. +// This implements the auth.OAuthHandler signature. +func performOAuthFlow(ctx context.Context, args auth.OAuthHandlerArgs) (oauth2.TokenSource, error) { + fmt.Println("Starting OAuth flow...") + + // Fetch protected resource metadata + if args.ResourceMetadataURL == "" { + return nil, fmt.Errorf("no resource metadata URL provided") + } + + // Extract resource ID from metadata URL + // The metadata URL is like http://host/.well-known/oauth-protected-resource + // GetProtectedResourceMetadataFromID expects just the resource ID (scheme + host + /) + metadataURL, err := url.Parse(args.ResourceMetadataURL) + if err != nil { + return nil, fmt.Errorf("invalid metadata URL: %w", err) + } + + resourceID := url.URL{ + Scheme: metadataURL.Scheme, + Host: metadataURL.Host, + Path: "/", + } + + fmt.Printf("Fetching protected resource metadata for %s\n", resourceID.String()) + metadata, err := oauthex.GetProtectedResourceMetadataFromID(ctx, resourceID.String(), nil) + if err != nil { + return nil, fmt.Errorf("failed to fetch resource metadata: %w", err) + } + + // Extract resource URL for RFC 8707 + resourceURL := metadata.Resource + if resourceURL == "" { + resourceURL = resourceID.String() + } + fmt.Printf("Resource URL: %s\n", resourceURL) + + // Get authorization server metadata + if metadata.AuthorizationServers == nil || len(metadata.AuthorizationServers) == 0 { + return nil, fmt.Errorf("no authorization servers in metadata") + } + + authServerURL := metadata.AuthorizationServers[0] + fmt.Printf("Using authorization server: %s\n", authServerURL) + + authMeta, err := oauthex.GetAuthServerMeta(ctx, authServerURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to fetch authorization server metadata: %w", err) + } + + // Register client dynamically + redirectURI := "http://localhost:3030/callback" + clientID, clientSecret, err := registerClient(ctx, authServerURL, redirectURI, authMeta) + if err != nil { + return nil, err + } + + // Start callback server + callbackServer := NewCallbackServer(3030) + if err := callbackServer.Start(); err != nil { + return nil, fmt.Errorf("failed to start callback server: %w", err) + } + defer callbackServer.Stop() + + // Generate PKCE + verifier, challenge := generatePKCE() + + // Generate state + stateBytes := make([]byte, 16) + if _, err := rand.Read(stateBytes); err != nil { + return nil, fmt.Errorf("failed to generate state: %w", err) + } + state := base64.RawURLEncoding.EncodeToString(stateBytes) + + // Build authorization URL + authURL, err := url.Parse(authMeta.AuthorizationEndpoint) + if err != nil { + return nil, fmt.Errorf("invalid authorization endpoint: %w", err) + } + + q := authURL.Query() + q.Set("response_type", "code") + q.Set("client_id", clientID) + q.Set("redirect_uri", redirectURI) + q.Set("state", state) + q.Set("code_challenge", challenge) + q.Set("code_challenge_method", "S256") + q.Set("scope", "user") + q.Set("resource", resourceURL) // RFC 8707: Resource Indicators for OAuth 2.0 + authURL.RawQuery = q.Encode() + + // Open browser for authorization + fmt.Printf("\nOpening browser for authorization...\n") + fmt.Printf("URL: %s\n\n", authURL.String()) + + if err := openBrowser(authURL.String()); err != nil { + fmt.Printf("Could not open browser automatically. Please visit the URL above.\n\n") + } + + // Wait for callback + fmt.Println("Waiting for authorization callback...") + code, returnedState, err := callbackServer.WaitForCallback(5 * time.Minute) + if err != nil { + return nil, fmt.Errorf("callback error: %w", err) + } + + if returnedState != state { + return nil, fmt.Errorf("state mismatch: expected %s, got %s", state, returnedState) + } + + fmt.Println("Authorization code received") + + // Exchange code for token + tokenURL := authMeta.TokenEndpoint + tokenReq := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {redirectURI}, + "client_id": {clientID}, + "code_verifier": {verifier}, + "resource": {resourceURL}, // RFC 8707: Resource Indicators for OAuth 2.0 + } + + // Add client secret if provided (client_secret_post method) + if clientSecret != "" { + tokenReq.Set("client_secret", clientSecret) + } + + resp, err := http.PostForm(tokenURL, tokenReq) + if err != nil { + return nil, fmt.Errorf("token exchange failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token exchange failed with status %d", resp.StatusCode) + } + + var token oauth2.Token + if err := json.NewDecoder(resp.Body).Decode(&token); err != nil { + return nil, fmt.Errorf("failed to decode token: %w", err) + } + + fmt.Println("Access token obtained") + + // Create OAuth2 config for token source + oauth2Config := &oauth2.Config{ + Endpoint: oauth2.Endpoint{ + AuthURL: authMeta.AuthorizationEndpoint, + TokenURL: authMeta.TokenEndpoint, + }, + } + + return oauth2Config.TokenSource(ctx, &token), nil +} + +// CallbackServer handles OAuth callbacks on a local HTTP server. +type CallbackServer struct { + port int + server *http.Server + + mu sync.Mutex + code string + state string + err error + resultReceived chan struct{} +} + +// NewCallbackServer creates a new callback server on the specified port. +func NewCallbackServer(port int) *CallbackServer { + return &CallbackServer{ + port: port, + resultReceived: make(chan struct{}), + } +} + +// Start starts the callback server. +func (s *CallbackServer) Start() error { + mux := http.NewServeMux() + mux.HandleFunc("/callback", s.handleCallback) + + s.server = &http.Server{ + Addr: fmt.Sprintf(":%d", s.port), + Handler: mux, + } + + go func() { + if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Printf("Callback server error: %v", err) + } + }() + + fmt.Printf("Started callback server on http://localhost:%d\n", s.port) + return nil +} + +// handleCallback handles the OAuth callback. +func (s *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + defer s.mu.Unlock() + + query := r.URL.Query() + + if code := query.Get("code"); code != "" { + s.code = code + s.state = query.Get("state") + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusOK) + w.Write([]byte(` + +
You can close this window and return to the terminal.
+ + + +`)) + close(s.resultReceived) + } else if errMsg := query.Get("error"); errMsg != "" { + s.err = fmt.Errorf("authorization error: %s", errMsg) + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(` + +Error: %s
+You can close this window and return to the terminal.
+ + +`, errMsg))) + close(s.resultReceived) + } else { + w.WriteHeader(http.StatusNotFound) + } +} + +// WaitForCallback waits for the OAuth callback with a timeout. +func (s *CallbackServer) WaitForCallback(timeout time.Duration) (code, state string, err error) { + select { + case <-s.resultReceived: + s.mu.Lock() + defer s.mu.Unlock() + return s.code, s.state, s.err + case <-time.After(timeout): + return "", "", fmt.Errorf("timeout waiting for OAuth callback") + } +} + +// Stop stops the callback server. +func (s *CallbackServer) Stop() error { + if s.server != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return s.server.Shutdown(ctx) + } + return nil +} + +// AuthClient is a simple MCP client. +type AuthClient struct { + transport mcp.Transport + session *mcp.ClientSession +} + +// NewAuthClient creates a new client with the given transport. +func NewAuthClient(transport mcp.Transport) *AuthClient { + return &AuthClient{ + transport: transport, + } +} + +// Connect connects to the MCP server. +func (c *AuthClient) Connect(ctx context.Context) error { + fmt.Println("Connecting to MCP server...") + + // Create MCP client + client := mcp.NewClient(&mcp.Implementation{ + Name: "simple-auth-client", + Version: "v1.0.0", + }, nil) + + // Connect to server + session, err := client.Connect(ctx, c.transport, nil) + if err != nil { + return fmt.Errorf("failed to connect: %w", err) + } + + c.session = session + fmt.Println("Connected to MCP server") + + return nil +} + +// ListTools lists available tools from the server. +func (c *AuthClient) ListTools(ctx context.Context) error { + if c.session == nil { + return fmt.Errorf("not connected to server") + } + + fmt.Println("\nAvailable tools:") + count := 0 + for tool, err := range c.session.Tools(ctx, nil) { + if err != nil { + return fmt.Errorf("failed to list tools: %w", err) + } + count++ + fmt.Printf("%d. %s", count, tool.Name) + if tool.Description != "" { + fmt.Printf("\n Description: %s", tool.Description) + } + fmt.Println() + } + + if count == 0 { + fmt.Println("No tools available") + } + + return nil +} + +// CallTool calls a specific tool. +func (c *AuthClient) CallTool(ctx context.Context, toolName string, arguments map[string]any) error { + if c.session == nil { + return fmt.Errorf("not connected to server") + } + + result, err := c.session.CallTool(ctx, &mcp.CallToolParams{ + Name: toolName, + Arguments: arguments, + }) + if err != nil { + return fmt.Errorf("failed to call tool '%s': %w", toolName, err) + } + + fmt.Printf("\nTool '%s' result:\n", toolName) + for _, content := range result.Content { + if textContent, ok := content.(*mcp.TextContent); ok { + fmt.Println(textContent.Text) + } else { + fmt.Printf("%+v\n", content) + } + } + + return nil +} + +// InteractiveLoop runs the interactive command loop. +func (c *AuthClient) InteractiveLoop(ctx context.Context) error { + fmt.Println("\nInteractive MCP Client") + fmt.Println("Commands:") + fmt.Println(" list - List available tools") + fmt.Println(" call