Skip to content

Commit c2bba04

Browse files
committed
Add OAuth token storage support
1 parent 1b29b93 commit c2bba04

File tree

7 files changed

+173
-40
lines changed

7 files changed

+173
-40
lines changed

internal/server/mcp_handler.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,24 +295,36 @@ func (h *MCPHandler) getUserTokenIfAvailable(ctx context.Context, userEmail stri
295295
return "", fmt.Errorf("storage not configured")
296296
}
297297

298-
token, err := h.storage.GetUserToken(ctx, userEmail, h.serverName)
298+
storedToken, err := h.storage.GetUserToken(ctx, userEmail, h.serverName)
299299
if err != nil {
300300
return "", err
301301
}
302302

303-
// Validate token format if configured
303+
// Extract the actual token string based on type
304+
var tokenString string
305+
switch storedToken.Type {
306+
case storage.TokenTypeManual:
307+
tokenString = storedToken.Value
308+
case storage.TokenTypeOAuth:
309+
if storedToken.OAuthData != nil {
310+
tokenString = storedToken.OAuthData.AccessToken
311+
}
312+
}
313+
314+
// Validate token format if configured (only for manual tokens)
304315
if h.serverConfig.UserAuthentication != nil &&
305316
h.serverConfig.UserAuthentication.Type == config.UserAuthTypeManual &&
306-
h.serverConfig.UserAuthentication.ValidationRegex != nil {
307-
if !h.serverConfig.UserAuthentication.ValidationRegex.MatchString(token) {
317+
h.serverConfig.UserAuthentication.ValidationRegex != nil &&
318+
storedToken.Type == storage.TokenTypeManual {
319+
if !h.serverConfig.UserAuthentication.ValidationRegex.MatchString(tokenString) {
308320
log.LogWarnWithFields("mcp", "User token doesn't match expected format", map[string]any{
309321
"user": userEmail,
310322
"service": h.serverName,
311323
})
312324
}
313325
}
314326

315-
return token, nil
327+
return tokenString, nil
316328
}
317329

318330
func (h *MCPHandler) forwardMessageToBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, config *config.MCPClientConfig) {

internal/server/mcp_handler_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,13 @@ type mockStorage struct {
3232
}
3333

3434
// Override only the methods we want to mock
35-
func (m *mockStorage) GetUserToken(ctx context.Context, userEmail, service string) (string, error) {
35+
func (m *mockStorage) GetUserToken(ctx context.Context, userEmail, service string) (*storage.StoredToken, error) {
3636
if m.Mock.ExpectedCalls != nil {
3737
args := m.Called(ctx, userEmail, service)
38-
return args.String(0), args.Error(1)
38+
if args.Get(0) == nil {
39+
return nil, args.Error(1)
40+
}
41+
return args.Get(0).(*storage.StoredToken), args.Error(1)
3942
}
4043
return m.MemoryStorage.GetUserToken(ctx, userEmail, service)
4144
}

internal/server/token_handlers.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"net/http"
66
"strings"
77
"sync"
8+
"time"
89

910
"github.com/dgellow/mcp-front/internal/config"
1011
"github.com/dgellow/mcp-front/internal/crypto"
@@ -226,7 +227,14 @@ func (h *TokenHandlers) SetTokenHandler(w http.ResponseWriter, r *http.Request)
226227
}
227228
}
228229

