Skip to content

Commit 64d394f

Browse files
committed
Refactor to eliminate runtime type assertions
Changed tokenStore from UserTokenStore to Storage interface to avoid runtime type assertions. This makes the code cleaner and type-safe.
1 parent 41dbe72 commit 64d394f

File tree

3 files changed

+49
-23
lines changed

3 files changed

+49
-23
lines changed

internal/server/handler.go

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ type Server struct {
2323
mux *http.ServeMux
2424
config *config.Config
2525
oauthServer *oauth.Server
26-
tokenStore storage.UserTokenStore
26+
storage storage.Storage
2727
sessionManager *client.StdioSessionManager
2828
sseServers map[string]*server.SSEServer // serverName -> SSE server for stdio servers
2929
}
@@ -143,8 +143,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (*Server, error) {
143143
return nil, fmt.Errorf("failed to create OAuth server: %w", err)
144144
}
145145

146-
// Use the storage directly as token store
147-
s.tokenStore = store
146+
s.storage = store
148147

149148
// Initialize admin users if admin is enabled
150149
if cfg.Proxy.Admin != nil && cfg.Proxy.Admin.Enabled {
@@ -180,7 +179,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (*Server, error) {
180179
mux.Handle("/register", chainMiddleware(http.HandlerFunc(s.oauthServer.RegisterHandler), oauthMiddlewares...))
181180

182181
// Protected endpoints - require authentication
183-
tokenHandlers := NewTokenHandlers(s.tokenStore, cfg.MCPServers, s.oauthServer != nil)
182+
tokenHandlers := NewTokenHandlers(s.storage, cfg.MCPServers, s.oauthServer != nil)
184183
tokenMiddlewares := []MiddlewareFunc{
185184
corsMiddleware(allowedOrigins),
186185
loggerMiddleware("tokens"),
@@ -280,9 +279,8 @@ func NewServer(ctx context.Context, cfg *config.Config) (*Server, error) {
280279
}
281280
s.sessionManager.RemoveSession(key)
282281

283-
// Remove session from storage
284-
if store, ok := handler.h.tokenStore.(storage.Storage); ok {
285-
if err := store.RevokeSession(sessionCtx, session.SessionID()); err != nil {
282+
if handler.h.storage != nil {
283+
if err := handler.h.storage.RevokeSession(sessionCtx, session.SessionID()); err != nil {
286284
internal.LogWarnWithFields("server", "Failed to revoke session from storage", map[string]interface{}{
287285
"error": err.Error(),
288286
"sessionID": session.SessionID(),
@@ -321,7 +319,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (*Server, error) {
321319
handler := NewMCPHandler(
322320
serverName,
323321
serverConfig,
324-
s.tokenStore,
322+
s.storage,
325323
baseURL.String(),
326324
info,
327325
s.sessionManager,
@@ -371,12 +369,12 @@ func NewServer(ctx context.Context, cfg *config.Config) (*Server, error) {
371369
encryptionKey = oauthAuth.EncryptionKey
372370
}
373371

374-
adminHandlers := NewAdminHandlers(s.tokenStore.(storage.Storage), cfg, s.sessionManager, encryptionKey)
372+
adminHandlers := NewAdminHandlers(s.storage, cfg, s.sessionManager, encryptionKey)
375373
adminMiddlewares := []MiddlewareFunc{
376374
corsMiddleware(allowedOrigins),
377375
loggerMiddleware("admin"),
378-
s.oauthServer.SSOMiddleware(), // Browser SSO
379-
adminMiddleware(cfg.Proxy.Admin, s.tokenStore.(storage.Storage)), // Admin check
376+
s.oauthServer.SSOMiddleware(), // Browser SSO
377+
adminMiddleware(cfg.Proxy.Admin, s.storage), // Admin check
380378
}
381379

382380
// Admin routes - all protected by admin middleware
@@ -478,7 +476,7 @@ func handleSessionRegistration(
478476
handler.mcpServer,
479477
handler.userEmail,
480478
handler.config.RequiresUserToken,
481-
handler.h.tokenStore,
479+
handler.h.storage,
482480
handler.h.serverName,
483481
handler.h.setupBaseURL,
484482
handler.config.TokenSetup,
@@ -495,15 +493,15 @@ func handleSessionRegistration(
495493
}
496494

497495
if handler.userEmail != "" {
498-
if store, ok := handler.h.tokenStore.(storage.Storage); ok {
496+
if handler.h.storage != nil {
499497
activeSession := storage.ActiveSession{
500498
SessionID: session.SessionID(),
501499
UserEmail: handler.userEmail,
502500
ServerName: handler.h.serverName,
503501
Created: time.Now(),
504502
LastActive: time.Now(),
505503
}
506-
if err := store.TrackSession(sessionCtx, activeSession); err != nil {
504+
if err := handler.h.storage.TrackSession(sessionCtx, activeSession); err != nil {
507505
internal.LogWarnWithFields("server", "Failed to track session", map[string]interface{}{
508506
"error": err.Error(),
509507
"sessionID": session.SessionID(),

internal/server/mcp_handler.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ type SessionManager interface {
3131
type MCPHandler struct {
3232
serverName string
3333
serverConfig *config.MCPClientConfig
34-
tokenStore storage.UserTokenStore
34+
storage storage.Storage
3535
setupBaseURL string
3636
info mcp.Implementation
3737
sessionManager SessionManager
@@ -42,7 +42,7 @@ type MCPHandler struct {
4242
func NewMCPHandler(
4343
serverName string,
4444
serverConfig *config.MCPClientConfig,
45-
tokenStore storage.UserTokenStore,
45+
storage storage.Storage,
4646
setupBaseURL string,
4747
info mcp.Implementation,
4848
sessionManager SessionManager,
@@ -51,7 +51,7 @@ func NewMCPHandler(
5151
return &MCPHandler{
5252
serverName: serverName,
5353
serverConfig: serverConfig,
54-
tokenStore: tokenStore,
54+
storage: storage,
5555
setupBaseURL: setupBaseURL,
5656
info: info,
5757
sessionManager: sessionManager,
@@ -136,8 +136,8 @@ func (h *MCPHandler) isMessageRequest(r *http.Request) bool {
136136
// trackUserAccess tracks user access if user email is provided
137137
func (h *MCPHandler) trackUserAccess(ctx context.Context, userEmail string) {
138138
if userEmail != "" {
139-
if store, ok := h.tokenStore.(storage.Storage); ok {
140-
if err := store.UpsertUser(ctx, userEmail); err != nil {
139+
if h.storage != nil {
140+
if err := h.storage.UpsertUser(ctx, userEmail); err != nil {
141141
internal.LogWarnWithFields("mcp", "Failed to track user", map[string]any{
142142
"error": err.Error(),
143143
"user": userEmail,
@@ -263,7 +263,11 @@ func (h *MCPHandler) getUserTokenIfAvailable(ctx context.Context, userEmail stri
263263
return "", fmt.Errorf("authentication required")
264264
}
265265

266-
token, err := h.tokenStore.GetUserToken(ctx, userEmail, h.serverName)
266+
if h.storage == nil {
267+
return "", fmt.Errorf("storage not configured")
268+
}
269+
270+
token, err := h.storage.GetUserToken(ctx, userEmail, h.serverName)
267271
if err != nil {
268272
return "", err
269273
}

internal/server/mcp_handler_test.go

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
"github.com/dgellow/mcp-front/internal/client"
1616
"github.com/dgellow/mcp-front/internal/config"
1717
"github.com/dgellow/mcp-front/internal/jsonrpc"
18-
"github.com/dgellow/mcp-front/internal/testutil"
18+
"github.com/dgellow/mcp-front/internal/storage"
1919
"github.com/mark3labs/mcp-go/mcp"
2020
"github.com/stretchr/testify/assert"
2121
"github.com/stretchr/testify/mock"
@@ -26,6 +26,28 @@ type mockSessionManager struct {
2626
mock.Mock
2727
}
2828

29+
type mockStorage struct {
30+
*storage.MemoryStorage
31+
mock.Mock
32+
}
33+
34+
// Override only the methods we want to mock
35+
func (m *mockStorage) GetUserToken(ctx context.Context, userEmail, service string) (string, error) {
36+
if m.Mock.ExpectedCalls != nil {
37+
args := m.Called(ctx, userEmail, service)
38+
return args.String(0), args.Error(1)
39+
}
40+
return m.MemoryStorage.GetUserToken(ctx, userEmail, service)
41+
}
42+
43+
func (m *mockStorage) UpsertUser(ctx context.Context, email string) error {
44+
if m.Mock.ExpectedCalls != nil {
45+
args := m.Called(ctx, email)
46+
return args.Error(0)
47+
}
48+
return m.MemoryStorage.UpsertUser(ctx, email)
49+
}
50+
2951
func (m *mockSessionManager) GetSession(key client.SessionKey) (*client.StdioSession, bool) {
3052
args := m.Called(key)
3153
if args.Get(0) == nil {
@@ -52,14 +74,16 @@ func (m *mockSessionManager) Shutdown() {
5274

5375
// Test helper to create MCPHandler for SSE tests
5476
func createTestMCPHandler(serverName string, config *config.MCPClientConfig) *MCPHandler {
55-
tokenStore := new(testutil.MockUserTokenStore)
77+
mockStore := &mockStorage{
78+
MemoryStorage: storage.NewMemoryStorage(),
79+
}
5680
sessionManager := new(mockSessionManager)
5781
info := mcp.Implementation{Name: "test", Version: "1.0"}
5882

5983
return NewMCPHandler(
6084
serverName,
6185
config,
62-
tokenStore,
86+
mockStore,
6387
"http://localhost:8080",
6488
info,
6589
sessionManager,

0 commit comments

Comments
 (0)