diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..bbc2970 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,79 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Debug MCP Server (tsx watch)", + "type": "node", + "request": "launch", + "program": "${workspaceFolder}/node_modules/.bin/tsx", + "args": ["watch", "src/index.ts"], + "console": "integratedTerminal", + "restart": true, + "skipFiles": ["/**"], + "env": { + "NODE_ENV": "development" + } + }, + { + "name": "Debug Jest Tests", + "type": "node", + "request": "launch", + "runtimeArgs": [ + "--experimental-vm-modules", + "--inspect-brk", + "${workspaceFolder}/node_modules/.bin/jest", + "--runInBand" + ], + "console": "integratedTerminal", + "internalConsoleOptions": "neverOpen", + "skipFiles": ["/**"], + "env": { + "NODE_OPTIONS": "--experimental-vm-modules" + } + }, + { + "name": "Debug Current Jest Test File", + "type": "node", + "request": "launch", + "runtimeArgs": [ + "--experimental-vm-modules", + "--inspect-brk", + "${workspaceFolder}/node_modules/.bin/jest", + "--runInBand", + "${relativeFile}" + ], + "console": "integratedTerminal", + "internalConsoleOptions": "neverOpen", + "skipFiles": ["/**"], + "env": { + "NODE_OPTIONS": "--experimental-vm-modules" + } + }, + { + "name": "Debug Jest Test by Pattern", + "type": "node", + "request": "launch", + "runtimeArgs": [ + "--experimental-vm-modules", + "--inspect-brk", + "${workspaceFolder}/node_modules/.bin/jest", + "--runInBand", + "--testNamePattern", + "${input:testNamePattern}" + ], + "console": "integratedTerminal", + "internalConsoleOptions": "neverOpen", + "skipFiles": ["/**"], + "env": { + "NODE_OPTIONS": "--experimental-vm-modules" + } + } + ], + "inputs": [ + { + "id": "testNamePattern", + "type": "promptString", + "description": "Test name pattern to match (e.g., 'should trigger onsessionclosed')" + } + ] +} \ No newline at end of file diff --git a/docs/streamable-http-design.md b/docs/streamable-http-design.md new file mode 100644 index 0000000..2d31081 --- /dev/null +++ b/docs/streamable-http-design.md @@ -0,0 +1,444 @@ +# Design Document: Implementing Streamable HTTP Transport for Example Remote Server + +## Research Summary + +### Current SSE Transport Architecture + +The example remote server currently uses the following architecture: + +1. **SSE Endpoint**: `/sse` - Creates SSE connection using `SSEServerTransport` +2. **Message Endpoint**: `/message` - Receives POST requests and forwards them via Redis +3. **Redis Integration**: Messages are published/subscribed through Redis channels using session IDs +4. **Auth**: Uses `requireBearerAuth` middleware with `EverythingAuthProvider` +5. **Session Management**: Each SSE connection gets a unique session ID used as Redis channel key + +**Key Files:** +- `/src/index.ts:91` - SSE endpoint with auth and headers +- `/src/handlers/mcp.ts:55-118` - SSE connection handler with Redis integration +- `/src/handlers/mcp.ts:120-144` - Message POST handler + +### Streamable HTTP Transport Specification (2025-03-26) + +The new Streamable HTTP transport replaces the old HTTP+SSE approach with a single endpoint that supports: + +1. **Single Endpoint**: One URL that handles GET, POST, and DELETE methods +2. **POST Requests**: Send JSON-RPC messages, can return either JSON responses or SSE streams +3. **GET Requests**: Open SSE streams for server-to-client messages +4. **Session Management**: Optional session IDs in `Mcp-Session-Id` headers +5. **Resumability**: Optional event storage with `Last-Event-ID` support +6. **Auth Integration**: Same authentication patterns as SSE + +**Key Specification Requirements:** +- Accept header must include both `application/json` and `text/event-stream` +- Session ID management via `Mcp-Session-Id` headers +- 202 Accepted for notifications/responses only +- SSE streams or JSON responses for requests +- Security: Origin validation, localhost binding, proper auth + +### TypeScript SDK Implementation + +The SDK provides `StreamableHTTPServerTransport` with: + +1. **Two Modes**: + - **Stateful**: Session ID generator provided, maintains sessions in memory + - **Stateless**: Session ID generator undefined, no session state + +2. **Key Features**: + - Built-in session validation + - Event store support for resumability + - Automatic response correlation + - Auth info threading via `req.auth` + +3. **Integration Patterns**: + - **Stateful**: Store transports by session ID, reuse across requests + - **Stateless**: New transport per request, immediate cleanup + - **Auth**: Same bearer auth middleware as SSE + +## Implementation Plan + +### 1. New Streamable HTTP Endpoint + +Add `/mcp` endpoint that handles GET, POST, DELETE methods: + +```typescript +// In src/index.ts +app.get("/mcp", cors(corsOptions), bearerAuth, authContext, handleStreamableHTTP); +app.post("/mcp", cors(corsOptions), bearerAuth, authContext, handleStreamableHTTP); +app.delete("/mcp", cors(corsOptions), bearerAuth, authContext, handleStreamableHTTP); +``` + +### 2. Handler Implementation + +Create new handler in `/src/handlers/mcp.ts`: + +```typescript +export async function handleStreamableHTTP(req: Request, res: Response) { + // Use same Redis-based architecture as SSE transport + // but with StreamableHTTPServerTransport instead of SSEServerTransport +} +``` + +### 3. Transport Integration Strategy (Horizontally Scalable) + +**Redis-based Session Management (Required for Horizontal Scaling)** +- Store session state in Redis, not in-memory +- Any server instance can handle any request for any session +- Session lifecycle independent of SSE connection lifecycle +- Message buffering in Redis when SSE connection is down +- Session TTL of 5 minutes to prevent Redis bloat + +### 4. Redis Integration + +Maintain current Redis architecture: +- Use session ID as Redis channel key +- Same message publishing/subscribing pattern +- Same MCP server creation logic +- Transport acts as bridge to Redis like current SSE implementation + +### 5. Auth Integration + +Use identical auth setup as SSE: +- Same `bearerAuth` middleware +- Same `authContext` middleware +- Same `EverythingAuthProvider` +- Auth info flows through `req.auth` to transport + +### 6. Backwards Compatibility + +Keep existing `/sse` and `/message` endpoints: +- Maintain current SSE transport for existing clients +- Add new `/mcp` endpoint alongside +- Both transports share same Redis infrastructure +- Same auth provider serves both + +## Key Differences from Current SSE Implementation + +1. **Single Endpoint**: `/mcp` handles all HTTP methods vs separate `/sse` + `/message` +2. **Transport Class**: `StreamableHTTPServerTransport` vs `SSEServerTransport` +3. **Session Headers**: `Mcp-Session-Id` headers vs URL session ID +4. **Request Handling**: Transport handles HTTP details vs manual SSE headers +5. **Response Correlation**: Built into transport vs manual request tracking + +## Benefits of This Approach + +1. **Spec Compliance**: Follows 2025-03-26 MCP specification exactly +2. **Minimal Changes**: Reuses existing Redis infrastructure and auth +3. **Feature Parity**: Same functionality as current SSE transport +4. **Future Proof**: Can add resumability with event store later +5. **Clean Integration**: Same auth patterns and middleware stack + +## Implementation Steps + +1. **Add Dependencies**: `StreamableHTTPServerTransport` from SDK +2. **Create Redis Session Management**: Implement `SessionManager` and `MessageDelivery` classes +3. **Create Handler**: New streamable HTTP handler function with Redis integration +4. **Add Routes**: New `/mcp` endpoint with all HTTP methods +5. **Session Management**: Redis-based session storage with TTL +6. **Message Buffering**: Redis-based message buffering for disconnected clients +7. **Testing**: Verify auth, Redis integration, horizontal scaling, and MCP protocol compliance +8. **Documentation**: Update README with new endpoint usage + +### Additional Implementation Considerations + +- **Redis Connection Management**: Ensure Redis connections are properly pooled and cleaned up +- **Error Handling**: Robust error handling for Redis operations and session timeouts +- **Monitoring**: Add logging for session creation, cleanup, and message buffering metrics +- **Performance**: Consider Redis memory usage and implement appropriate limits on message buffer size +- **Security**: Ensure session IDs are cryptographically secure and validate all session operations + +## Technical Details + +### Session Management Architecture (Redis-based) + +**Redis Data Structures for Horizontal Scaling:** + +```typescript +// Redis keys for session management +const SESSION_METADATA_KEY = (sessionId: string) => `session:${sessionId}:metadata`; +const SESSION_MESSAGES_KEY = (sessionId: string) => `session:${sessionId}:messages`; +const SESSION_CONNECTION_KEY = (sessionId: string) => `session:${sessionId}:connection`; + +// Session metadata structure +interface SessionMetadata { + sessionId: string; + clientId: string; + createdAt: number; + lastActivity: number; +} + +// Session lifecycle management +class SessionManager { + private static SESSION_TTL = 5 * 60; // 5 minutes in seconds + + static async createSession(sessionId: string, clientId: string): Promise { + const metadata: SessionMetadata = { + sessionId, + clientId, + createdAt: Date.now(), + lastActivity: Date.now() + }; + + // Store session metadata with TTL + await redisClient.set( + SESSION_METADATA_KEY(sessionId), + JSON.stringify(metadata), + { EX: this.SESSION_TTL } + ); + + // Initialize empty message buffer + await redisClient.del(SESSION_MESSAGES_KEY(sessionId)); + + // Mark connection as disconnected initially + await redisClient.set(SESSION_CONNECTION_KEY(sessionId), 'disconnected', { EX: this.SESSION_TTL }); + } + + static async refreshSession(sessionId: string): Promise { + const metadata = await this.getSessionMetadata(sessionId); + if (!metadata) return false; + + // Update last activity and refresh TTL + metadata.lastActivity = Date.now(); + await redisClient.set( + SESSION_METADATA_KEY(sessionId), + JSON.stringify(metadata), + { EX: this.SESSION_TTL } + ); + + // Refresh other keys too + await redisClient.expire(SESSION_MESSAGES_KEY(sessionId), this.SESSION_TTL); + await redisClient.expire(SESSION_CONNECTION_KEY(sessionId), this.SESSION_TTL); + + return true; + } + + static async deleteSession(sessionId: string): Promise { + await redisClient.del(SESSION_METADATA_KEY(sessionId)); + await redisClient.del(SESSION_MESSAGES_KEY(sessionId)); + await redisClient.del(SESSION_CONNECTION_KEY(sessionId)); + } + + static async getSessionMetadata(sessionId: string): Promise { + const data = await redisClient.get(SESSION_METADATA_KEY(sessionId)); + return data ? JSON.parse(data) : null; + } + + // Mark SSE connection as connected/disconnected + static async setConnectionState(sessionId: string, connected: boolean): Promise { + await redisClient.set( + SESSION_CONNECTION_KEY(sessionId), + connected ? 'connected' : 'disconnected', + { EX: this.SESSION_TTL } + ); + } + + static async isConnected(sessionId: string): Promise { + const state = await redisClient.get(SESSION_CONNECTION_KEY(sessionId)); + return state === 'connected'; + } +} +``` + +### Redis Integration Pattern with Message Buffering + +The implementation extends the current Redis pattern to support message buffering: + +```typescript +// Message delivery with buffering support +class MessageDelivery { + static async deliverMessage(sessionId: string, message: JSONRPCMessage): Promise { + const isConnected = await SessionManager.isConnected(sessionId); + + if (isConnected) { + // Direct delivery via existing Redis pub/sub + const redisChannel = `mcp:${sessionId}`; + await redisClient.publish(redisChannel, JSON.stringify(message)); + } else { + // Buffer the message for later delivery + await redisClient.lpush( + SESSION_MESSAGES_KEY(sessionId), + JSON.stringify(message) + ); + // Set TTL on the messages list + await redisClient.expire(SESSION_MESSAGES_KEY(sessionId), SessionManager.SESSION_TTL); + } + } + + static async deliverBufferedMessages(sessionId: string, transport: StreamableHTTPServerTransport): Promise { + // Get all buffered messages + const bufferedMessages = await redisClient.lrange(SESSION_MESSAGES_KEY(sessionId), 0, -1); + + // Deliver buffered messages in order (reverse because lpush) + for (let i = bufferedMessages.length - 1; i >= 0; i--) { + const message = JSON.parse(bufferedMessages[i]); + await transport.send(message); + } + + // Clear the buffer after delivery + await redisClient.del(SESSION_MESSAGES_KEY(sessionId)); + } +} + +// Enhanced Redis subscription for SSE connections +const setupRedisSubscription = async (sessionId: string, transport: StreamableHTTPServerTransport) => { + const redisChannel = `mcp:${sessionId}`; + + const redisCleanup = await redisClient.createSubscription( + redisChannel, + async (message) => { + const jsonMessage = JSON.parse(message); + try { + await transport.send(jsonMessage); + } catch (error) { + console.error(`Failed to send message on transport for session ${sessionId}:`, error); + // Mark connection as disconnected so future messages get buffered + await SessionManager.setConnectionState(sessionId, false); + } + }, + async (error) => { + console.error('Redis subscription error:', error); + await SessionManager.setConnectionState(sessionId, false); + } + ); + + return redisCleanup; +}; +``` + +### Handler Implementation Flow + +The new streamable HTTP handler integrates with the Redis-based session management: + +```typescript +export async function handleStreamableHTTP(req: Request, res: Response) { + const method = req.method; + const sessionId = req.headers['mcp-session-id'] as string | undefined; + + if (method === 'POST') { + // Handle POST requests (initialization or message sending) + + if (isInitializeRequest(req.body) && !sessionId) { + // New session initialization + const newSessionId = randomUUID(); + const authInfo = req.auth; + + // Create session in Redis + await SessionManager.createSession(newSessionId, authInfo?.clientId || 'unknown'); + + // Create transport with Redis-based session management + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => newSessionId, + // Custom implementation - don't store transport in memory + }); + + const { server: mcpServer, cleanup: mcpCleanup } = createMcpServer(); + + // Set up Redis subscription for this session but don't store transport globally + // Instead, rely on Redis for all message routing + + await mcpServer.connect(transport); + await transport.handleRequest(req, res, req.body); + + } else if (sessionId) { + // Existing session - validate and handle request + const sessionValid = await SessionManager.refreshSession(sessionId); + if (!sessionValid) { + res.writeHead(404).end(JSON.stringify({ + jsonrpc: "2.0", + error: { code: -32001, message: "Session not found" }, + id: null + })); + return; + } + + // Create ephemeral transport for this request + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => sessionId, // Use existing session ID + }); + + const { server: mcpServer, cleanup: mcpCleanup } = createMcpServer(); + await mcpServer.connect(transport); + await transport.handleRequest(req, res, req.body); + + // Clean up after request completes + res.on('close', mcpCleanup); + } + + } else if (method === 'GET') { + // Handle SSE stream requests + + if (!sessionId) { + res.writeHead(400).end('Session ID required'); + return; + } + + const sessionValid = await SessionManager.refreshSession(sessionId); + if (!sessionValid) { + res.writeHead(404).end('Session not found'); + return; + } + + // Create transport for SSE stream + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => sessionId, + }); + + // Mark connection as active + await SessionManager.setConnectionState(sessionId, true); + + // Deliver any buffered messages first + await MessageDelivery.deliverBufferedMessages(sessionId, transport); + + // Set up Redis subscription for live messages + const redisCleanup = await setupRedisSubscription(sessionId, transport); + + // Handle connection cleanup + res.on('close', async () => { + await SessionManager.setConnectionState(sessionId, false); + redisCleanup(); + }); + + await transport.handleRequest(req, res); + + } else if (method === 'DELETE') { + // Handle session deletion + + if (!sessionId) { + res.writeHead(400).end('Session ID required'); + return; + } + + // Delete session from Redis + await SessionManager.deleteSession(sessionId); + res.writeHead(200).end(); + } +} +``` + +### Auth Information Flow + +Auth information flows through the middleware stack: + +```typescript +// Auth middleware adds req.auth +const authInfo: AuthInfo = req.auth; + +// Transport receives auth info +await transport.handleRequest(req, res); + +// Auth info is available in MCP server handlers +server.tool('example', 'description', schema, async (params, { authInfo }) => { + // authInfo contains token, clientId, scopes, etc. +}); +``` + +## Conclusion + +This design provides horizontally scalable streamable HTTP support by using Redis for all session state management and message buffering. Key advantages: + +1. **Horizontal Scalability**: Any server instance can handle any request for any session +2. **Resilient Connection Handling**: SSE disconnects don't end sessions; messages are buffered +3. **Automatic Cleanup**: 5-minute session TTL prevents Redis bloat +4. **Backwards Compatibility**: Existing `/sse` and `/message` endpoints remain unchanged +5. **Spec Compliance**: Follows 2025-03-26 MCP specification exactly + +The implementation is more complex than a single-instance approach but essential for production deployment in a horizontally scaled environment. The Redis-based architecture ensures sessions persist across server instances and SSE connection interruptions. \ No newline at end of file diff --git a/docs/user-id-system.md b/docs/user-id-system.md new file mode 100644 index 0000000..eb4daff --- /dev/null +++ b/docs/user-id-system.md @@ -0,0 +1,392 @@ +# User ID System Documentation + +## Overview + +The MCP server implements a comprehensive user identification and session ownership system that ensures secure multi-user access to MCP resources. This system integrates localStorage-based user management, OAuth authentication flows, and Redis-backed session isolation. + +## Architecture Components + +### 1. User ID Management (localStorage) +### 2. OAuth Authorization Flow +### 3. Redis Session Ownership +### 4. Session Access Validation + +--- + +## 1. User ID Management (localStorage) + +The fake upstream authentication system uses browser localStorage to manage user identities for testing and development purposes. + +### localStorage Schema + +```typescript +// Stored in browser localStorage +{ + "mcpUserId": "550e8400-e29b-41d4-a716-446655440000" // UUID v4 +} +``` + +### User ID Generation Flow + +```mermaid +sequenceDiagram + participant Browser + participant LocalStorage + participant AuthPage as Fake Auth Page + + Browser->>AuthPage: Load /fakeupstreamauth/authorize + AuthPage->>LocalStorage: Check mcpUserId + + alt User ID exists + LocalStorage-->>AuthPage: Return existing UUID + AuthPage->>Browser: Display existing ID + else User ID missing + AuthPage->>AuthPage: Generate new UUID v4 + AuthPage->>LocalStorage: Store new mcpUserId + AuthPage->>Browser: Display new ID + end + + Note over AuthPage: User can edit or regenerate ID + AuthPage->>LocalStorage: Update mcpUserId (if changed) +``` + +### User ID Operations + +| Operation | Description | Implementation | +|-----------|-------------|----------------| +| **Generate** | Create new UUID v4 | `generateUUID()` function | +| **Retrieve** | Get existing or create new | `getUserId()` function | +| **Update** | Edit existing ID | `editUserId()` function | +| **Persist** | Store in localStorage | `localStorage.setItem('mcpUserId', userId)` | + +--- + +## 2. OAuth Authorization Flow + +The OAuth flow integrates user IDs from localStorage into the MCP authorization process. + +### Complete OAuth Flow with User ID + +```mermaid +sequenceDiagram + participant Client + participant MCPServer as MCP Server + participant AuthPage as Auth Page + participant FakeAuth as Fake Upstream Auth + participant LocalStorage + participant Redis + + Client->>MCPServer: Request authorization + MCPServer->>AuthPage: Redirect to auth page + AuthPage->>Client: Show MCP authorization page + Client->>FakeAuth: Click "Continue to Authentication" + + FakeAuth->>LocalStorage: Get/Create userId + LocalStorage-->>FakeAuth: Return userId + FakeAuth->>Client: Show userId management UI + + Client->>FakeAuth: Complete authentication + FakeAuth->>MCPServer: Redirect with code + userId + MCPServer->>Redis: Store userId in McpInstallation + MCPServer->>Client: Return access token + + Note over Redis: McpInstallation.userId = userId +``` + +### OAuth Data Flow + +```mermaid +graph TD + A[Browser localStorage] -->|userId| B[Fake Auth Page] + B -->|userId in query params| C[Authorization Callback] + C -->|userId| D[McpInstallation Object] + D -->|access_token| E[Redis Storage] + E -->|AuthInfo.extra.userId| F[Session Ownership] +``` + +### Authorization Code Exchange + +The userId is embedded in the authorization flow: + +```javascript +// In fake auth page +function authorize() { + const userId = getUserId(); // From localStorage + const url = new URL(redirectUri); + url.searchParams.set('userId', userId); + url.searchParams.set('code', 'fakecode'); + window.location.href = url.toString(); +} +``` + +--- + +## 3. Redis Session Ownership + +Redis stores session ownership information using a structured key system. + +### Redis Key Structure + +``` +session:{sessionId}:owner → userId # Session ownership +mcp:shttp:toserver:{sessionId} → [pub/sub channel] # Client→Server messages (also indicates liveness) +mcp:shttp:toclient:{sessionId}:{requestId} → [pub/sub channel] # Server→Client responses +mcp:control:{sessionId} → [pub/sub channel] # Control messages +``` + +### Redis Operations + +| Operation | Key Pattern | Value | Purpose | +|-----------|-------------|--------|---------| +| **Set Owner** | `session:{sessionId}:owner` | `userId` | Store session owner | +| **Get Owner** | `session:{sessionId}:owner` | `userId` | Retrieve session owner | +| **Check Live** | `mcp:shttp:toserver:{sessionId}` | `numsub > 0` | Check if session active via pub/sub subscribers | + +### Session Liveness Mechanism + +Session liveness is determined by **pub/sub subscription count** rather than explicit keys: + +```mermaid +graph TD + A[MCP Server Starts] --> B[Subscribe to mcp:shttp:toserver:sessionId] + B --> C[numsub = 1 → Session is LIVE] + C --> D[Session Processing] + D --> E[MCP Server Shutdown] + E --> F[Unsubscribe from channel] + F --> G[numsub = 0 → Session is DEAD] + + H[isLive() function] --> I[Check numsub count] + I --> J{numsub > 0?} + J -->|Yes| K[Session is Live] + J -->|No| L[Session is Dead] +``` + +**Why this works:** +- When an MCP server starts, it subscribes to `mcp:shttp:toserver:{sessionId}` +- When it shuts down (gracefully or crashes), Redis automatically removes the subscription +- `numsub` reflects the actual state without requiring explicit cleanup + +### Session Ownership Functions + +```typescript +// Core ownership functions +export async function setSessionOwner(sessionId: string, userId: string): Promise +export async function getSessionOwner(sessionId: string): Promise +export async function validateSessionOwnership(sessionId: string, userId: string): Promise +export async function isSessionOwnedBy(sessionId: string, userId: string): Promise +export async function isLive(sessionId: string): Promise // Uses numsub count +``` + +--- + +## 4. Session Access Validation + +Session access is validated at multiple points in the request lifecycle. + +### Session Validation Flow + +```mermaid +sequenceDiagram + participant Client + participant Handler as shttp Handler + participant Auth as Auth Middleware + participant Redis + + Client->>Handler: MCP Request with session-id + Handler->>Auth: Extract userId from token + Auth-->>Handler: Return userId + + alt New Session (Initialize) + Handler->>Handler: Generate new sessionId + Handler->>Redis: setSessionOwner(sessionId, userId) + Handler->>Handler: Start MCP server (subscribes to channel) + Note over Handler: Session becomes "live" via pub/sub subscription + Handler->>Client: Return with new session + else Existing Session + Handler->>Redis: isSessionOwnedBy(sessionId, userId) + Redis-->>Handler: Return ownership status + + alt Session Owned by User + Handler->>Client: Process request + else Session Not Owned + Handler->>Client: 400 Bad Request + end + end +``` + +### DELETE Request Validation + +```mermaid +sequenceDiagram + participant Client + participant Handler as shttp Handler + participant Redis + + Client->>Handler: DELETE /mcp (session-id: xyz) + Handler->>Handler: Extract userId from auth + Handler->>Redis: isSessionOwnedBy(sessionId, userId) + + alt Session Owned by User + Redis-->>Handler: true + Handler->>Redis: shutdownSession(sessionId) + Handler->>Client: 200 OK (Session terminated) + else Session Not Owned + Redis-->>Handler: false + Handler->>Client: 404 Not Found (Session not found or access denied) + end +``` + +### Request Authorization Matrix + +| Request Type | Session ID | User ID | Authorization Check | +|-------------|-----------|---------|-------------------| +| **Initialize** | None | Required | Create new session | +| **Existing Session** | Required | Required | `isSessionOwnedBy()` | +| **DELETE Session** | Required | Required | `isSessionOwnedBy()` | + +--- + +## 5. Security Model + +### Multi-User Isolation + +```mermaid +graph TB + subgraph "User A" + A1[localStorage: userA-id] + A2[Session: session-A] + A3[Redis: session:A:owner → userA] + end + + subgraph "User B" + B1[localStorage: userB-id] + B2[Session: session-B] + B3[Redis: session:B:owner → userB] + end + + subgraph "Redis Isolation" + R1[session:A:owner → userA-id] + R2[session:B:owner → userB-id] + R3[Ownership Validation] + end + + A3 --> R1 + B3 --> R2 + R1 --> R3 + R2 --> R3 +``` + +### Security Guarantees + +1. **Session Isolation**: Users can only access sessions they own +2. **Identity Verification**: User ID is validated from authenticated token +3. **Ownership Persistence**: Session ownership is stored in Redis +4. **Access Control**: All session operations validate ownership +5. **Secure Cleanup**: DELETE operations verify ownership before termination + +### Attack Prevention + +| Attack Vector | Prevention | Implementation | +|---------------|------------|----------------| +| **Session Hijacking** | Ownership validation | `isSessionOwnedBy()` check | +| **Cross-User Access** | User ID verification | Extract userId from AuthInfo | +| **Session Spoofing** | Token validation | Bearer token middleware | +| **Unauthorized DELETE** | Ownership check | Validate before shutdown | + +--- + +## 6. Implementation Details + +### Error Handling + +```typescript +// Session access errors +if (!userId) { + return 401; // Unauthorized: User ID required +} + +if (!await isSessionOwnedBy(sessionId, userId)) { + return 400; // Bad Request: Session access denied +} +``` + +### Testing Strategy + +The system includes comprehensive tests for: + +- **User session isolation**: Users cannot access other users' sessions +- **DELETE request validation**: Only owners can delete sessions +- **Redis cleanup**: Proper cleanup of ownership data +- **Auth flow integration**: User ID propagation through OAuth + +### Performance Considerations + +1. **Redis Efficiency**: O(1) lookups for session ownership +2. **Session Reuse**: Existing sessions are reused when ownership matches +3. **Cleanup**: Automatic cleanup prevents resource leaks +4. **Caching**: Session ownership is cached in Redis + +--- + +## 7. Configuration + +### Environment Variables + +```bash +# Redis configuration for session storage +REDIS_HOST=localhost +REDIS_PORT=6379 +REDIS_PASSWORD=optional + +# Base URI for OAuth redirects +BASE_URI=http://localhost:3000 +``` + +### Development Testing + +```bash +# Run multi-user tests +npm test -- --testNamePattern="User Session Isolation" + +# Test session ownership +npm test -- --testNamePattern="session ownership" + +# Full integration test +npm test +``` + +--- + +## 8. Monitoring and Debugging + +### Redis Key Monitoring + +```bash +# Monitor session ownership keys +redis-cli KEYS "session:*:owner" + +# Watch session ownership operations +redis-cli MONITOR | grep "session:" + +# Check active (live) sessions via pub/sub +redis-cli PUBSUB CHANNELS "mcp:shttp:toserver:*" +redis-cli PUBSUB NUMSUB "mcp:shttp:toserver:*" +``` + +### Debugging Commands + +```bash +# Check session ownership +redis-cli GET "session:550e8400-e29b-41d4-a716-446655440000:owner" + +# List all session owners +redis-cli KEYS "session:*:owner" + +# Check if specific session is live +redis-cli PUBSUB NUMSUB "mcp:shttp:toserver:550e8400-e29b-41d4-a716-446655440000" + +# Monitor pub/sub activity +redis-cli MONITOR +``` + +This system provides robust multi-user session management with strong security guarantees and comprehensive testing coverage. \ No newline at end of file diff --git a/package-lock.json b/package-lock.json index e6ec5b9..0eb858d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -8,7 +8,7 @@ "name": "mcp-server-everything", "version": "0.1.0", "dependencies": { - "@modelcontextprotocol/sdk": "^1.6.1", + "@modelcontextprotocol/sdk": "^1.15.1", "@redis/client": "^1.6.0", "cors": "^2.8.5", "dotenv": "^16.4.7", @@ -1072,32 +1072,6 @@ "url": "https://opencollective.com/eslint" } }, - "node_modules/@eslint/eslintrc/node_modules/ajv": { - "version": "6.12.6", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", - "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", - "dev": true, - "license": "MIT", - "peer": true, - "dependencies": { - "fast-deep-equal": "^3.1.1", - "fast-json-stable-stringify": "^2.0.0", - "json-schema-traverse": "^0.4.1", - "uri-js": "^4.2.2" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/@eslint/eslintrc/node_modules/json-schema-traverse": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", - "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", - "dev": true, - "license": "MIT", - "peer": true - }, "node_modules/@eslint/js": { "version": "9.27.0", "resolved": "https://registry.npmjs.org/@eslint/js/-/js-9.27.0.tgz", @@ -1671,16 +1645,17 @@ } }, "node_modules/@modelcontextprotocol/sdk": { - "version": "1.11.4", - "resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.11.4.tgz", - "integrity": "sha512-OTbhe5slIjiOtLxXhKalkKGhIQrwvhgCDs/C2r8kcBTy5HR/g43aDQU0l7r8O0VGbJPTNJvDc7ZdQMdQDJXmbw==", + "version": "1.15.1", + "resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.15.1.tgz", + "integrity": "sha512-W/XlN9c528yYn+9MQkVjxiTPgPxoxt+oczfjHBDsJx0+59+O7B75Zhsp0B16Xbwbz8ANISDajh6+V7nIcPMc5w==", "license": "MIT", "dependencies": { - "ajv": "^8.17.1", + "ajv": "^6.12.6", "content-type": "^1.0.5", "cors": "^2.8.5", "cross-spawn": "^7.0.5", "eventsource": "^3.0.2", + "eventsource-parser": "^3.0.0", "express": "^5.0.1", "express-rate-limit": "^7.5.0", "pkce-challenge": "^5.0.0", @@ -2547,15 +2522,15 @@ } }, "node_modules/ajv": { - "version": "8.17.1", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", - "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", + "version": "6.12.6", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", + "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", "license": "MIT", "dependencies": { - "fast-deep-equal": "^3.1.3", - "fast-uri": "^3.0.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2" + "fast-deep-equal": "^3.1.1", + "fast-json-stable-stringify": "^2.0.0", + "json-schema-traverse": "^0.4.1", + "uri-js": "^4.2.2" }, "funding": { "type": "github", @@ -3570,32 +3545,6 @@ "url": "https://opencollective.com/eslint" } }, - "node_modules/eslint/node_modules/ajv": { - "version": "6.12.6", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", - "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", - "dev": true, - "license": "MIT", - "peer": true, - "dependencies": { - "fast-deep-equal": "^3.1.1", - "fast-json-stable-stringify": "^2.0.0", - "json-schema-traverse": "^0.4.1", - "uri-js": "^4.2.2" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/eslint/node_modules/json-schema-traverse": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", - "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", - "dev": true, - "license": "MIT", - "peer": true - }, "node_modules/espree": { "version": "10.3.0", "resolved": "https://registry.npmjs.org/espree/-/espree-10.3.0.tgz", @@ -3875,7 +3824,6 @@ "version": "2.1.0", "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", - "dev": true, "license": "MIT" }, "node_modules/fast-levenshtein": { @@ -3886,22 +3834,6 @@ "license": "MIT", "peer": true }, - "node_modules/fast-uri": { - "version": "3.0.6", - "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.0.6.tgz", - "integrity": "sha512-Atfo14OibSv5wAp4VWNsFYE1AchQRTv9cBGWET4pZWHzYshFSS9NQI6I57rdKn9croWVMbYFbLhJ+yJvmZIIHw==", - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/fastify" - }, - { - "type": "opencollective", - "url": "https://opencollective.com/fastify" - } - ], - "license": "BSD-3-Clause" - }, "node_modules/fastq": { "version": "1.19.1", "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.19.1.tgz", @@ -5298,9 +5230,9 @@ "license": "MIT" }, "node_modules/json-schema-traverse": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", - "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", + "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", "license": "MIT" }, "node_modules/json-stable-stringify-without-jsonify": { @@ -6036,9 +5968,7 @@ "version": "2.3.1", "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", "integrity": "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==", - "dev": true, "license": "MIT", - "peer": true, "engines": { "node": ">=6" } @@ -6149,15 +6079,6 @@ "node": ">=0.10.0" } }, - "node_modules/require-from-string": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/require-from-string/-/require-from-string-2.0.2.tgz", - "integrity": "sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==", - "license": "MIT", - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/resolve": { "version": "1.22.10", "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.10.tgz", @@ -6967,9 +6888,7 @@ "version": "4.4.1", "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", - "dev": true, "license": "BSD-2-Clause", - "peer": true, "dependencies": { "punycode": "^2.1.0" } diff --git a/package.json b/package.json index c5ff68e..da588bb 100644 --- a/package.json +++ b/package.json @@ -7,7 +7,8 @@ "scripts": { "start": "node dist/index.js", "dev": "tsx watch src/index.ts", - "build": "tsc", + "build": "tsc && npm run copy-static", + "copy-static": "mkdir -p dist/static && cp -r src/static/* dist/static/", "lint": "eslint src/", "test": "NODE_OPTIONS=--experimental-vm-modules jest" }, @@ -25,11 +26,15 @@ "typescript-eslint": "^8.18.0" }, "dependencies": { - "@modelcontextprotocol/sdk": "^1.6.1", + "@modelcontextprotocol/sdk": "^1.15.1", "@redis/client": "^1.6.0", "cors": "^2.8.5", "dotenv": "^16.4.7", "express": "^4.21.2", "raw-body": "^3.0.0" + }, + "overrides": { + "@types/express": "^5.0.0", + "@types/express-serve-static-core": "^5.0.2" } } diff --git a/src/auth/provider.test.ts b/src/auth/provider.test.ts index 5f7c82b..261159b 100644 --- a/src/auth/provider.test.ts +++ b/src/auth/provider.test.ts @@ -24,7 +24,8 @@ function createMockResponse() { redirect: jest.fn().mockReturnThis(), status: jest.fn().mockReturnThis(), json: jest.fn().mockReturnThis(), - send: jest.fn().mockReturnThis() + send: jest.fn().mockReturnThis(), + setHeader: jest.fn().mockReturnThis() }; return res as unknown as jest.Mocked; } @@ -53,6 +54,7 @@ function getMockAuthValues() { }, clientId: client.client_id, issuedAt: Date.now() / 1000, + userId: "test-user-id", }; return { @@ -135,7 +137,8 @@ describe("EverythingAuthProvider", () => { // Verify HTML sent with redirect expect(res.send).toHaveBeenCalled(); const sentHtml = (res.send as jest.Mock).mock.calls[0][0]; - expect(sentHtml).toContain('MCP Auth Page'); + expect(sentHtml).toContain('MCP Server Authorization'); + expect(sentHtml).toContain('Authorization Required'); expect(sentHtml).toContain('fakeupstreamauth/authorize?redirect_uri=/fakeupstreamauth/callback&state='); }); }); @@ -279,6 +282,7 @@ describe("EverythingAuthProvider", () => { }, clientId: "different-client-id", issuedAt: Date.now() / 1000, + userId: "test-user-id", }; await authService.saveRefreshToken(refreshToken, accessToken); @@ -314,6 +318,7 @@ describe("EverythingAuthProvider", () => { }, clientId: "client-id", issuedAt: Date.now() / 1000, + userId: "test-user-id", }; await authService.saveMcpInstallation(accessToken, mcpInstallation); @@ -325,6 +330,9 @@ describe("EverythingAuthProvider", () => { clientId: mcpInstallation.clientId, scopes: ['mcp'], expiresAt: mcpInstallation.mcpTokens.expires_in! + mcpInstallation.issuedAt, + extra: { + userId: "test-user-id" + } }); }); @@ -351,6 +359,7 @@ describe("EverythingAuthProvider", () => { }, clientId: "client-id", issuedAt: twoDaysAgoInSeconds, // 2 days ago, with 1-day expiry + userId: "test-user-id", }; await authService.saveMcpInstallation(accessToken, mcpInstallation); @@ -377,6 +386,7 @@ describe("EverythingAuthProvider", () => { }, clientId: client.client_id, issuedAt: Date.now() / 1000, + userId: "test-user-id", }; // Save the installation diff --git a/src/auth/provider.ts b/src/auth/provider.ts index 53cb73e..3d77c61 100644 --- a/src/auth/provider.ts +++ b/src/auth/provider.ts @@ -72,22 +72,196 @@ export class EverythingAuthProvider implements OAuthServerProvider { // You can redirect to another page, or you can send an html response directly // res.redirect(new URL(`fakeupstreamauth/authorize?metadata=${authorizationCode}`, BASE_URI).href); + // Set permissive CSP for styling + res.setHeader('Content-Security-Policy', [ + "default-src 'self'", + "style-src 'self' 'unsafe-inline'", + "script-src 'self' 'unsafe-inline'", + "img-src 'self' data:", + "object-src 'none'", + "frame-ancestors 'none'", + "form-action 'self'", + "base-uri 'self'" + ].join('; ')); + res.send(` - + + - MCP Auth Page + + + MCP Server Authorization + -

MCP Server Auth Page

-

- This page is the authorization page presented by the MCP server, routing the user upstream. This is only - needed on 2025-03-26 Auth spec, where the MCP server acts as it's own authoriztion server. This page should - be present to avoid confused deputy attacks. -

-

- Click here - to continue to the upstream auth -

+
+
+ +
MCP
+
+ +

Authorization Required

+

This client wants to connect to your MCP server

+ +
+

Client Application

+
${client.client_id}
+
+ +
+

What happens next?

+

You'll be redirected to authenticate with the upstream provider. Once verified, you'll be granted access to this MCP server's resources.

+
+ + + Continue to Authentication + + +
+ Model Context Protocol (MCP) Server +
+
`); @@ -156,6 +330,7 @@ export class EverythingAuthProvider implements OAuthServerProvider { ...mcpInstallation, mcpTokens: newTokens, issuedAt: Date.now() / 1000, + userId: mcpInstallation.userId, // Preserve the user ID }); return newTokens; @@ -183,7 +358,10 @@ export class EverythingAuthProvider implements OAuthServerProvider { token, clientId: installation.clientId, scopes: ['mcp'], - expiresAt + expiresAt, + extra: { + userId: installation.userId + } }; } diff --git a/src/handlers/common.ts b/src/handlers/common.ts new file mode 100644 index 0000000..3a95c18 --- /dev/null +++ b/src/handlers/common.ts @@ -0,0 +1,63 @@ +import { NextFunction, Request, Response } from "express"; +import { withContext } from "../context.js"; +import { readMcpInstallation } from "../services/auth.js"; +import { logger } from "../utils/logger.js"; + +import { JSONRPCError, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse } from "@modelcontextprotocol/sdk/types.js"; + +export function logMcpMessage( + message: JSONRPCError | JSONRPCNotification | JSONRPCRequest | JSONRPCResponse, + sessionId: string, +) { + // check if message has a method field + if ("method" in message) { + if (message.method === "tools/call") { + logger.info('Processing MCP method', { + sessionId, + method: message.method, + toolName: message.params?.name + }); + } else { + logger.info('Processing MCP method', { + sessionId, + method: message.method + }); + } + } else if ("error" in message) { + logger.warning('Received error message', { + sessionId, + errorMessage: message.error.message, + errorCode: message.error.code + }); + } +} + + +export async function authContext( + req: Request, + res: Response, + next: NextFunction, +) { + const authInfo = req.auth + + if (!authInfo) { + res.set("WWW-Authenticate", 'Bearer error="invalid_token"'); + res.status(401).json({ error: "Invalid access token" }); + return; + } + + const token = authInfo.token; + + // Load UpstreamInstallation based on the access token + const mcpInstallation = await readMcpInstallation(token); + if (!mcpInstallation) { + res.set("WWW-Authenticate", 'Bearer error="invalid_token"'); + res.status(401).json({ error: "Invalid access token" }); + return; + } + + // Wrap the rest of the request handling in the context + withContext({ mcpAccessToken: token, fakeUpstreamInstallation: mcpInstallation.fakeUpstreamInstallation }, () => + next(), + ); +} diff --git a/src/handlers/fakeauth.ts b/src/handlers/fakeauth.ts index d19e786..49f1f70 100644 --- a/src/handlers/fakeauth.ts +++ b/src/handlers/fakeauth.ts @@ -12,18 +12,245 @@ export async function handleFakeAuthorize(req: Request, res: Response) { // get the redirect_uri and state from the query params const { redirect_uri, state } = req.query; - // TODO, mint actual codes? + // Set a more permissive CSP for auth pages to allow inline styles and scripts + res.setHeader('Content-Security-Policy', [ + "default-src 'self'", + "style-src 'self' 'unsafe-inline'", // Allow inline styles for auth page styling + "script-src 'self' 'unsafe-inline'", // Allow inline scripts for auth page functionality + "object-src 'none'", + "frame-ancestors 'none'", + "form-action 'self'", + "base-uri 'self'" + ].join('; ')); res.send(` - + + - Fake Upstream Auth Provider! + + + Upstream Provider Authentication + -

Fake Auth

-

Fake auth page

-

Click here to authorize

-

Click here to fail authorization

+
+ +

Upstream Authentication

+

Please verify your identity with the upstream provider

+ +
+

Your User Identity

+
Loading...
+ +
+ + + +
+ Testing Multiple Users: Open this page in different browser windows or incognito tabs to simulate different users. Each will have their own unique User ID and separate MCP sessions. +
+
+ + `); @@ -36,6 +263,7 @@ export async function handleFakeAuthorizeRedirect(req: Request, res: Response) { // The state returned from the upstream auth server is actually the authorization code state: mcpAuthorizationCode, code: upstreamAuthorizationCode, + userId, // User ID from the authorization flow } = req.query; // This is where you'd exchange the upstreamAuthorizationCode for access/refresh tokens @@ -63,6 +291,7 @@ export async function handleFakeAuthorizeRedirect(req: Request, res: Response) { mcpTokens, clientId: pendingAuth.clientId, issuedAt: Date.now() / 1000, + userId: (userId as string) || 'anonymous-user', // Include user ID from auth flow } // Store the upstream authorization data diff --git a/src/handlers/mcp.ts b/src/handlers/mcp.ts deleted file mode 100644 index 1a70833..0000000 --- a/src/handlers/mcp.ts +++ /dev/null @@ -1,144 +0,0 @@ -import { SSEServerTransport } from "@modelcontextprotocol/sdk/server/sse.js"; -import contentType from "content-type"; -import { NextFunction, Request, Response } from "express"; -import getRawBody from "raw-body"; -import { readMcpInstallation } from "../services/auth.js"; -import { withContext } from "../context.js"; -import { createMcpServer } from "../services/mcp.js"; -import { redisClient } from "../redis.js"; -import { AuthInfo } from "@modelcontextprotocol/sdk/server/auth/types.js"; - -const MAXIMUM_MESSAGE_SIZE = "4mb"; - -declare module "express-serve-static-core" { - interface Request { - /** - * Information about the validated access token, if the `requireBearerAuth` middleware was used. - */ - auth?: AuthInfo; - } -} - -export async function authContext( - req: Request, - res: Response, - next: NextFunction, -) { - const authInfo = req.auth - - if (!authInfo) { - res.set("WWW-Authenticate", 'Bearer error="invalid_token"'); - res.status(401).json({ error: "Invalid access token" }); - return; - } - - const token = authInfo.token; - - // Load UpstreamInstallation based on the access token - const mcpInstallation = await readMcpInstallation(token); - if (!mcpInstallation) { - res.set("WWW-Authenticate", 'Bearer error="invalid_token"'); - res.status(401).json({ error: "Invalid access token" }); - return; - } - - // Wrap the rest of the request handling in the context - withContext({ mcpAccessToken: token, fakeUpstreamInstallation: mcpInstallation.fakeUpstreamInstallation }, () => - next(), - ); -} - -function redisChannelForSession(sessionId: string): string { - return `mcp:${sessionId}`; -} - -export async function handleSSEConnection(req: Request, res: Response) { - const { server: mcpServer, cleanup: mcpCleanup } = createMcpServer(); - const transport = new SSEServerTransport("/message", res); - console.info(`[session ${transport.sessionId}] Received MCP SSE connection`); - - const redisCleanup = await redisClient.createSubscription( - redisChannelForSession(transport.sessionId), - (json) => { - const message = JSON.parse(json); - - if (message.method) { - if (message.method === "tools/call") { - console.info( - `[session ${transport.sessionId}] Processing ${message.method}, for tool ${message.params?.name}`, - ); - } else { - console.info( - `[session ${transport.sessionId}] Processing ${message.method} method`, - ); - } - } else if (message.error) { - console.warn( - `[session ${transport.sessionId}] Received error message: ${message.error.message}, ${message.error.code}`, - ) - } - transport.handleMessage(message).catch((error) => { - console.error( - `[session ${transport.sessionId}] Error handling message:`, - error, - ); - }); - }, - (error) => { - console.error( - `[session ${transport.sessionId}] Disconnecting due to error in Redis subscriber:`, - error, - ); - transport - .close() - .catch((error) => - console.error( - `[session ${transport.sessionId}] Error closing transport:`, - error, - ), - ); - }, - ); - - const cleanup = () => { - void mcpCleanup(); - redisCleanup().catch((error) => - console.error( - `[session ${transport.sessionId}] Error disconnecting Redis subscriber:`, - error, - ), - ); - } - - // Clean up Redis subscription when the connection closes - mcpServer.onclose = cleanup - - console.info(`[session ${transport.sessionId}] Listening on Redis channel`); - await mcpServer.connect(transport); -} - -export async function handleMessage(req: Request, res: Response) { - const sessionId = req.query.sessionId; - let body: string; - try { - if (typeof sessionId !== "string") { - throw new Error("Only one sessionId allowed"); - } - - const ct = contentType.parse(req.headers["content-type"] ?? ""); - if (ct.type !== "application/json") { - throw new Error(`Unsupported content-type: ${ct}`); - } - - body = await getRawBody(req, { - limit: MAXIMUM_MESSAGE_SIZE, - encoding: ct.parameters.charset ?? "utf-8", - }); - } catch (error) { - res.status(400).json(error); - console.error("Bad POST request:", error); - return; - } - await redisClient.publish(redisChannelForSession(sessionId), body); - res.status(202).end(); -} diff --git a/src/handlers/shttp.integration.test.ts b/src/handlers/shttp.integration.test.ts new file mode 100644 index 0000000..b345d28 --- /dev/null +++ b/src/handlers/shttp.integration.test.ts @@ -0,0 +1,694 @@ +import { jest } from '@jest/globals'; +import { Request, Response } from 'express'; +import { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; +import { MockRedisClient, setRedisClient } from '../redis.js'; +import { handleStreamableHTTP } from './shttp.js'; +import { AuthInfo } from '@modelcontextprotocol/sdk/server/auth/types.js'; +// import { randomUUID } from 'crypto'; // Currently unused but may be needed for future tests +import { shutdownSession } from '../services/redisTransport.js'; + +// Type for MCP initialization response +interface MCPInitResponse { + jsonrpc: string; + id: string | number; + result?: { + _meta?: { + sessionId?: string; + }; + [key: string]: unknown; + }; +} + +describe('Streamable HTTP Handler Integration Tests', () => { + let mockRedis: MockRedisClient; + let mockReq: Partial; + let mockRes: Partial; + + beforeEach(() => { + mockRedis = new MockRedisClient(); + setRedisClient(mockRedis); + jest.resetAllMocks(); + + // Create mock response with chainable methods + mockRes = { + status: jest.fn().mockReturnThis(), + json: jest.fn().mockReturnThis(), + on: jest.fn().mockReturnThis(), + once: jest.fn().mockReturnThis(), + emit: jest.fn().mockReturnThis(), + headersSent: false, + setHeader: jest.fn().mockReturnThis(), + writeHead: jest.fn().mockReturnThis(), + write: jest.fn().mockReturnThis(), + end: jest.fn().mockReturnThis(), + getHeader: jest.fn(), + removeHeader: jest.fn().mockReturnThis(), + socket: { + setTimeout: jest.fn(), + }, + } as unknown as Partial; + + // Create mock request + mockReq = { + method: 'POST', + headers: { + 'content-type': 'application/json', + 'accept': 'application/json, text/event-stream', + 'mcp-protocol-version': '2024-11-05', + }, + body: {}, + }; + }); + + // Helper function to trigger cleanup after handleStreamableHTTP calls + const triggerResponseCleanup = async () => { + // Find all finish handlers registered during the test + const finishHandlers = (mockRes.on as jest.Mock).mock.calls + .filter(([event]) => event === 'finish') + .map(([, handler]) => handler); + + // Trigger all finish handlers + for (const handler of finishHandlers) { + if (typeof handler === 'function') { + await handler(); + } + } + }; + + // Helper to extract session ID from test context + const getSessionIdFromTest = (): string | undefined => { + // Try to get from response headers first + const setHeaderCalls = (mockRes.setHeader as jest.Mock).mock.calls; + const sessionIdHeader = setHeaderCalls.find(([name]) => name === 'mcp-session-id'); + if (sessionIdHeader?.[1]) { + return sessionIdHeader[1] as string; + } + + // Fall back to extracting from Redis channels + const allChannels = Array.from(mockRedis.subscribers.keys()); + const serverChannel = allChannels.find(channel => channel.includes('mcp:shttp:toserver:')); + return serverChannel?.split(':')[3]; + }; + + afterEach(async () => { + // Always trigger cleanup for any MCP servers created during tests + await triggerResponseCleanup(); + mockRedis.clear(); + jest.clearAllMocks(); + }); + + describe('Redis Subscription Cleanup', () => { + it('should clean up Redis subscriptions after shttp response completes', async () => { + // Set up initialization request (no session ID for new initialization) + const initRequest: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'init-1', + method: 'initialize', + params: { + protocolVersion: '2024-11-05', + capabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + }; + + mockReq.body = initRequest; + mockReq.auth = { + clientId: 'test-client-123', + token: 'test-token', + scopes: ['mcp'], + extra: { userId: 'test-user-123' } + } as AuthInfo; + + // Call the handler + await handleStreamableHTTP(mockReq as Request, mockRes as Response); + + // Wait longer for async initialization to complete + await new Promise(resolve => setTimeout(resolve, 200)); + + // get the sessionId from the response + const sessionId = getSessionIdFromTest(); + expect(sessionId).toBeDefined(); + + // Check if any subscriptions were created on any channels + // Since we don't know the exact sessionId generated, check all channels + const allChannels = Array.from(mockRedis.subscribers.keys()); + const totalSubscriptions = allChannels.reduce((sum, channel) => sum + (mockRedis.subscribers.get(channel)?.length || 0), 0); + + // Should have created at least one subscription (server channel) + expect(totalSubscriptions).toBeGreaterThan(0); + expect(allChannels.some(channel => channel.includes('mcp:shttp:toserver:'))).toBe(true); + + // Find the finish handler that was registered + const finishHandler = (mockRes.on as jest.Mock).mock.calls.find( + ([event]) => event === 'finish' + )?.[1] as (() => Promise) | undefined; + + expect(finishHandler).toBeDefined(); + + // Simulate response completion to trigger cleanup + if (finishHandler) { + await finishHandler(); + } + + // Verify cleanup handler was registered + expect(mockRes.on).toHaveBeenCalledWith('finish', expect.any(Function)); + + if (sessionId) { + await shutdownSession(sessionId) + } + }); + + it('should handle cleanup errors gracefully', async () => { + const initRequest: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'init-1', + method: 'initialize', + params: { + protocolVersion: '2024-11-05', + capabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + }; + + mockReq.body = initRequest; + mockReq.auth = { + clientId: 'test-client-123', + token: 'test-token', + scopes: ['mcp'], + extra: { userId: 'test-user-123' } + } as AuthInfo; + + // Call the handler + await handleStreamableHTTP(mockReq as Request, mockRes as Response); + + // Use Redis error simulation to test error handling + const finishHandler = (mockRes.on as jest.Mock).mock.calls.find( + ([event]) => event === 'finish' + )?.[1] as (() => Promise) | undefined; + + // Simulate error during cleanup + + // Cleanup should not throw error even if Redis operations fail + if (finishHandler) { + await expect(finishHandler()).resolves.not.toThrow(); + } + + // Clean up the MCP server by sending DELETE request + const cleanupSessionId = getSessionIdFromTest(); + + if (cleanupSessionId) { + // Send DELETE request to clean up MCP server + jest.clearAllMocks(); + mockReq.method = 'DELETE'; + if (mockReq.headers) { + mockReq.headers['mcp-session-id'] = cleanupSessionId; + } + mockReq.body = {}; + + await handleStreamableHTTP(mockReq as Request, mockRes as Response); + + // Wait a bit for cleanup to complete + await new Promise(resolve => setTimeout(resolve, 50)); + } + }); + }); + + describe('DELETE Request Session Cleanup', () => { + it('should trigger onsessionclosed callback which sends shutdown control message', async () => { + // First, create a session with an initialization request + const initRequest: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'init-1', + method: 'initialize', + params: { + protocolVersion: '2024-11-05', + capabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + }; + + mockReq.body = initRequest; + mockReq.auth = { + clientId: 'test-client-123', + token: 'test-token', + scopes: ['mcp'], + extra: { userId: 'test-user-123' } + } as AuthInfo; + + // Initialize session + await handleStreamableHTTP(mockReq as Request, mockRes as Response); + + // Wait for async initialization + await new Promise(resolve => setTimeout(resolve, 100)); + + // For initialization requests with StreamableHTTPServerTransport, + // the handler might not immediately return a response if using SSE mode + // Let's check different possible locations for the session ID + + // Check JSON responses + const jsonCalls = (mockRes.json as jest.Mock).mock.calls; + let sessionId: string | undefined; + + if (jsonCalls.length > 0) { + const response = jsonCalls[0][0] as MCPInitResponse; + if (response?.result?._meta?.sessionId) { + sessionId = response.result._meta.sessionId; + } + } + + // Check write calls (for SSE responses) + if (!sessionId) { + const writeCalls = (mockRes.write as jest.Mock).mock.calls; + for (const [data] of writeCalls) { + if (typeof data === 'string' && data.includes('sessionId')) { + try { + // SSE data format: "data: {...}\n\n" + const jsonStr = data.replace(/^data: /, '').trim(); + const parsed = JSON.parse(jsonStr) as MCPInitResponse; + if (parsed?.result?._meta?.sessionId) { + sessionId = parsed.result._meta.sessionId; + } + } catch { + // Not valid JSON, continue + } + } + } + } + + // Fallback to getting from Redis channels + if (!sessionId) { + sessionId = getSessionIdFromTest(); + } + + expect(sessionId).toBeDefined(); + + // Reset mocks but keep the session + jest.clearAllMocks(); + + // Now test DELETE request + mockReq.method = 'DELETE'; + mockReq.headers = { + ...mockReq.headers, + 'mcp-session-id': sessionId + }; + mockReq.body = {}; + + // Track control messages sent to Redis + const publishSpy = jest.spyOn(mockRedis, 'publish'); + + // Call DELETE handler - StreamableHTTPServerTransport should handle it + await handleStreamableHTTP(mockReq as Request, mockRes as Response); + + // Wait for async processing and onsessionclosed callback + await new Promise(resolve => setTimeout(resolve, 100)); + + // The StreamableHTTPServerTransport should handle the DELETE and trigger onsessionclosed + // which calls shutdownSession, sending the control message + const controlCalls = publishSpy.mock.calls.filter(call => + call[0] === `mcp:control:${sessionId}` + ); + + expect(controlCalls.length).toBeGreaterThan(0); + + // Verify the control message content + const controlCall = publishSpy.mock.calls.find(call => + call[0] === `mcp:control:${sessionId}` + ); + if (controlCall) { + const message = JSON.parse(controlCall[1]); + expect(message.type).toBe('control'); + expect(message.action).toBe('SHUTDOWN'); + } + }); + + it('should return 401 for DELETE request with wrong user', async () => { + // First, create a session as user1 + const initRequest: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'init-1', + method: 'initialize', + params: { + protocolVersion: '2024-11-05', + capabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + }; + + mockReq.body = initRequest; + mockReq.auth = { + clientId: 'test-client-123', + token: 'test-token', + scopes: ['mcp'], + extra: { userId: 'user1' } + } as AuthInfo; + + // Initialize session as user1 + await handleStreamableHTTP(mockReq as Request, mockRes as Response); + + // Wait for async initialization + await new Promise(resolve => setTimeout(resolve, 100)); + + // Get the session ID from response + let sessionId: string | undefined; + + // Check JSON responses + const jsonCalls = (mockRes.json as jest.Mock).mock.calls; + if (jsonCalls.length > 0) { + const response = jsonCalls[0][0] as MCPInitResponse; + if (response?.result?._meta?.sessionId) { + sessionId = response.result._meta.sessionId; + } + } + + // Check write calls (for SSE responses) + if (!sessionId) { + const writeCalls = (mockRes.write as jest.Mock).mock.calls; + for (const [data] of writeCalls) { + if (typeof data === 'string' && data.includes('sessionId')) { + try { + const jsonStr = data.replace(/^data: /, '').trim(); + const parsed = JSON.parse(jsonStr) as MCPInitResponse; + if (parsed?.result?._meta?.sessionId) { + sessionId = parsed.result._meta.sessionId; + } + } catch { + // Ignore JSON parse errors + } + } + } + } + + if (!sessionId) { + sessionId = getSessionIdFromTest(); + } + + // Reset mocks + jest.clearAllMocks(); + + // Now test DELETE request as user2 + mockReq.method = 'DELETE'; + mockReq.headers = { + ...mockReq.headers, + 'mcp-session-id': sessionId + }; + mockReq.body = {}; + mockReq.auth = { + clientId: 'test-client-456', + token: 'test-token-2', + scopes: ['mcp'], + extra: { userId: 'user2' } + } as AuthInfo; + + await handleStreamableHTTP(mockReq as Request, mockRes as Response); + + // Should return 401 for unauthorized access to another user's session + expect(mockRes.status).toHaveBeenCalledWith(401); + + // shutdown the session + if (sessionId) { + await shutdownSession(sessionId) + } + }); + }); + + describe('User Session Isolation', () => { + it('should prevent users from accessing sessions created by other users', async () => { + // Create session for user 1 + const user1Auth: AuthInfo = { + clientId: 'user1-client', + token: 'user1-token', + scopes: ['mcp'], + extra: { userId: 'user1' } + }; + + const user2Auth: AuthInfo = { + clientId: 'user2-client', + token: 'user2-token', + scopes: ['mcp'], + extra: { userId: 'user2' } + }; + + const initRequest: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'init-1', + method: 'initialize', + params: { + protocolVersion: '2024-11-05', + capabilities: {}, + clientInfo: { name: 'user1-client', version: '1.0.0' } + } + }; + + // User 1 creates session + mockReq.body = initRequest; + mockReq.auth = user1Auth; + + await handleStreamableHTTP(mockReq as Request, mockRes as Response); + + // Wait for async initialization to complete + await new Promise(resolve => setTimeout(resolve, 100)); + + // Get the actual session ID from response + let actualSessionId: string | undefined; + + // Check JSON responses + const jsonCalls = (mockRes.json as jest.Mock).mock.calls; + if (jsonCalls.length > 0) { + const response = jsonCalls[0][0] as MCPInitResponse; + if (response?.result?._meta?.sessionId) { + actualSessionId = response.result._meta.sessionId; + } + } + + // Check write calls (for SSE responses) + if (!actualSessionId) { + const writeCalls = (mockRes.write as jest.Mock).mock.calls; + for (const [data] of writeCalls) { + if (typeof data === 'string' && data.includes('sessionId')) { + try { + const jsonStr = data.replace(/^data: /, '').trim(); + const parsed = JSON.parse(jsonStr) as MCPInitResponse; + if (parsed?.result?._meta?.sessionId) { + actualSessionId = parsed.result._meta.sessionId; + } + } catch { + // Ignore JSON parse errors + } + } + } + } + + if (!actualSessionId) { + actualSessionId = getSessionIdFromTest(); + } + + expect(actualSessionId).toBeDefined(); + + // Store finish handler before clearing mocks + const finishHandler1 = (mockRes.on as jest.Mock).mock.calls.find( + ([event]) => event === 'finish' + )?.[1] as (() => Promise) | undefined; + + // Reset mocks + jest.clearAllMocks(); + + // Trigger cleanup for the MCP server created in this step + if (finishHandler1) { + await finishHandler1(); + } + + // User 2 tries to access user 1's session + mockReq.headers = { + ...mockReq.headers, + 'mcp-session-id': actualSessionId + }; + mockReq.body = { + jsonrpc: '2.0', + id: 'user2-request', + method: 'tools/list', + params: {} + }; + mockReq.auth = user2Auth; + + await handleStreamableHTTP(mockReq as Request, mockRes as Response); + + // Should return 401 for unauthorized access to another user's session + expect(mockRes.status).toHaveBeenCalledWith(401); + + // Clean up the MCP server by sending DELETE request + if (actualSessionId) { + jest.clearAllMocks(); + mockReq.method = 'DELETE'; + mockReq.headers['mcp-session-id'] = actualSessionId; + mockReq.body = {}; + mockReq.auth = user1Auth; // Use user1's auth to delete their session + + await handleStreamableHTTP(mockReq as Request, mockRes as Response); + + // Wait a bit for cleanup to complete + await new Promise(resolve => setTimeout(resolve, 50)); + } + }); + + it('should allow users to create separate sessions with same session ID pattern', async () => { + // This test shows that different users should be able to use sessions + // without interfering with each other, even if session IDs might collide + + const user1Auth: AuthInfo = { + clientId: 'user1-client', + token: 'user1-token', + scopes: ['mcp'] + }; + + const user2Auth: AuthInfo = { + clientId: 'user2-client', + token: 'user2-token', + scopes: ['mcp'] + }; + + const initRequest: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'init-1', + method: 'initialize', + params: { + protocolVersion: '2024-11-05', + capabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + }; + + // User 1 creates session + mockReq.body = initRequest; + mockReq.auth = user1Auth; + + await handleStreamableHTTP(mockReq as Request, mockRes as Response); + + // Store finish handler before clearing mocks + const finishHandler1 = (mockRes.on as jest.Mock).mock.calls.find( + ([event]) => event === 'finish' + )?.[1] as (() => Promise) | undefined; + + // Reset for user 2 + jest.clearAllMocks(); + + // Trigger cleanup for User 1's MCP server + if (finishHandler1) { + await finishHandler1(); + } + + // User 2 creates their own session + mockReq.body = { + ...initRequest, + id: 'init-2', + params: { + ...initRequest.params, + clientInfo: { name: 'user2-client', version: '1.0.0' } + } + }; + mockReq.auth = user2Auth; + delete mockReq.headers!['mcp-session-id']; // New initialization + + await handleStreamableHTTP(mockReq as Request, mockRes as Response); + + // Trigger cleanup for User 2's MCP server + const finishHandler2 = (mockRes.on as jest.Mock).mock.calls.find( + ([event]) => event === 'finish' + )?.[1] as (() => Promise) | undefined; + + if (finishHandler2) { + await finishHandler2(); + } + + // Both users should be able to create sessions successfully + // Sessions should be isolated in Redis using user-scoped keys + expect(mockRes.status).not.toHaveBeenCalledWith(400); + expect(mockRes.status).not.toHaveBeenCalledWith(403); + }); + + it('should clean up only the requesting user\'s session on DELETE', async () => { + // Create sessions for both users + const user1Auth: AuthInfo = { + clientId: 'user1-client', + token: 'user1-token', + scopes: ['mcp'] + }; + + const user2Auth: AuthInfo = { + clientId: 'user2-client', + token: 'user2-token', + scopes: ['mcp'] + }; + + // Create session for user 1 + const initRequest: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'init-1', + method: 'initialize', + params: { + protocolVersion: '2024-11-05', + capabilities: {}, + clientInfo: { name: 'user1-client', version: '1.0.0' } + } + }; + + mockReq.body = initRequest; + mockReq.auth = user1Auth; + + await handleStreamableHTTP(mockReq as Request, mockRes as Response); + + // Trigger cleanup for User 1's MCP server + const finishHandler1 = (mockRes.on as jest.Mock).mock.calls.find( + ([event]) => event === 'finish' + )?.[1] as (() => Promise) | undefined; + + if (finishHandler1) { + await finishHandler1(); + } + + // Track session 1 ID (would be returned in response headers) + const session1Id = 'user1-session-id'; // In real implementation, extract from response + + // Create session for user 2 + jest.clearAllMocks(); + mockReq.body = { + ...initRequest, + id: 'init-2', + params: { + ...initRequest.params, + clientInfo: { name: 'user2-client', version: '1.0.0' } + } + }; + mockReq.auth = user2Auth; + delete mockReq.headers!['mcp-session-id']; + + await handleStreamableHTTP(mockReq as Request, mockRes as Response); + + // Trigger cleanup for User 2's MCP server + const finishHandler2 = (mockRes.on as jest.Mock).mock.calls.find( + ([event]) => event === 'finish' + )?.[1] as (() => Promise) | undefined; + + if (finishHandler2) { + await finishHandler2(); + } + + // Track session 2 ID (placeholder for actual implementation) + + // User 1 deletes their session + jest.clearAllMocks(); + mockReq.method = 'DELETE'; + mockReq.headers = { + ...mockReq.headers, + 'mcp-session-id': session1Id + }; + mockReq.body = {}; + mockReq.auth = user1Auth; + + await handleStreamableHTTP(mockReq as Request, mockRes as Response); + + // Only user 1's session should be cleaned up + // User 2's session should remain active + // This test documents expected behavior for proper user isolation + }); + }); +}); \ No newline at end of file diff --git a/src/handlers/shttp.test.ts b/src/handlers/shttp.test.ts new file mode 100644 index 0000000..a068209 --- /dev/null +++ b/src/handlers/shttp.test.ts @@ -0,0 +1,229 @@ +import { jest } from '@jest/globals'; +import { Request, Response } from 'express'; +import { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; +import { MockRedisClient, setRedisClient } from '../redis.js'; + +describe('Streamable HTTP Handler', () => { + let mockRedis: MockRedisClient; + + beforeEach(() => { + mockRedis = new MockRedisClient(); + setRedisClient(mockRedis); + jest.resetAllMocks(); + }); + + afterEach(() => { + mockRedis.clear(); + }); + + describe('Helper function tests', () => { + it('should verify Redis mock is working', async () => { + await mockRedis.set('test-key', 'test-value'); + const value = await mockRedis.get('test-key'); + expect(value).toBe('test-value'); + }); + + it('should handle Redis pub/sub', async () => { + const messageHandler = jest.fn(); + const cleanup = await mockRedis.createSubscription( + 'test-channel', + messageHandler, + jest.fn() + ); + + await mockRedis.publish('test-channel', 'test-message'); + + expect(messageHandler).toHaveBeenCalledWith('test-message'); + + await cleanup(); + }); + }); + + describe('Request validation', () => { + it('should identify initialize requests correctly', async () => { + const { isInitializeRequest } = await import('@modelcontextprotocol/sdk/types.js'); + + const initRequest: JSONRPCMessage = { + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: '2024-11-05', + capabilities: {}, + clientInfo: { name: 'test', version: '1.0' } + } + }; + + const nonInitRequest: JSONRPCMessage = { + jsonrpc: '2.0', + id: 2, + method: 'tools/list', + params: {} + }; + + expect(isInitializeRequest(initRequest)).toBe(true); + expect(isInitializeRequest(nonInitRequest)).toBe(false); + }); + }); + + describe('HTTP response mock behavior', () => { + it('should create proper response mock with chainable methods', () => { + const mockRes = { + status: jest.fn().mockReturnThis(), + json: jest.fn().mockReturnThis(), + on: jest.fn().mockReturnThis(), + headersSent: false, + } as Partial; + + // Test chaining + const result = mockRes.status!(400).json!({ + jsonrpc: '2.0', + error: { code: -32000, message: 'Bad Request' }, + id: null + }); + + expect(mockRes.status).toHaveBeenCalledWith(400); + expect(mockRes.json).toHaveBeenCalledWith({ + jsonrpc: '2.0', + error: { code: -32000, message: 'Bad Request' }, + id: null + }); + expect(result).toBe(mockRes); + }); + }); + + describe('Session ID generation', () => { + it('should generate valid UUIDs', async () => { + const { randomUUID } = await import('crypto'); + + const sessionId = randomUUID(); + + // UUID v4 format: xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx + expect(sessionId).toMatch(/^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$/i); + }); + }); + + describe('Redis channel naming', () => { + it('should create correct channel names for server communication', () => { + const sessionId = 'test-session-123'; + const requestId = 'req-456'; + + const toServerChannel = `mcp:shttp:toserver:${sessionId}`; + const toClientChannel = `mcp:shttp:toclient:${sessionId}:${requestId}`; + const notificationChannel = `mcp:shttp:toclient:${sessionId}:__GET_stream`; + + expect(toServerChannel).toBe('mcp:shttp:toserver:test-session-123'); + expect(toClientChannel).toBe('mcp:shttp:toclient:test-session-123:req-456'); + expect(notificationChannel).toBe('mcp:shttp:toclient:test-session-123:__GET_stream'); + }); + }); + + describe('Error response formatting', () => { + it('should format JSON-RPC error responses correctly', () => { + const errorResponse = { + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Bad Request: No valid session ID provided', + }, + id: null, + }; + + expect(errorResponse.jsonrpc).toBe('2.0'); + expect(errorResponse.error.code).toBe(-32000); + expect(errorResponse.error.message).toContain('Bad Request'); + expect(errorResponse.id).toBe(null); + }); + + it('should format internal error responses correctly', () => { + const internalErrorResponse = { + jsonrpc: '2.0', + error: { + code: -32603, + message: 'Internal error', + }, + id: null, + }; + + expect(internalErrorResponse.jsonrpc).toBe('2.0'); + expect(internalErrorResponse.error.code).toBe(-32603); + expect(internalErrorResponse.error.message).toBe('Internal error'); + expect(internalErrorResponse.id).toBe(null); + }); + }); + + describe('Request/Response patterns', () => { + it('should handle typical MCP message structures', () => { + const initializeRequest: JSONRPCMessage = { + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: '2024-11-05', + capabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + }; + + const toolsListRequest: JSONRPCMessage = { + jsonrpc: '2.0', + id: 2, + method: 'tools/list', + params: {} + }; + + const toolsListResponse: JSONRPCMessage = { + jsonrpc: '2.0', + id: 2, + result: { + tools: [ + { + name: 'echo', + description: 'Echo the input', + inputSchema: { + type: 'object', + properties: { + text: { type: 'string' } + } + } + } + ] + } + }; + + expect(initializeRequest.method).toBe('initialize'); + expect(toolsListRequest.method).toBe('tools/list'); + expect(toolsListResponse.result).toBeDefined(); + expect(Array.isArray(toolsListResponse.result?.tools)).toBe(true); + }); + }); + + describe('HTTP header handling', () => { + it('should extract session ID from headers', () => { + const mockReq = { + headers: { + 'mcp-session-id': 'test-session-123', + 'content-type': 'application/json' + }, + body: {} + } as Partial; + + const sessionId = mockReq.headers!['mcp-session-id'] as string; + + expect(sessionId).toBe('test-session-123'); + }); + + it('should handle missing session ID in headers', () => { + const mockReq = { + headers: { + 'content-type': 'application/json' + }, + body: {} + } as Partial; + + const sessionId = mockReq.headers!['mcp-session-id'] as string | undefined; + + expect(sessionId).toBeUndefined(); + }); + }); +}); \ No newline at end of file diff --git a/src/handlers/shttp.ts b/src/handlers/shttp.ts new file mode 100644 index 0000000..37fee5e --- /dev/null +++ b/src/handlers/shttp.ts @@ -0,0 +1,128 @@ +import { AuthInfo } from "@modelcontextprotocol/sdk/server/auth/types.js"; +import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; +import { Request, Response } from "express"; +import { getShttpTransport, isSessionOwnedBy, redisRelayToMcpServer, ServerRedisTransport, setSessionOwner, shutdownSession } from "../services/redisTransport.js"; +import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; +import { randomUUID } from "crypto"; +import { createMcpServer } from "../services/mcp.js"; +import { logger } from "../utils/logger.js"; + + +declare module "express-serve-static-core" { + interface Request { + /** + * Information about the validated access token, if the `requireBearerAuth` middleware was used. + */ + auth?: AuthInfo; + } +} + +function getUserIdFromAuth(auth?: AuthInfo): string | null { + return auth?.extra?.userId as string || null; +} + +export async function handleStreamableHTTP(req: Request, res: Response) { + let shttpTransport: StreamableHTTPServerTransport | undefined = undefined; + + res.on('finish', async () => { + await shttpTransport?.close(); + }); + + const onsessionclosed = async (sessionId: string) => { + logger.info('Session closed callback triggered', { + sessionId, + userId: getUserIdFromAuth(req.auth) + }); + await shutdownSession(sessionId); + } + + try { + // Check for existing session ID + const sessionId = req.headers['mcp-session-id'] as string | undefined; + const userId = getUserIdFromAuth(req.auth); + + // if no userid, return 401, we shouldn't get here ideally + if (!userId) { + logger.warning('Request without user ID', { + sessionId, + hasAuth: !!req.auth + }); + res.status(401) + return; + } + + const isGetRequest = req.method === 'GET'; + + // incorrect session for the authed user, return 401 + if (sessionId) { + if (!(await isSessionOwnedBy(sessionId, userId))) { + logger.warning('Session ownership mismatch', { + sessionId, + userId, + requestMethod: req.method + }); + res.status(401) + return; + } + // Reuse existing transport for owned session + logger.info('Reusing existing session', { + sessionId, + userId, + isGetRequest + }); + shttpTransport = await getShttpTransport(sessionId, onsessionclosed, isGetRequest); + } else if (isInitializeRequest(req.body)) { + // New initialization request - use JSON response mode + const onsessioninitialized = async (sessionId: string) => { + logger.info('Initializing new session', { + sessionId, + userId + }); + + const { server, cleanup: mcpCleanup } = createMcpServer(); + + const serverRedisTransport = new ServerRedisTransport(sessionId); + serverRedisTransport.onclose = mcpCleanup; + await server.connect(serverRedisTransport) + + // Set session ownership + await setSessionOwner(sessionId, userId); + + logger.info('Session initialized successfully', { + sessionId, + userId + }); + } + + const sessionId = randomUUID(); + shttpTransport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => sessionId, + onsessionclosed, + onsessioninitialized, + }); + shttpTransport.onclose = await redisRelayToMcpServer(sessionId, shttpTransport); + } else { + // Invalid request - no session ID and not initialization request + logger.warning('Invalid request: no session ID and not initialization', { + hasSessionId: !!sessionId, + isInitRequest: false, + userId, + method: req.method + }); + res.status(400) + return; + } + // Handle the request with existing transport - no need to reconnect + await shttpTransport.handleRequest(req, res, req.body); + } catch (error) { + logger.error('Error handling MCP request', error as Error, { + sessionId: req.headers['mcp-session-id'] as string | undefined, + method: req.method, + userId: getUserIdFromAuth(req.auth) + }); + + if (!res.headersSent) { + res.status(500) + } + } +} diff --git a/src/handlers/sse.ts b/src/handlers/sse.ts new file mode 100644 index 0000000..e80c67a --- /dev/null +++ b/src/handlers/sse.ts @@ -0,0 +1,107 @@ +import { AuthInfo } from "@modelcontextprotocol/sdk/server/auth/types.js"; +import { SSEServerTransport } from "@modelcontextprotocol/sdk/server/sse.js"; +import contentType from "content-type"; +import { Request, Response } from "express"; +import getRawBody from "raw-body"; +import { redisClient } from "../redis.js"; +import { createMcpServer } from "../services/mcp.js"; +import { logMcpMessage } from "./common.js"; +import { logger } from "../utils/logger.js"; + +const MAXIMUM_MESSAGE_SIZE = "4mb"; + +declare module "express-serve-static-core" { + interface Request { + /** + * Information about the validated access token, if the `requireBearerAuth` middleware was used. + */ + auth?: AuthInfo; + } +} + +function redisChannelForSession(sessionId: string): string { + return `mcp:${sessionId}`; +} + +export async function handleSSEConnection(req: Request, res: Response) { + const { server: mcpServer, cleanup: mcpCleanup } = createMcpServer(); + const transport = new SSEServerTransport("/message", res); + logger.info('Received MCP SSE connection', { + sessionId: transport.sessionId + }); + + const redisCleanup = await redisClient.createSubscription( + redisChannelForSession(transport.sessionId), + (json) => { + // TODO handle DELETE messages + // TODO set timeout to kill the session + + const message = JSON.parse(json); + logMcpMessage(message, transport.sessionId); + transport.handleMessage(message).catch((error) => { + logger.error('Error handling message', error as Error, { + sessionId: transport.sessionId + }); + }); + }, + (error) => { + logger.error('Disconnecting due to error in Redis subscriber', error as Error, { + sessionId: transport.sessionId + }); + transport + .close() + .catch((error) => + logger.error('Error closing transport', error as Error, { + sessionId: transport.sessionId + }), + ); + }, + ); + + const cleanup = () => { + void mcpCleanup(); + redisCleanup().catch((error) => + logger.error('Error disconnecting Redis subscriber', error as Error, { + sessionId: transport.sessionId + }), + ); + } + + // Clean up Redis subscription when the connection closes + mcpServer.onclose = cleanup + + logger.info('Listening on Redis channel', { + sessionId: transport.sessionId, + channel: redisChannelForSession(transport.sessionId) + }); + await mcpServer.connect(transport); +} + +export async function handleMessage(req: Request, res: Response) { + const sessionId = req.query.sessionId; + let body: string; + try { + if (typeof sessionId !== "string") { + throw new Error("Only one sessionId allowed"); + } + + const ct = contentType.parse(req.headers["content-type"] ?? ""); + if (ct.type !== "application/json") { + throw new Error(`Unsupported content-type: ${ct}`); + } + + body = await getRawBody(req, { + limit: MAXIMUM_MESSAGE_SIZE, + encoding: ct.parameters.charset ?? "utf-8", + }); + } catch (error) { + res.status(400).json(error); + logger.error('Bad POST request', error as Error, { + sessionId, + contentType: req.headers['content-type'] + }); + return; + } + await redisClient.publish(redisChannelForSession(sessionId), body); + res.status(202).end(); +} diff --git a/src/index.ts b/src/index.ts index 4fc70ab..7b67dbd 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,15 +1,24 @@ +import { BearerAuthMiddlewareOptions, requireBearerAuth } from "@modelcontextprotocol/sdk/server/auth/middleware/bearerAuth.js"; +import { AuthRouterOptions, mcpAuthRouter } from "@modelcontextprotocol/sdk/server/auth/router.js"; import cors from "cors"; import express from "express"; -import { BASE_URI, PORT } from "./config.js"; -import { AuthRouterOptions, mcpAuthRouter } from "@modelcontextprotocol/sdk/server/auth/router.js"; +import path from "path"; +import { fileURLToPath } from "url"; import { EverythingAuthProvider } from "./auth/provider.js"; -import { handleMessage, handleSSEConnection, authContext } from "./handlers/mcp.js"; -import { handleFakeAuthorizeRedirect, handleFakeAuthorize } from "./handlers/fakeauth.js"; +import { BASE_URI, PORT } from "./config.js"; +import { authContext } from "./handlers/common.js"; +import { handleFakeAuthorize, handleFakeAuthorizeRedirect } from "./handlers/fakeauth.js"; +import { handleStreamableHTTP } from "./handlers/shttp.js"; +import { handleMessage, handleSSEConnection } from "./handlers/sse.js"; import { redisClient } from "./redis.js"; -import { requireBearerAuth } from "@modelcontextprotocol/sdk/server/auth/middleware/bearerAuth.js"; +import { logger } from "./utils/logger.js"; const app = express(); +// Get the directory of the current module +const __filename = fileURLToPath(import.meta.url); +const __dirname = path.dirname(__filename); + // Base security middleware - applied to all routes const baseSecurityHeaders = (req: express.Request, res: express.Response, next: express.NextFunction) => { // Basic security headers @@ -32,14 +41,45 @@ const baseSecurityHeaders = (req: express.Request, res: express.Response, next: next(); }; -// simple logging middleware -const logger = (req: express.Request, res: express.Response, next: express.NextFunction) => { - console.log(`${req.method} ${req.url}`); - next(); - // Log the response status code +// Structured logging middleware +const loggingMiddleware = (req: express.Request, res: express.Response, next: express.NextFunction) => { + const startTime = Date.now(); + + // Sanitize headers to remove sensitive information + const sanitizedHeaders = { ...req.headers }; + delete sanitizedHeaders.authorization; + delete sanitizedHeaders.cookie; + delete sanitizedHeaders['x-api-key']; + + // Log request (without sensitive data) + logger.info('Request received', { + method: req.method, + url: req.url, + // Only log specific safe headers + headers: { + 'content-type': sanitizedHeaders['content-type'], + 'user-agent': sanitizedHeaders['user-agent'], + 'mcp-protocol-version': sanitizedHeaders['mcp-protocol-version'], + 'mcp-session-id': sanitizedHeaders['mcp-session-id'], + 'accept': sanitizedHeaders['accept'], + 'x-cloud-trace-context': sanitizedHeaders['x-cloud-trace-context'] + }, + // Don't log request body as it may contain sensitive data + bodySize: req.headers['content-length'] + }); + + // Log response when finished res.on('finish', () => { - console.log(`Response status code: ${res.statusCode}`); + const duration = Date.now() - startTime; + logger.info('Request completed', { + method: req.method, + url: req.url, + statusCode: res.statusCode, + duration: `${duration}ms` + }); }); + + next(); }; @@ -61,11 +101,19 @@ const sseHeaders = (req: express.Request, res: express.Response, next: express.N const corsOptions = { origin: true, // Allow any origin methods: ['GET', 'POST'], - allowedHeaders: ['Content-Type', 'Authorization', "MCP-Protocol-Version"], + allowedHeaders: ['Content-Type', 'Authorization', "Mcp-Protocol-Version", "Mcp-Protocol-Id"], + exposedHeaders: ["Mcp-Protocol-Version", "Mcp-Protocol-Id"], credentials: true }; -app.use(logger); + +app.use(express.json()); + +// Add structured logging context middleware first +app.use(logger.middleware()); + +// Then add the logging middleware +app.use(loggingMiddleware); // Apply base security headers to all routes app.use(baseSecurityHeaders); @@ -73,6 +121,8 @@ app.use(baseSecurityHeaders); // Enable CORS pre-flight requests app.options('*', cors(corsOptions)); + +const authProvider = new EverythingAuthProvider(); // Auth configuration const options: AuthRouterOptions = { provider: new EverythingAuthProvider(), @@ -84,24 +134,47 @@ const options: AuthRouterOptions = { } } }; + +const dearerAuthMiddlewareOptions: BearerAuthMiddlewareOptions = { + // verifyAccessToken(token: string): Promise; + verifier: { + verifyAccessToken: authProvider.verifyAccessToken.bind(authProvider), + } +} + app.use(mcpAuthRouter(options)); -const bearerAuth = requireBearerAuth(options); +const bearerAuth = requireBearerAuth(dearerAuthMiddlewareOptions); -// MCP routes +// MCP routes (legacy SSE transport) app.get("/sse", cors(corsOptions), bearerAuth, authContext, sseHeaders, handleSSEConnection); app.post("/message", cors(corsOptions), bearerAuth, authContext, sensitiveDataHeaders, handleMessage); -// Upstream auth routes +// MCP routes (new streamable HTTP transport) +app.get("/mcp", cors(corsOptions), bearerAuth, authContext, handleStreamableHTTP); +app.post("/mcp", cors(corsOptions), bearerAuth, authContext, handleStreamableHTTP); +app.delete("/mcp", cors(corsOptions), bearerAuth, authContext, handleStreamableHTTP); + +// Static assets +app.get("/mcp-logo.png", (req, res) => { + const logoPath = path.join(__dirname, "static", "mcp.png"); + res.sendFile(logoPath); +}); + +// Upstream auth routes app.get("/fakeupstreamauth/authorize", cors(corsOptions), handleFakeAuthorize); app.get("/fakeupstreamauth/callback", cors(corsOptions), handleFakeAuthorizeRedirect); try { await redisClient.connect(); } catch (error) { - console.error("Could not connect to Redis:", error); + logger.error("Could not connect to Redis", error as Error); process.exit(1); } app.listen(PORT, () => { - console.log(`Server running on http://localhost:${PORT}`); + logger.info('Server started', { + port: PORT, + url: `http://localhost:${PORT}`, + environment: process.env.NODE_ENV || 'development' + }); }); diff --git a/src/redis.ts b/src/redis.ts index 7bbdc7e..1826bff 100644 --- a/src/redis.ts +++ b/src/redis.ts @@ -1,4 +1,5 @@ import { createClient, SetOptions } from "@redis/client"; +import { logger } from "./utils/logger.js"; /** * Describes the Redis primitives we use in this application, to be able to mock @@ -8,9 +9,15 @@ export interface RedisClient { get(key: string): Promise; set(key: string, value: string, options?: SetOptions): Promise; getDel(key: string): Promise; + del(key: string): Promise; + expire(key: string, seconds: number): Promise; + lpush(key: string, ...values: string[]): Promise; + lrange(key: string, start: number, stop: number): Promise; connect(): Promise; on(event: string, callback: (error: Error) => void): void; options?: { url: string }; + exists(key: string): Promise; + numsub(key: string): Promise; /** * Creates a pub/sub subscription. Returns a cleanup function to unsubscribe. @@ -40,10 +47,15 @@ export class RedisClientImpl implements RedisClient { constructor() { this.redis.on("error", (error) => - console.error("Redis client error:", error), + logger.error("Redis client error", error as Error), ); } + async numsub(key: string): Promise { + const subs = await this.redis.pubSubNumSub(key); + return subs[key] || 0; + } + async get(key: string): Promise { return await this.redis.get(key); } @@ -60,6 +72,22 @@ export class RedisClientImpl implements RedisClient { ); } + async del(key: string): Promise { + return await this.redis.del(key); + } + + async expire(key: string, seconds: number): Promise { + return await this.redis.expire(key, seconds); + } + + async lpush(key: string, ...values: string[]): Promise { + return await this.redis.lPush(key, values); + } + + async lrange(key: string, start: number, stop: number): Promise { + return await this.redis.lRange(key, start, stop); + } + async connect(): Promise { await this.redis.connect(); } @@ -78,9 +106,15 @@ export class RedisClientImpl implements RedisClient { onError: (error: Error) => void, ): Promise<() => Promise> { const subscriber = this.redis.duplicate(); - subscriber.on("error", onError); + subscriber.on("error", (error) => { + onError(error); + }); + await subscriber.connect(); - await subscriber.subscribe(channel, onMessage); + + await subscriber.subscribe(channel, (message) => { + onMessage(message); + }); return async () => { await subscriber.disconnect(); @@ -90,6 +124,11 @@ export class RedisClientImpl implements RedisClient { async publish(channel: string, message: string): Promise { await this.redis.publish(channel, message); } + + async exists(key: string): Promise { + const result = await this.redis.exists(key); + return result > 0; + } } // Export a mutable reference that can be swapped in tests @@ -103,7 +142,8 @@ export function setRedisClient(client: RedisClient) { export class MockRedisClient implements RedisClient { options = { url: "redis://localhost:6379" }; private store = new Map(); - private subscribers = new Map void)[]>(); + private lists = new Map(); + public subscribers = new Map void)[]>(); // Public for testing access private errorCallbacks = new Map void)[]>(); async get(key: string): Promise { @@ -125,6 +165,39 @@ export class MockRedisClient implements RedisClient { return oldValue; } + async del(key: string): Promise { + let deleted = 0; + if (this.store.has(key)) { + this.store.delete(key); + deleted++; + } + if (this.lists.has(key)) { + this.lists.delete(key); + deleted++; + } + return deleted; + } + + async expire(key: string, _seconds: number): Promise { + // Mock implementation - just return true if key exists + return this.store.has(key) || this.lists.has(key); + } + + async lpush(key: string, ...values: string[]): Promise { + const list = this.lists.get(key) || []; + list.unshift(...values); + this.lists.set(key, list); + return list.length; + } + + async lrange(key: string, start: number, stop: number): Promise { + const list = this.lists.get(key) || []; + if (stop === -1) { + return list.slice(start); + } + return list.slice(start, stop + 1); + } + async connect(): Promise { // No-op in mock } @@ -183,8 +256,17 @@ export class MockRedisClient implements RedisClient { } } + async exists(key: string): Promise { + return this.store.has(key) || this.lists.has(key); + } + + async numsub(key: string): Promise { + return (this.subscribers.get(key) || []).length; + } + clear() { this.store.clear(); + this.lists.clear(); this.subscribers.clear(); this.errorCallbacks.clear(); } diff --git a/src/services/auth.test.ts b/src/services/auth.test.ts index f8e167f..6a761bd 100644 --- a/src/services/auth.test.ts +++ b/src/services/auth.test.ts @@ -153,10 +153,16 @@ describe("auth utils", () => { const first = await exchangeToken(authCode); expect(first).toBeDefined(); + // Mock console.error to suppress expected error message + const consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation(() => {}); + // Second exchange throws await expect(exchangeToken(authCode)).rejects.toThrow( "Duplicate use of authorization code detected" ); + + // Restore console.error + consoleErrorSpy.mockRestore(); }); it("returns undefined for non-existent code", async () => { @@ -187,6 +193,7 @@ describe("auth utils", () => { }, clientId: "client-id", issuedAt: Date.now() / 1000, + userId: "test-user-id", } await saveMcpInstallation(accessToken, mcpInstallation); @@ -233,6 +240,7 @@ describe("auth utils", () => { }, clientId: "client-id", issuedAt: Date.now() / 1000, + userId: "test-user-id", }); const getDel = jest.spyOn(mockRedis, 'getDel').mockImplementationOnce(() => { @@ -246,6 +254,7 @@ describe("auth utils", () => { }, clientId: "client-id", issuedAt: Date.now() / 1000, + userId: "test-user-id", }; const value = JSON.stringify(mcpInstallation); const iv = crypto.randomBytes(16); diff --git a/src/services/auth.ts b/src/services/auth.ts index c4fa5ee..02dae0f 100644 --- a/src/services/auth.ts +++ b/src/services/auth.ts @@ -3,6 +3,7 @@ import crypto from "crypto"; import { redisClient } from "../redis.js"; import { McpInstallation, PendingAuthorization, TokenExchange } from "../types.js"; import { OAuthClientInformationFull, OAuthTokens } from "@modelcontextprotocol/sdk/shared/auth.js"; +import { logger } from "../utils/logger.js"; export function generatePKCEChallenge(verifier: string): string { const buffer = Buffer.from(verifier); @@ -243,7 +244,9 @@ export async function exchangeToken( const tokenExchange: TokenExchange = JSON.parse(decoded); if (tokenExchange.alreadyUsed) { - console.error("Duplicate use of authorization code detected; revoking tokens"); + logger.error('Duplicate use of authorization code detected; revoking tokens', undefined, { + authorizationCode: authorizationCode.substring(0, 8) + '...' + }); await revokeMcpInstallation(tokenExchange.mcpAccessToken); throw new Error("Duplicate use of authorization code detected; tokens revoked"); } @@ -257,7 +260,9 @@ export async function exchangeToken( if (rereadData !== data) { // Data concurrently changed while we were updating it. This necessarily means a duplicate use. - console.error("Duplicate use of authorization code detected; revoking tokens"); + logger.error('Duplicate use of authorization code detected (concurrent update); revoking tokens', undefined, { + authorizationCode: authorizationCode.substring(0, 8) + '...' + }); await revokeMcpInstallation(tokenExchange.mcpAccessToken); throw new Error("Duplicate use of authorization code detected; tokens revoked"); } diff --git a/src/services/mcp.ts b/src/services/mcp.ts index 16f20fc..04a5e65 100644 --- a/src/services/mcp.ts +++ b/src/services/mcp.ts @@ -15,13 +15,16 @@ import { SetLevelRequestSchema, SubscribeRequestSchema, Tool, - ToolSchema, UnsubscribeRequestSchema, } from "@modelcontextprotocol/sdk/types.js"; import { z } from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; -type ToolInput = z.infer; +type ToolInput = { + type: "object"; + properties?: Record; + required?: string[]; +}; /* Input schemas for tools implemented in this server */ const EchoSchema = z.object({ @@ -92,7 +95,12 @@ enum PromptName { RESOURCE = "resource_prompt", } -export const createMcpServer = () => { +interface McpServerWrapper { + server: Server; + cleanup: () => void; +} + +export const createMcpServer = (): McpServerWrapper => { const server = new Server( { name: "example-servers/everything", diff --git a/src/services/redisTransport.integration.test.ts b/src/services/redisTransport.integration.test.ts new file mode 100644 index 0000000..3fcc778 --- /dev/null +++ b/src/services/redisTransport.integration.test.ts @@ -0,0 +1,291 @@ +import { jest } from '@jest/globals'; +import { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; +import { MockRedisClient, setRedisClient } from '../redis.js'; +import { + ServerRedisTransport, + redisRelayToMcpServer, + shutdownSession +} from './redisTransport.js'; +import { createMcpServer } from './mcp.js'; +import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; + +describe('Redis Transport Integration', () => { + let mockRedis: MockRedisClient; + + beforeEach(() => { + mockRedis = new MockRedisClient(); + setRedisClient(mockRedis); + jest.resetAllMocks(); + }); + + afterEach(() => { + mockRedis.clear(); + }); + + describe('MCP Initialization Flow', () => { + const sessionId = 'test-init-session'; + + it('should relay initialization request from client to server through Redis', async () => { + // 1. Start the server listening to Redis + const { server, cleanup: serverCleanup } = createMcpServer(); + const serverTransport = new ServerRedisTransport(sessionId); + serverTransport.onclose = serverCleanup; + await server.connect(serverTransport); + + // 2. Create a mock client transport (simulating the streamable HTTP client side) + const mockClientTransport: Transport = { + onmessage: undefined, + onclose: undefined, + onerror: undefined, + send: jest.fn(() => Promise.resolve()), + close: jest.fn(() => Promise.resolve()), + start: jest.fn(() => Promise.resolve()) + }; + + // 3. Set up the Redis relay (this is what happens in the HTTP handler) + const cleanup = await redisRelayToMcpServer(sessionId, mockClientTransport); + + // Track messages received by server + const serverReceivedMessages: JSONRPCMessage[] = []; + const originalServerOnMessage = serverTransport.onmessage; + serverTransport.onmessage = (message, extra) => { + serverReceivedMessages.push(message); + originalServerOnMessage?.(message, extra); + }; + + // 4. Simulate client sending initialization request + const initMessage: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'init-1', + method: 'initialize', + params: { + protocolVersion: '2024-11-05', + capabilities: {}, + clientInfo: { + name: 'test-client', + version: '1.0.0' + } + } + }; + + // Trigger the client transport onmessage (simulates HTTP request) + mockClientTransport.onmessage?.(initMessage); + + // Wait for message to be relayed through Redis + await new Promise(resolve => setTimeout(resolve, 50)); + + // 5. Verify server received the init message + expect(serverReceivedMessages).toHaveLength(1); + expect(serverReceivedMessages[0]).toMatchObject({ + jsonrpc: '2.0', + id: 'init-1', + method: 'initialize' + }); + + // 6. Simulate server responding (this should get relayed back to client) + const initResponse: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'init-1', + result: { + protocolVersion: '2024-11-05', + capabilities: { + tools: {}, + prompts: {}, + resources: {} + }, + serverInfo: { + name: 'example-server', + version: '1.0.0' + } + } + }; + + await serverTransport.send(initResponse, { relatedRequestId: 'init-1' }); + + // Wait for response to be relayed back + await new Promise(resolve => setTimeout(resolve, 50)); + + // 7. Verify client transport received the response + expect(mockClientTransport.send).toHaveBeenCalledWith( + expect.objectContaining({ + jsonrpc: '2.0', + id: 'init-1', + result: expect.objectContaining({ + protocolVersion: '2024-11-05', + serverInfo: expect.objectContaining({ + name: 'example-server' + }) + }) + }), + { relatedRequestId: 'init-1' } + ); + + // Cleanup + await cleanup(); + await shutdownSession(sessionId); + serverCleanup(); // Clean up MCP server intervals + + // Ensure server transport is closed + await serverTransport.close(); + + await new Promise(resolve => setTimeout(resolve, 10)); + }); + + it('should handle tools/list request through Redis relay', async () => { + // Set up server and mock client + const { server, cleanup: serverCleanup } = createMcpServer(); + const serverTransport = new ServerRedisTransport(sessionId); + serverTransport.onclose = serverCleanup; + await server.connect(serverTransport); + + const mockClientTransport: Transport = { + onmessage: undefined, + onclose: undefined, + onerror: undefined, + send: jest.fn(() => Promise.resolve()), + close: jest.fn(() => Promise.resolve()), + start: jest.fn(() => Promise.resolve()) + }; + + const cleanup = await redisRelayToMcpServer(sessionId, mockClientTransport); + + // Send tools/list request + const toolsListMessage: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'tools-1', + method: 'tools/list', + params: {} + }; + + mockClientTransport.onmessage?.(toolsListMessage); + + // Wait for processing and response + await new Promise(resolve => setTimeout(resolve, 100)); + + // Verify client received a response with tools + expect(mockClientTransport.send).toHaveBeenCalledWith( + expect.objectContaining({ + jsonrpc: '2.0', + id: 'tools-1', + result: expect.objectContaining({ + tools: expect.any(Array) + }) + }), + undefined + ); + + // Cleanup + await cleanup(); + await shutdownSession(sessionId); + serverCleanup(); // Clean up MCP server intervals + + // Ensure server transport is closed + await serverTransport.close(); + + await new Promise(resolve => setTimeout(resolve, 10)); + }); + + it('should handle notifications through Redis relay', async () => { + // Set up server and mock client + const { server, cleanup: serverCleanup } = createMcpServer(); + const serverTransport = new ServerRedisTransport(sessionId); + serverTransport.onclose = serverCleanup; + await server.connect(serverTransport); + + const mockClientTransport: Transport = { + onmessage: undefined, + onclose: undefined, + onerror: undefined, + send: jest.fn(() => Promise.resolve()), + close: jest.fn(() => Promise.resolve()), + start: jest.fn(() => Promise.resolve()) + }; + + const cleanup = await redisRelayToMcpServer(sessionId, mockClientTransport); + + // Set up notification subscription manually since notifications don't have an id + const notificationChannel = `mcp:shttp:toclient:${sessionId}:__GET_stream`; + const notificationCleanup = await mockRedis.createSubscription(notificationChannel, async (redisMessageJson) => { + const redisMessage = JSON.parse(redisMessageJson); + if (redisMessage.type === 'mcp') { + await mockClientTransport.send(redisMessage.message, redisMessage.options); + } + }, (error) => { + mockClientTransport.onerror?.(error); + }); + + // Send a notification from server (notifications don't have an id) + const notification: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'notifications/message', + params: { + level: 'info', + logger: 'test', + data: 'Test notification' + } + }; + + await serverTransport.send(notification); + + // Wait for notification to be delivered + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify client received the notification + expect(mockClientTransport.send).toHaveBeenCalledWith( + expect.objectContaining({ + jsonrpc: '2.0', + method: 'notifications/message', + params: expect.objectContaining({ + level: 'info', + data: 'Test notification' + }) + }), + undefined + ); + + // Cleanup notification subscription + await notificationCleanup(); + + // Cleanup + await cleanup(); + await shutdownSession(sessionId); + serverCleanup(); // Clean up MCP server intervals + + // Ensure server transport is closed + await serverTransport.close(); + + await new Promise(resolve => setTimeout(resolve, 10)); + }); + + it('should not create response subscriptions for notifications', async () => { + const mockClientTransport: Transport = { + onmessage: undefined, + onclose: undefined, + onerror: undefined, + send: jest.fn(() => Promise.resolve()), + close: jest.fn(() => Promise.resolve()), + start: jest.fn(() => Promise.resolve()) + }; + + const cleanup = await redisRelayToMcpServer(sessionId, mockClientTransport); + + // Send a notification (no id field) + const notification: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'notifications/initialized', + params: {} + }; + + mockClientTransport.onmessage?.(notification); + + // Wait for processing + await new Promise(resolve => setTimeout(resolve, 50)); + + // Should not create any response channel subscriptions for notifications + // (we can't easily test this directly, but we can ensure no errors occur) + expect(mockClientTransport.send).not.toHaveBeenCalled(); + + await cleanup(); + }); + }); +}); \ No newline at end of file diff --git a/src/services/redisTransport.test.ts b/src/services/redisTransport.test.ts new file mode 100644 index 0000000..60e7e33 --- /dev/null +++ b/src/services/redisTransport.test.ts @@ -0,0 +1,536 @@ +import { jest } from '@jest/globals'; +import { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; +import { MockRedisClient, setRedisClient } from '../redis.js'; +import { + ServerRedisTransport, + redisRelayToMcpServer, + isLive, + shutdownSession, + setSessionOwner, + getSessionOwner, + validateSessionOwnership, + isSessionOwnedBy +} from './redisTransport.js'; +import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; + +describe('Redis Transport', () => { + let mockRedis: MockRedisClient; + + beforeEach(() => { + mockRedis = new MockRedisClient(); + setRedisClient(mockRedis); + jest.resetAllMocks(); + }); + + afterEach(() => { + // Clear all Redis data and subscriptions + mockRedis.clear(); + }); + + describe('ServerRedisTransport', () => { + let transport: ServerRedisTransport; + const sessionId = 'test-session-123'; + + beforeEach(() => { + transport = new ServerRedisTransport(sessionId); + }); + + afterEach(async () => { + if (transport) { + await transport.close(); + } + }); + + it('should create transport with session ID', () => { + expect(transport).toBeInstanceOf(ServerRedisTransport); + }); + + it('should send response messages to request-specific channels', async () => { + const responseMessage: JSONRPCMessage = { + jsonrpc: '2.0', + id: 123, + result: { data: 'test response' } + }; + + const mockSubscriber = jest.fn(); + await mockRedis.createSubscription( + `mcp:shttp:toclient:${sessionId}:123`, + mockSubscriber, + jest.fn() + ); + + await transport.send(responseMessage, { relatedRequestId: 123 }); + + expect(mockSubscriber).toHaveBeenCalledWith( + JSON.stringify({ + type: 'mcp', + message: responseMessage, + options: { relatedRequestId: 123 } + }) + ); + }); + + it('should send notification messages to notification channel', async () => { + const notificationMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'notifications/message', + params: { message: 'test notification' } + }; + + const mockSubscriber = jest.fn(); + await mockRedis.createSubscription( + `mcp:shttp:toclient:${sessionId}:__GET_stream`, + mockSubscriber, + jest.fn() + ); + + await transport.send(notificationMessage); + + expect(mockSubscriber).toHaveBeenCalledWith( + JSON.stringify({ + type: 'mcp', + message: notificationMessage, + options: undefined + }) + ); + }); + + it('should handle close gracefully', async () => { + const onCloseMock = jest.fn(); + transport.onclose = onCloseMock; + + await transport.close(); + + expect(onCloseMock).toHaveBeenCalled(); + }); + + it('should respond to shutdown control messages', async () => { + await transport.start(); + + const onCloseMock = jest.fn(); + transport.onclose = onCloseMock; + + // Send a shutdown control message + await shutdownSession(sessionId); + + // Wait for async processing + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(onCloseMock).toHaveBeenCalled(); + }); + + it('should receive MCP messages from clients and call onmessage', async () => { + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + await transport.start(); + + // Simulate client sending a message to server + const clientMessage: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'test-req', + method: 'tools/list', + params: {} + }; + + await mockRedis.publish( + `mcp:shttp:toserver:${sessionId}`, + JSON.stringify({ + type: 'mcp', + message: clientMessage, + extra: { authInfo: { token: 'test-token', clientId: 'test-client', scopes: [] } } + }) + ); + + // Wait for async processing + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(onMessageMock).toHaveBeenCalledWith( + clientMessage, + { authInfo: { token: 'test-token', clientId: 'test-client', scopes: [] } } + ); + + await transport.close(); + }); + }); + + + describe('redisRelayToMcpServer', () => { + let mockTransport: Transport; + const sessionId = 'test-session-456'; + + beforeEach(() => { + mockTransport = { + onmessage: undefined, + onclose: undefined, + onerror: undefined, + send: jest.fn(() => Promise.resolve()), + close: jest.fn(() => Promise.resolve()), + start: jest.fn(() => Promise.resolve()) + }; + }); + + it('should set up message relay from transport to server', async () => { + const cleanup = await redisRelayToMcpServer(sessionId, mockTransport); + + // Simulate a message from the transport + const requestMessage: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'req-123', + method: 'tools/list', + params: {} + }; + + // Trigger the onmessage handler + mockTransport.onmessage?.(requestMessage, { authInfo: { token: 'test-token', clientId: 'test-client', scopes: [] } }); + + // Wait a bit for async processing + await new Promise(resolve => setTimeout(resolve, 10)); + + // Check that message was published to server channel + const serverSubscriber = jest.fn(); + await mockRedis.createSubscription( + `mcp:shttp:toserver:${sessionId}`, + serverSubscriber, + jest.fn() + ); + + // The message should have been published + expect(mockRedis.numsub(`mcp:shttp:toserver:${sessionId}`)).resolves.toBe(1); + + await cleanup(); + }); + + it('should subscribe to response channel for request messages', async () => { + const cleanup = await redisRelayToMcpServer(sessionId, mockTransport); + + const requestMessage: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'req-456', + method: 'tools/call', + params: { name: 'echo', arguments: { text: 'hello' } } + }; + + // Trigger the onmessage handler + mockTransport.onmessage?.(requestMessage, { authInfo: { token: 'test-token', clientId: 'test-client', scopes: [] } }); + + // Wait for subscription setup + await new Promise(resolve => setTimeout(resolve, 10)); + + // Now simulate a response from the server + const responseMessage: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'req-456', + result: { content: [{ type: 'text', text: 'hello' }] } + }; + + await mockRedis.publish( + `mcp:shttp:toclient:${sessionId}:req-456`, + JSON.stringify({ + type: 'mcp', + message: responseMessage, + options: undefined + }) + ); + + // Check that the response was sent back to the transport + expect(mockTransport.send).toHaveBeenCalledWith(responseMessage, undefined); + + await cleanup(); + }); + + it('should not subscribe for notification messages (no id)', async () => { + const cleanup = await redisRelayToMcpServer(sessionId, mockTransport); + + const notificationMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'notifications/message', + params: { message: 'test' } + }; + + // Trigger the onmessage handler + mockTransport.onmessage?.(notificationMessage); + + // Wait a bit + await new Promise(resolve => setTimeout(resolve, 10)); + + // Should not create any response channel subscriptions + expect(await mockRedis.numsub(`mcp:shttp:toclient:${sessionId}:undefined`)).toBe(0); + + await cleanup(); + }); + }); + + describe('isLive', () => { + const sessionId = 'test-session-789'; + + it('should return true when session has active subscribers', async () => { + // Create a subscription to the server channel + await mockRedis.createSubscription( + `mcp:shttp:toserver:${sessionId}`, + jest.fn(), + jest.fn() + ); + + expect(await isLive(sessionId)).toBe(true); + }); + + it('should return false when session has no subscribers', async () => { + expect(await isLive(sessionId)).toBe(false); + }); + }); + + describe('Session Ownership', () => { + const sessionId = 'test-session-ownership'; + const userId = 'test-user-123'; + + it('should set and get session owner', async () => { + await setSessionOwner(sessionId, userId); + const owner = await getSessionOwner(sessionId); + expect(owner).toBe(userId); + }); + + it('should validate session ownership correctly', async () => { + await setSessionOwner(sessionId, userId); + + expect(await validateSessionOwnership(sessionId, userId)).toBe(true); + expect(await validateSessionOwnership(sessionId, 'different-user')).toBe(false); + }); + + it('should check if session is owned by user including liveness', async () => { + // Session not live yet + expect(await isSessionOwnedBy(sessionId, userId)).toBe(false); + + // Make session live + await mockRedis.createSubscription( + `mcp:shttp:toserver:${sessionId}`, + jest.fn(), + jest.fn() + ); + + // Still false because no owner set + expect(await isSessionOwnedBy(sessionId, userId)).toBe(false); + + // Set owner + await setSessionOwner(sessionId, userId); + + // Now should be true + expect(await isSessionOwnedBy(sessionId, userId)).toBe(true); + + // False for different user + expect(await isSessionOwnedBy(sessionId, 'different-user')).toBe(false); + }); + }); + + describe('Integration: Redis message flow', () => { + const sessionId = 'integration-test-session'; + + it('should relay messages between client and server through Redis', async () => { + // Set up client-side transport simulation + const clientTransport: Transport = { + onmessage: undefined, + onclose: undefined, + onerror: undefined, + send: jest.fn(() => Promise.resolve()), + close: jest.fn(() => Promise.resolve()), + start: jest.fn(() => Promise.resolve()) + }; + + const cleanup = await redisRelayToMcpServer(sessionId, clientTransport); + + // Client sends a request + const listToolsRequest: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'integration-req-1', + method: 'tools/list', + params: {} + }; + + // Set up subscription to simulate server receiving the message + const serverSubscriber = jest.fn(); + await mockRedis.createSubscription( + `mcp:shttp:toserver:${sessionId}`, + serverSubscriber, + jest.fn() + ); + + // Simulate client sending request + clientTransport.onmessage?.(listToolsRequest); + + // Wait for async processing + await new Promise(resolve => setTimeout(resolve, 10)); + + // Verify the message was published to server channel + expect(serverSubscriber).toHaveBeenCalledWith( + JSON.stringify({ + type: 'mcp', + message: listToolsRequest, + extra: undefined, + options: undefined + }) + ); + + // Simulate server sending response + const serverResponse: JSONRPCMessage = { + jsonrpc: '2.0', + id: 'integration-req-1', + result: { tools: [{ name: 'echo', description: 'Echo tool' }] } + }; + + await mockRedis.publish( + `mcp:shttp:toclient:${sessionId}:integration-req-1`, + JSON.stringify({ + type: 'mcp', + message: serverResponse, + options: undefined + }) + ); + + // Wait for response processing + await new Promise(resolve => setTimeout(resolve, 10)); + + // Verify the response was sent back to client + expect(clientTransport.send).toHaveBeenCalledWith(serverResponse, undefined); + + await cleanup(); + }); + }); + + describe('Control Messages', () => { + const sessionId = 'test-control-session'; + + it('should send shutdown control messages', async () => { + const controlSubscriber = jest.fn(); + await mockRedis.createSubscription( + `mcp:control:${sessionId}`, + controlSubscriber, + jest.fn() + ); + + await shutdownSession(sessionId); + + const callArgs = controlSubscriber.mock.calls[0][0] as string; + const message = JSON.parse(callArgs); + + expect(message.type).toBe('control'); + expect(message.action).toBe('SHUTDOWN'); + expect(typeof message.timestamp).toBe('number'); + }); + + it('should properly shutdown server transport via control message', async () => { + const transport = new ServerRedisTransport(sessionId); + const onCloseMock = jest.fn(); + transport.onclose = onCloseMock; + + await transport.start(); + + // Send shutdown signal + await shutdownSession(sessionId); + + // Wait for async processing + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(onCloseMock).toHaveBeenCalled(); + }); + }); + + describe('Inactivity Timeout', () => { + const sessionId = 'test-inactivity-session'; + + beforeEach(() => { + jest.useFakeTimers({ doNotFake: ['setImmediate', 'nextTick'] }); + }); + + afterEach(() => { + jest.useRealTimers(); + }); + + it('should shutdown session after 5 minutes of inactivity', async () => { + const transport = new ServerRedisTransport(sessionId); + const shutdownSpy = jest.spyOn(mockRedis, 'publish'); + + await transport.start(); + + // Fast-forward time by 5 minutes + jest.advanceTimersByTime(5 * 60 * 1000); + + // Should have published shutdown control message + expect(shutdownSpy).toHaveBeenCalledWith( + `mcp:control:${sessionId}`, + expect.stringContaining('"action":"SHUTDOWN"') + ); + + await transport.close(); + }); + + it('should reset timeout when message is received', async () => { + const transport = new ServerRedisTransport(sessionId); + const onMessageMock = jest.fn(); + transport.onmessage = onMessageMock; + + await transport.start(); + + // Fast-forward 4 minutes + jest.advanceTimersByTime(4 * 60 * 1000); + + // Manually publish a message to trigger the subscription handler + const testMessage = { jsonrpc: '2.0', method: 'ping' }; + await mockRedis.publish( + `mcp:shttp:toserver:${sessionId}`, + JSON.stringify({ + type: 'mcp', + message: testMessage + }) + ); + + // Wait for message to be processed + await new Promise(resolve => setImmediate(resolve)); + + // Verify message was received + expect(onMessageMock).toHaveBeenCalledWith(testMessage, undefined); + + // Clear the publish spy to check only future calls + const shutdownSpy = jest.spyOn(mockRedis, 'publish'); + shutdownSpy.mockClear(); + + // Fast-forward 4 more minutes (total 8, but only 4 since last message) + jest.advanceTimersByTime(4 * 60 * 1000); + + // Should not have shutdown yet + expect(shutdownSpy).not.toHaveBeenCalledWith( + `mcp:control:${sessionId}`, + expect.stringContaining('"action":"SHUTDOWN"') + ); + + // Fast-forward 2 more minutes to exceed timeout + jest.advanceTimersByTime(2 * 60 * 1000); + + // Now should have shutdown + expect(shutdownSpy).toHaveBeenCalledWith( + `mcp:control:${sessionId}`, + expect.stringContaining('"action":"SHUTDOWN"') + ); + + await transport.close(); + }, 10000); + + it('should clear timeout on close', async () => { + const transport = new ServerRedisTransport(sessionId); + const shutdownSpy = jest.spyOn(mockRedis, 'publish'); + + await transport.start(); + + // Close transport before timeout + await transport.close(); + + // Fast-forward past timeout + jest.advanceTimersByTime(10 * 60 * 1000); + + // Should not have triggered shutdown + expect(shutdownSpy).not.toHaveBeenCalledWith( + `mcp:control:${sessionId}`, + expect.stringContaining('"action":"SHUTDOWN"') + ); + }); + }); +}); \ No newline at end of file diff --git a/src/services/redisTransport.ts b/src/services/redisTransport.ts new file mode 100644 index 0000000..e35eb9e --- /dev/null +++ b/src/services/redisTransport.ts @@ -0,0 +1,351 @@ +import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; +import { redisClient } from "../redis.js"; +import { Transport, TransportSendOptions } from "@modelcontextprotocol/sdk/shared/transport.js"; +import { AuthInfo } from "@modelcontextprotocol/sdk/server/auth/types.js"; +import { JSONRPCMessage, MessageExtraInfo } from "@modelcontextprotocol/sdk/types.js"; +import { logger } from "../utils/logger.js"; + +let redisTransportCounter = 0; +const notificationStreamId = "__GET_stream"; + +// Message types for Redis transport +type RedisMessage = + | { + type: 'mcp'; + message: JSONRPCMessage; + extra?: MessageExtraInfo; + options?: TransportSendOptions; + } + | { + type: 'control'; + action: 'SHUTDOWN' | 'PING' | 'STATUS'; + timestamp?: number; + }; + +function sendToMcpServer(sessionId: string, message: JSONRPCMessage, extra?: { authInfo?: AuthInfo; }, options?: TransportSendOptions): Promise { + const toServerChannel = getToServerChannel(sessionId); + + logger.debug('Sending message to MCP server via Redis', { + sessionId, + channel: toServerChannel, + method: ('method' in message ? message.method : undefined), + id: ('id' in message ? message.id : undefined) + }); + + const redisMessage: RedisMessage = { type: 'mcp', message, extra, options }; + return redisClient.publish(toServerChannel, JSON.stringify(redisMessage)); +} + +function getToServerChannel(sessionId: string): string { + return `mcp:shttp:toserver:${sessionId}`; +} + +function getToClientChannel(sessionId: string, relatedRequestId: string): string { + return `mcp:shttp:toclient:${sessionId}:${relatedRequestId}`; +} + +function getControlChannel(sessionId: string): string { + return `mcp:control:${sessionId}`; +} + +function sendControlMessage(sessionId: string, action: 'SHUTDOWN' | 'PING' | 'STATUS'): Promise { + const controlChannel = getControlChannel(sessionId); + const redisMessage: RedisMessage = { + type: 'control', + action, + timestamp: Date.now() + }; + return redisClient.publish(controlChannel, JSON.stringify(redisMessage)); +} + +export async function shutdownSession(sessionId: string): Promise { + logger.info('Sending shutdown control message', { sessionId }); + return sendControlMessage(sessionId, 'SHUTDOWN'); +} + +export async function isLive(sessionId: string): Promise { + // Check if the session is live by checking if the key exists in Redis + const numSubs = await redisClient.numsub(getToServerChannel(sessionId)); + return numSubs > 0; +} + +export async function setSessionOwner(sessionId: string, userId: string): Promise { + logger.debug('Setting session owner', { sessionId, userId }); + await redisClient.set(`session:${sessionId}:owner`, userId); +} + +export async function getSessionOwner(sessionId: string): Promise { + return await redisClient.get(`session:${sessionId}:owner`); +} + +export async function validateSessionOwnership(sessionId: string, userId: string): Promise { + const owner = await getSessionOwner(sessionId); + return owner === userId; +} + +export async function isSessionOwnedBy(sessionId: string, userId: string): Promise { + const isLiveSession = await isLive(sessionId); + if (!isLiveSession) { + logger.debug('Session not live', { sessionId }); + return false; + } + const isOwned = await validateSessionOwnership(sessionId, userId); + logger.debug('Session ownership check', { sessionId, userId, isOwned }); + return isOwned; +} + + +export async function redisRelayToMcpServer(sessionId: string, transport: Transport, isGetRequest: boolean = false): Promise<() => Promise> { + logger.debug('Setting up Redis relay to MCP server', { + sessionId, + isGetRequest + }); + + let redisCleanup: (() => Promise) | undefined = undefined; + const cleanup = async () => { + // TODO: solve race conditions where we call cleanup while the subscription is being created / before it is created + if (redisCleanup) { + logger.debug('Cleaning up Redis relay', { sessionId }); + await redisCleanup(); + } + } + + const subscribe = async (requestId: string) => { + const toClientChannel = getToClientChannel(sessionId, requestId); + + logger.debug('Subscribing to client channel', { + sessionId, + requestId, + channel: toClientChannel + }); + + redisCleanup = await redisClient.createSubscription(toClientChannel, async (redisMessageJson) => { + const redisMessage = JSON.parse(redisMessageJson) as RedisMessage; + if (redisMessage.type === 'mcp') { + logger.debug('Relaying message from Redis to client', { + sessionId, + requestId, + method: ('method' in redisMessage.message ? redisMessage.message.method : undefined) + }); + await transport.send(redisMessage.message, redisMessage.options); + } + }, (error) => { + logger.error('Error in Redis relay subscription', error, { + sessionId, + channel: toClientChannel + }); + transport.onerror?.(error); + }); + } + + if (isGetRequest) { + await subscribe(notificationStreamId); + } else { + const messagePromise = new Promise((resolve) => { + transport.onmessage = async (message, extra) => { + // First, set up response subscription if needed + if ("id" in message) { + logger.debug('Setting up response subscription', { + sessionId, + messageId: message.id, + method: ('method' in message ? message.method : undefined) + }); + await subscribe(message.id.toString()); + } + // Now send the message to the MCP server + await sendToMcpServer(sessionId, message, extra); + resolve(message); + } + }); + + messagePromise.catch((error) => { + transport.onerror?.(error); + cleanup(); + }); + } + return cleanup; +} + + +// New Redis transport for server->client messages using request-id based channels +export class ServerRedisTransport implements Transport { + private counter: number; + private _sessionId: string; + private controlCleanup?: (() => Promise); + private serverCleanup?: (() => Promise); + private shouldShutdown = false; + private inactivityTimeout?: NodeJS.Timeout; + private readonly INACTIVITY_TIMEOUT_MS = 5 * 60 * 1000; // 5 minutes + + onclose?: (() => void) | undefined; + onerror?: ((error: Error) => void) | undefined; + onmessage?: ((message: JSONRPCMessage, extra?: { authInfo?: AuthInfo; }) => void) | undefined; + + constructor(sessionId: string) { + this.counter = redisTransportCounter++; + this._sessionId = sessionId; + } + + private resetInactivityTimer(): void { + // Clear existing timeout if any + if (this.inactivityTimeout) { + clearTimeout(this.inactivityTimeout); + } + + // Set new timeout + this.inactivityTimeout = setTimeout(() => { + logger.info('Session timed out due to inactivity', { + sessionId: this._sessionId, + timeoutMs: this.INACTIVITY_TIMEOUT_MS + }); + void shutdownSession(this._sessionId); + }, this.INACTIVITY_TIMEOUT_MS); + } + + private clearInactivityTimer(): void { + if (this.inactivityTimeout) { + clearTimeout(this.inactivityTimeout); + this.inactivityTimeout = undefined; + } + } + + async start(): Promise { + logger.info('Starting ServerRedisTransport', { + sessionId: this._sessionId, + inactivityTimeoutMs: this.INACTIVITY_TIMEOUT_MS + }); + + // Start inactivity timer + this.resetInactivityTimer(); + + // Subscribe to MCP messages from clients + const serverChannel = getToServerChannel(this._sessionId); + logger.debug('Subscribing to server channel', { + sessionId: this._sessionId, + channel: serverChannel + }); + + this.serverCleanup = await redisClient.createSubscription( + serverChannel, + (messageJson) => { + const redisMessage = JSON.parse(messageJson) as RedisMessage; + if (redisMessage.type === 'mcp') { + // Reset inactivity timer on each message from client + this.resetInactivityTimer(); + + logger.debug('Received MCP message from client', { + sessionId: this._sessionId, + method: ('method' in redisMessage.message ? redisMessage.message.method : undefined), + id: ('id' in redisMessage.message ? redisMessage.message.id : undefined) + }); + + this.onmessage?.(redisMessage.message, redisMessage.extra); + } + }, + (error) => { + logger.error('Error in server channel subscription', error, { + sessionId: this._sessionId, + channel: serverChannel + }); + this.onerror?.(error); + } + ); + + // Subscribe to control messages for shutdown + const controlChannel = getControlChannel(this._sessionId); + logger.debug('Subscribing to control channel', { + sessionId: this._sessionId, + channel: controlChannel + }); + + this.controlCleanup = await redisClient.createSubscription( + controlChannel, + (messageJson) => { + const redisMessage = JSON.parse(messageJson) as RedisMessage; + if (redisMessage.type === 'control') { + logger.info('Received control message', { + sessionId: this._sessionId, + action: redisMessage.action + }); + + if (redisMessage.action === 'SHUTDOWN') { + logger.info('Shutting down transport due to control message', { + sessionId: this._sessionId + }); + this.shouldShutdown = true; + this.close(); + } + } + }, + (error) => { + logger.error('Error in control channel subscription', error, { + sessionId: this._sessionId, + channel: controlChannel + }); + this.onerror?.(error); + } + ); + + } + + async send(message: JSONRPCMessage, options?: TransportSendOptions): Promise { + const relatedRequestId = options?.relatedRequestId?.toString() ?? ("id" in message ? message.id?.toString() : notificationStreamId); + const channel = getToClientChannel(this._sessionId, relatedRequestId) + + logger.debug('Sending message to client', { + sessionId: this._sessionId, + channel, + method: ('method' in message ? message.method : undefined), + id: ('id' in message ? message.id : undefined), + relatedRequestId + }); + + const redisMessage: RedisMessage = { type: 'mcp', message, options }; + const messageStr = JSON.stringify(redisMessage); + await redisClient.publish(channel, messageStr); + } + + async close(): Promise { + logger.info('Closing ServerRedisTransport', { + sessionId: this._sessionId, + wasShutdown: this.shouldShutdown + }); + + // Clear inactivity timer + this.clearInactivityTimer(); + + // Clean up server message subscription + if (this.serverCleanup) { + await this.serverCleanup(); + this.serverCleanup = undefined; + } + + // Clean up control message subscription + if (this.controlCleanup) { + await this.controlCleanup(); + this.controlCleanup = undefined; + } + + this.onclose?.(); + } +} + +export async function getShttpTransport(sessionId: string, onsessionclosed: (sessionId: string) => void | Promise, isGetRequest: boolean = false): Promise { + logger.debug('Getting StreamableHTTPServerTransport for existing session', { + sessionId, + isGetRequest + }); + + // Giving undefined here and setting the sessionId means the + // transport wont try to create a new session. + const shttpTransport = new StreamableHTTPServerTransport({ + sessionIdGenerator: undefined, + onsessionclosed, + }) + shttpTransport.sessionId = sessionId; + + // Use the new request-id based relay approach + const cleanup = await redisRelayToMcpServer(sessionId, shttpTransport, isGetRequest); + shttpTransport.onclose = cleanup; + return shttpTransport; +} \ No newline at end of file diff --git a/src/static/mcp.png b/src/static/mcp.png new file mode 100644 index 0000000..86c5266 Binary files /dev/null and b/src/static/mcp.png differ diff --git a/src/types.ts b/src/types.ts index 7f168d8..0b15454 100644 --- a/src/types.ts +++ b/src/types.ts @@ -27,4 +27,5 @@ export interface McpInstallation { mcpTokens: OAuthTokens; clientId: string; issuedAt: number; + userId: string; // Unique identifier for the user (not client) } \ No newline at end of file diff --git a/src/utils/logger.ts b/src/utils/logger.ts new file mode 100644 index 0000000..f9879db --- /dev/null +++ b/src/utils/logger.ts @@ -0,0 +1,176 @@ +import { AsyncLocalStorage } from 'async_hooks'; +import { Request, Response, NextFunction } from 'express'; + +// Severity levels as per Google Cloud Logging +export enum LogSeverity { + DEFAULT = 'DEFAULT', + DEBUG = 'DEBUG', + INFO = 'INFO', + NOTICE = 'NOTICE', + WARNING = 'WARNING', + ERROR = 'ERROR', + CRITICAL = 'CRITICAL', + ALERT = 'ALERT', + EMERGENCY = 'EMERGENCY' +} + +interface LogContext { + trace?: string; + spanId?: string; + requestId?: string; + userAgent?: string; + method?: string; + path?: string; + [key: string]: string | undefined; +} + +interface StructuredLogEntry { + severity: LogSeverity; + message: string; + timestamp: string; + 'logging.googleapis.com/trace'?: string; + 'logging.googleapis.com/spanId'?: string; + [key: string]: unknown; +} + +class StructuredLogger { + private asyncLocalStorage = new AsyncLocalStorage(); + private projectId: string | undefined; + + constructor() { + // Get project ID from environment or metadata server + this.projectId = process.env.GOOGLE_CLOUD_PROJECT || process.env.GCP_PROJECT; + } + + /** + * Run a function with a specific logging context + */ + runWithContext(context: LogContext, fn: () => T): T { + return this.asyncLocalStorage.run(context, fn); + } + + /** + * Extract trace context from Cloud Run request + */ + extractTraceContext(req: Request): LogContext { + const context: LogContext = {}; + + const traceHeader = req.header('X-Cloud-Trace-Context'); + if (traceHeader && this.projectId) { + const [trace, spanId] = traceHeader.split('/'); + context.trace = `projects/${this.projectId}/traces/${trace}`; + if (spanId) { + context.spanId = spanId.split(';')[0]; // Remove any trace flags + } + } + + // Add other useful request context + context.requestId = req.header('X-Request-Id'); + context.userAgent = req.header('User-Agent'); + context.method = req.method; + context.path = req.path; + + return context; + } + + /** + * Create Express middleware for request context + */ + middleware() { + return (req: Request, res: Response, next: NextFunction) => { + const context = this.extractTraceContext(req); + this.runWithContext(context, () => { + next(); + }); + }; + } + + /** + * Log a structured message + */ + private log(severity: LogSeverity, message: string, metadata?: Record) { + const context = this.asyncLocalStorage.getStore() || {}; + + const entry: StructuredLogEntry = { + severity, + message, + timestamp: new Date().toISOString(), + ...metadata + }; + + // Add trace context if available + if (context.trace) { + entry['logging.googleapis.com/trace'] = context.trace; + } + if (context.spanId) { + entry['logging.googleapis.com/spanId'] = context.spanId; + } + + // Add any other context fields + Object.keys(context).forEach(key => { + if (key !== 'trace' && key !== 'spanId') { + entry[`context.${key}`] = context[key]; + } + }); + + // Output as JSON for Cloud Logging + console.log(JSON.stringify(entry)); + } + + // Convenience methods for different severity levels + debug(message: string, metadata?: Record) { + this.log(LogSeverity.DEBUG, message, metadata); + } + + info(message: string, metadata?: Record) { + this.log(LogSeverity.INFO, message, metadata); + } + + notice(message: string, metadata?: Record) { + this.log(LogSeverity.NOTICE, message, metadata); + } + + warning(message: string, metadata?: Record) { + this.log(LogSeverity.WARNING, message, metadata); + } + + error(message: string, error?: Error, metadata?: Record) { + const errorMetadata = { + ...metadata, + error: error ? { + name: error.name, + message: error.message, + stack: error.stack + } : undefined + }; + this.log(LogSeverity.ERROR, message, errorMetadata); + } + + critical(message: string, metadata?: Record) { + this.log(LogSeverity.CRITICAL, message, metadata); + } + + alert(message: string, metadata?: Record) { + this.log(LogSeverity.ALERT, message, metadata); + } + + emergency(message: string, metadata?: Record) { + this.log(LogSeverity.EMERGENCY, message, metadata); + } + + /** + * Add additional context to the current async context + */ + addContext(context: LogContext) { + const currentContext = this.asyncLocalStorage.getStore(); + if (currentContext) { + Object.assign(currentContext, context); + } + } +} + +// Export singleton instance +export const logger = new StructuredLogger(); + +// Re-export for convenience +export type { LogContext, StructuredLogEntry }; \ No newline at end of file