229-
if err := h.tokenStore.SetUserToken(r.Context(), userEmail, serviceName, token); err != nil {
230+
// Create StoredToken for manual entry
231+
storedToken := &storage.StoredToken{
232+
Type: storage.TokenTypeManual,
233+
Value: token,
234+
UpdatedAt: time.Now(),
235+
}
236+
237+
if err := h.tokenStore.SetUserToken(r.Context(), userEmail, serviceName, storedToken); err != nil {
230238
log.LogErrorWithFields("token", "Failed to store token", map[string]interface{}{
231239
"error": err.Error(),
232240
"user": userEmail,

internal/storage/firestore.go

Lines changed: 96 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ var _ fosite.Storage = (*FirestoreStorage)(nil)
3434

3535
// UserTokenDoc represents a user token document in Firestore
3636
type UserTokenDoc struct {
37-
UserEmail string `firestore:"user_email"`
38-
Service string `firestore:"service"`
39-
Token string `firestore:"token"` // Encrypted
40-
UpdatedAt time.Time `firestore:"updated_at"`
37+
UserEmail string `firestore:"user_email"`
38+
Service string `firestore:"service"`
39+
Type TokenType `firestore:"type"`
40+
Value string `firestore:"value,omitempty"` // Encrypted manual token
41+
OAuthData *OAuthTokenData `firestore:"oauth_data,omitempty"` // OAuth metadata (tokens encrypted)
42+
UpdatedAt time.Time `firestore:"updated_at"`
4143
}
4244

4345
// OAuthClientEntity represents the structure stored in Firestore
@@ -361,47 +363,121 @@ func (s *FirestoreStorage) makeUserTokenDocID(userEmail, service string) string
361363
}
362364

363365
// GetUserToken retrieves a user's token for a specific service
364-
func (s *FirestoreStorage) GetUserToken(ctx context.Context, userEmail, service string) (string, error) {
366+
func (s *FirestoreStorage) GetUserToken(ctx context.Context, userEmail, service string) (*StoredToken, error) {
365367
docID := s.makeUserTokenDocID(userEmail, service)
366368
doc, err := s.client.Collection(s.tokenCollection).Doc(docID).Get(ctx)
367369
if err != nil {
368370
if status.Code(err) == codes.NotFound {
369-
return "", ErrUserTokenNotFound
371+
return nil, ErrUserTokenNotFound
370372
}
371-
return "", fmt.Errorf("failed to get token from Firestore: %w", err)
373+
return nil, fmt.Errorf("failed to get token from Firestore: %w", err)
372374
}
373375

374376
var tokenDoc UserTokenDoc
375377
if err := doc.DataTo(&tokenDoc); err != nil {
376-
return "", fmt.Errorf("failed to unmarshal token: %w", err)
378+
return nil, fmt.Errorf("failed to unmarshal token: %w", err)
377379
}
378380

379-
// Decrypt the token
380-
decrypted, err := s.encryptor.Decrypt(tokenDoc.Token)
381-
if err != nil {
382-
return "", fmt.Errorf("failed to decrypt token: %w", err)
381+
// Build StoredToken
382+
storedToken := &StoredToken{
383+
Type: tokenDoc.Type,
384+
UpdatedAt: tokenDoc.UpdatedAt,
385+
}
386+
387+
// Decrypt based on type
388+
switch tokenDoc.Type {
389+
case TokenTypeManual:
390+
if tokenDoc.Value != "" {
391+
decrypted, err := s.encryptor.Decrypt(tokenDoc.Value)
392+
if err != nil {
393+
return nil, fmt.Errorf("failed to decrypt manual token: %w", err)
394+
}
395+
storedToken.Value = decrypted
396+
}
397+
case TokenTypeOAuth:
398+
if tokenDoc.OAuthData != nil {
399+
// Decrypt OAuth tokens
400+
decryptedAccess, err := s.encryptor.Decrypt(tokenDoc.OAuthData.AccessToken)
401+
if err != nil {
402+
return nil, fmt.Errorf("failed to decrypt access token: %w", err)
403+
}
404+
405+
oauthData := &OAuthTokenData{
406+
AccessToken: decryptedAccess,
407+
TokenType: tokenDoc.OAuthData.TokenType,
408+
ExpiresAt: tokenDoc.OAuthData.ExpiresAt,
409+
Scopes: tokenDoc.OAuthData.Scopes,
410+
}
411+
412+
if tokenDoc.OAuthData.RefreshToken != "" {
413+
decryptedRefresh, err := s.encryptor.Decrypt(tokenDoc.OAuthData.RefreshToken)
414+
if err != nil {
415+
return nil, fmt.Errorf("failed to decrypt refresh token: %w", err)
416+
}
417+
oauthData.RefreshToken = decryptedRefresh
418+
}
419+
420+
storedToken.OAuthData = oauthData
421+
}
383422
}
384423

385-
return decrypted, nil
424+
return storedToken, nil
386425
}
387426

388427
// SetUserToken stores or updates a user's token for a specific service
389-
func (s *FirestoreStorage) SetUserToken(ctx context.Context, userEmail, service, token string) error {
390-
// Encrypt the token before storing
391-
encrypted, err := s.encryptor.Encrypt(token)
392-
if err != nil {
393-
return fmt.Errorf("failed to encrypt token: %w", err)
428+
func (s *FirestoreStorage) SetUserToken(ctx context.Context, userEmail, service string, token *StoredToken) error {
429+
if token == nil {
430+
return fmt.Errorf("token cannot be nil")
394431
}
395432

396433
docID := s.makeUserTokenDocID(userEmail, service)
397434
tokenDoc := UserTokenDoc{
398435
UserEmail: userEmail,
399436
Service: service,
400-
Token: encrypted,
437+
Type: token.Type,
401438
UpdatedAt: time.Now(),
402439
}
403440

404-
_, err = s.client.Collection(s.tokenCollection).Doc(docID).Set(ctx, tokenDoc)
441+
// Encrypt based on type
442+
switch token.Type {
443+
case TokenTypeManual:
444+
if token.Value != "" {
445+
encrypted, err := s.encryptor.Encrypt(token.Value)
446+
if err != nil {
447+
return fmt.Errorf("failed to encrypt manual token: %w", err)
448+
}
449+
tokenDoc.Value = encrypted
450+
}
451+
case TokenTypeOAuth:
452+
if token.OAuthData != nil {
453+
// Encrypt OAuth tokens
454+
encryptedAccess, err := s.encryptor.Encrypt(token.OAuthData.AccessToken)
455+
if err != nil {
456+
return fmt.Errorf("failed to encrypt access token: %w", err)
457+
}
458+
459+
oauthData := &OAuthTokenData{
460+
AccessToken: encryptedAccess,
461+
TokenType: token.OAuthData.TokenType,
462+
ExpiresAt: token.OAuthData.ExpiresAt,
463+
Scopes: token.OAuthData.Scopes,
464+
}
465+
466+
if token.OAuthData.RefreshToken != "" {
467+
encryptedRefresh, err := s.encryptor.Encrypt(token.OAuthData.RefreshToken)
468+
if err != nil {
469+
return fmt.Errorf("failed to encrypt refresh token: %w", err)
470+
}
471+
oauthData.RefreshToken = encryptedRefresh
472+
}
473+
474+
tokenDoc.OAuthData = oauthData
475+
}
476+
default:
477+
return fmt.Errorf("unknown token type: %s", token.Type)
478+
}
479+
480+
_, err := s.client.Collection(s.tokenCollection).Doc(docID).Set(ctx, tokenDoc)
405481
if err != nil {
406482
return fmt.Errorf("failed to store token in Firestore: %w", err)
407483
}

internal/storage/memory.go

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package storage
22

33
import (
44
"context"
5+
"fmt"
56
"strings"
67
"sync"
78
"time"
@@ -19,9 +20,9 @@ var _ fosite.Storage = (*MemoryStorage)(nil)
1920
// It extends the MemoryStore with thread-safe client management
2021
type MemoryStorage struct {
2122
*storage.MemoryStore
22-
stateCache sync.Map // map[string]fosite.AuthorizeRequester
23-
clientsMutex sync.RWMutex // For thread-safe client access
24-
userTokens map[string]string // map["email:service"] = token
23+
stateCache sync.Map // map[string]fosite.AuthorizeRequester
24+
clientsMutex sync.RWMutex // For thread-safe client access
25+
userTokens map[string]*StoredToken // map["email:service"] = token
2526
userTokensMutex sync.RWMutex
2627
users map[string]*UserInfo // map[email] = UserInfo
2728
usersMutex sync.RWMutex
@@ -33,7 +34,7 @@ type MemoryStorage struct {
3334
func NewMemoryStorage() *MemoryStorage {
3435
return &MemoryStorage{
3536
MemoryStore: storage.NewMemoryStore(),
36-
userTokens: make(map[string]string),
37+
userTokens: make(map[string]*StoredToken),
3738
users: make(map[string]*UserInfo),
3839
sessions: make(map[string]*ActiveSession),
3940
}
@@ -140,20 +141,24 @@ func (s *MemoryStorage) makeUserTokenKey(userEmail, service string) string {
140141
}
141142

142143
// GetUserToken retrieves a user's token for a specific service
143-
func (s *MemoryStorage) GetUserToken(ctx context.Context, userEmail, service string) (string, error) {
144+
func (s *MemoryStorage) GetUserToken(ctx context.Context, userEmail, service string) (*StoredToken, error) {
144145
s.userTokensMutex.RLock()
145146
defer s.userTokensMutex.RUnlock()
146147

147148
key := s.makeUserTokenKey(userEmail, service)
148149
token, exists := s.userTokens[key]
149150
if !exists {
150-
return "", ErrUserTokenNotFound
151+
return nil, ErrUserTokenNotFound
151152
}
152153
return token, nil
153154
}
154155

155156
// SetUserToken stores or updates a user's token for a specific service
156-
func (s *MemoryStorage) SetUserToken(ctx context.Context, userEmail, service, token string) error {
157+
func (s *MemoryStorage) SetUserToken(ctx context.Context, userEmail, service string, token *StoredToken) error {
158+
if token == nil {
159+
return fmt.Errorf("token cannot be nil")
160+
}
161+
157162
s.userTokensMutex.Lock()
158163
defer s.userTokensMutex.Unlock()
159164

internal/storage/storage.go

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,31 @@ type UserInfo struct {
2727
IsAdmin bool `json:"is_admin"`
2828
}
2929

30+
// TokenType represents the type of stored token
31+
type TokenType string
32+
33+
const (
34+
TokenTypeManual TokenType = "manual"
35+
TokenTypeOAuth TokenType = "oauth"
36+
)
37+
38+
// OAuthTokenData represents OAuth token metadata
39+
type OAuthTokenData struct {
40+
AccessToken string `json:"access_token"`
41+
RefreshToken string `json:"refresh_token,omitempty"`
42+
TokenType string `json:"token_type,omitempty"`
43+
ExpiresAt time.Time `json:"expires_at,omitempty"`
44+
Scopes []string `json:"scopes,omitempty"`
45+
}
46+
47+
// StoredToken represents a token with its metadata
48+
type StoredToken struct {
49+
Type TokenType `json:"type"`
50+
Value string `json:"value,omitempty"` // For manual tokens
51+
OAuthData *OAuthTokenData `json:"oauth,omitempty"` // For OAuth tokens
52+
UpdatedAt time.Time `json:"updated_at"`
53+
}
54+
3055
// ActiveSession represents an active MCP session
3156
type ActiveSession struct {
3257
SessionID string `json:"session_id"`
@@ -40,8 +65,8 @@ type ActiveSession struct {
4065
// This interface is used by handlers that need to access user-specific tokens
4166
// for external services (e.g., Notion, GitHub).
4267
type UserTokenStore interface {
43-
GetUserToken(ctx context.Context, userEmail, service string) (string, error)
44-
SetUserToken(ctx context.Context, userEmail, service, token string) error
68+
GetUserToken(ctx context.Context, userEmail, service string) (*StoredToken, error)
69+
SetUserToken(ctx context.Context, userEmail, service string, token *StoredToken) error
4570
DeleteUserToken(ctx context.Context, userEmail, service string) error
4671
ListUserServices(ctx context.Context, userEmail string) ([]string, error)
4772
}

internal/testutil/mocks.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package testutil
33
import (
44
"context"
55

6+
"github.com/dgellow/mcp-front/internal/storage"
67
"github.com/mark3labs/mcp-go/mcp"
78
"github.com/mark3labs/mcp-go/server"
89
"github.com/stretchr/testify/mock"
@@ -157,12 +158,15 @@ type MockUserTokenStore struct {
157158
mock.Mock
158159
}
159160

160-
func (m *MockUserTokenStore) GetUserToken(ctx context.Context, userEmail, serverName string) (string, error) {
161+
func (m *MockUserTokenStore) GetUserToken(ctx context.Context, userEmail, serverName string) (*storage.StoredToken, error) {
161162
args := m.Called(ctx, userEmail, serverName)
162-
return args.String(0), args.Error(1)
163+
if args.Get(0) == nil {
164+
return nil, args.Error(1)
165+
}
166+
return args.Get(0).(*storage.StoredToken), args.Error(1)
163167
}
164168

165-
func (m *MockUserTokenStore) SetUserToken(ctx context.Context, userEmail, serverName, token string) error {
169+
func (m *MockUserTokenStore) SetUserToken(ctx context.Context, userEmail, serverName string, token *storage.StoredToken) error {
166170
args := m.Called(ctx, userEmail, serverName, token)
167171
return args.Error(0)
168172
}

0 commit comments

Comments
 (0)