Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions workflow/packages/backend/api/src/app/mcp/mcp-server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ export type CreateMcpServerRequest = {
mcpId: string
reply: FastifyReply
logger: FastifyBaseLogger
userId?: string // πŸ” NEW: User context for security
projectId?: string // πŸ” NEW: Project context for security
}
export type CreateMcpServerResponse = {
server: McpServer
Expand Down
232 changes: 206 additions & 26 deletions workflow/packages/backend/api/src/app/mcp/mcp-sse-controller.ts
Original file line number Diff line number Diff line change
@@ -1,39 +1,189 @@
import { ALL_PRINCIPAL_TYPES, ApId } from 'workflow-shared'
import { ALL_PRINCIPAL_TYPES, ApId, PRINCIPAL_TYPES, PrincipalType } from 'workflow-shared'
import { FastifyPluginAsyncTypebox, Type } from '@fastify/type-provider-typebox'
import { StatusCodes } from 'http-status-codes'
import { createMcpServer } from './mcp-server'
import { mcpService } from './mcp-service'
import { mcpSessionManager } from './mcp-session-manager'

const HEARTBEAT_INTERVAL = 30 * 1000 // 30 seconds
const MAX_CONNECTIONS_PER_USER = 5 // Maximum concurrent connections per user
const RATE_LIMIT_WINDOW = 60 * 1000 // 1 minute rate limit window
const RATE_LIMIT_MAX_REQUESTS = 10 // Maximum requests per window

// In-memory rate limiting and connection tracking
const userConnections = new Map<string, number>()
const userRateLimit = new Map<string, { count: number, resetTime: number }>()

export const mcpSseController: FastifyPluginAsyncTypebox = async (app) => {

app.get('/:id/sse', SSERequest, async (req, reply) => {
const token = req.params.id
const mcp = await mcpService(req.log).getByToken({
token,
})

const { server, transport } = await createMcpServer({
mcpId: mcp.id,
reply,
logger: req.log,
})

await mcpSessionManager(req.log).add(transport.sessionId, server, transport)

await server.connect(transport)

const heartbeatInterval = setInterval(() => {
reply.raw.write(': heartbeat\n\n')
req.log.info(`Heartbeat sent for session ${transport.sessionId}`)
}, HEARTBEAT_INTERVAL)

reply.raw.on('close', async () => {
clearInterval(heartbeatInterval)
req.log.info(`Connection closed for session ${transport.sessionId}`)
await mcpSessionManager(req.log).publish(transport.sessionId, {}, 'remove')
})

// πŸ” SECURITY FIX: Validate user authentication and session
if (!req.principal || !req.principal.id) {
req.log.warn(`Unauthorized MCP SSE access attempt with token: ${token.substring(0, 8)}...`)
return reply.code(StatusCodes.UNAUTHORIZED).send({
error: 'Authentication required',
message: 'Valid session required for MCP server access'
})
}

// πŸ” SECURITY FIX: Only allow authenticated users (no service principals for SSE)
if (req.principal.type !== PrincipalType.USER) {
req.log.warn(`Invalid principal type for MCP SSE: ${req.principal.type}`)
return reply.code(StatusCodes.FORBIDDEN).send({
error: 'Invalid access method',
message: 'MCP SSE connections require user authentication'
})
}

const userId = req.principal.id
const userProjectId = req.principal.projectId

// πŸ” SECURITY FIX: Rate limiting per user
const now = Date.now()
const userRate = userRateLimit.get(userId) || { count: 0, resetTime: now + RATE_LIMIT_WINDOW }

if (now > userRate.resetTime) {
// Reset rate limit window
userRate.count = 0
userRate.resetTime = now + RATE_LIMIT_WINDOW
}

if (userRate.count >= RATE_LIMIT_MAX_REQUESTS) {
req.log.warn(`Rate limit exceeded for user ${userId}`)
return reply.code(StatusCodes.TOO_MANY_REQUESTS).send({
error: 'Rate limit exceeded',
message: 'Too many MCP connection attempts. Please wait before retrying.',
retryAfter: Math.ceil((userRate.resetTime - now) / 1000)
})
}

userRate.count++
userRateLimit.set(userId, userRate)

// πŸ” SECURITY FIX: Connection limit per user
const currentConnections = userConnections.get(userId) || 0
if (currentConnections >= MAX_CONNECTIONS_PER_USER) {
req.log.warn(`Connection limit exceeded for user ${userId}`)
return reply.code(StatusCodes.TOO_MANY_REQUESTS).send({
error: 'Connection limit exceeded',
message: `Maximum ${MAX_CONNECTIONS_PER_USER} concurrent MCP connections allowed per user`
})
}

// πŸ” SECURITY FIX: Validate token exists and get MCP configuration
let mcp
try {
mcp = await mcpService(req.log).getByToken({ token })
} catch (error) {
req.log.warn(`Invalid or expired MCP token: ${token.substring(0, 8)}...`)
return reply.code(StatusCodes.NOT_FOUND).send({
error: 'Invalid token',
message: 'MCP token not found or has expired'
})
}

// πŸ” SECURITY FIX: Verify user owns the MCP or has access to the project
if (mcp.projectId !== userProjectId) {
req.log.warn(`Unauthorized MCP access attempt: user ${userId} tried to access MCP ${mcp.id} from project ${mcp.projectId}`)
return reply.code(StatusCodes.FORBIDDEN).send({
error: 'Unauthorized access',
message: 'You do not have permission to access this MCP server'
})
}

// πŸ” SECURITY FIX: Input validation for token format
const tokenPattern = /^[a-zA-Z0-9_-]+$/
if (!tokenPattern.test(token) || token.length < 10 || token.length > 50) {
req.log.warn(`Invalid token format: ${token.substring(0, 8)}...`)
return reply.code(StatusCodes.BAD_REQUEST).send({
error: 'Invalid token format',
message: 'Token format is invalid'
})
}

try {
// πŸ” SECURITY FIX: Create MCP server with user context
const { server, transport } = await createMcpServer({
mcpId: mcp.id,
reply,
logger: req.log,
userId: userId, // πŸ” NEW: Bind server to authenticated user
projectId: userProjectId // πŸ” NEW: Add project context
})

// πŸ” SECURITY FIX: Track connection with expiration
const sessionMetadata = {
userId: userId,
projectId: userProjectId,
mcpId: mcp.id,
createdAt: new Date(),
expiresAt: new Date(Date.now() + (24 * 60 * 60 * 1000)), // 24 hour session expiry
ipAddress: req.ip,
userAgent: req.headers['user-agent'] || 'Unknown'
}

await mcpSessionManager(req.log).add(transport.sessionId, server, transport, sessionMetadata)

// πŸ” SECURITY FIX: Update connection count
userConnections.set(userId, (userConnections.get(userId) || 0) + 1)

// πŸ” SECURITY FIX: Comprehensive audit logging
req.log.info({
event: 'MCP_SESSION_CREATED',
userId: userId,
sessionId: transport.sessionId,
mcpId: mcp.id,
projectId: userProjectId,
tokenPrefix: token.substring(0, 8),
ipAddress: req.ip,
userAgent: req.headers['user-agent'],
timestamp: new Date().toISOString()
}, `MCP session established for user ${userId}`)

await server.connect(transport)

const heartbeatInterval = setInterval(() => {
reply.raw.write(': heartbeat\n\n')
req.log.debug(`Heartbeat sent for session ${transport.sessionId}`)
}, HEARTBEAT_INTERVAL)

reply.raw.on('close', async () => {
clearInterval(heartbeatInterval)

// πŸ” SECURITY FIX: Cleanup connection tracking
const currentCount = userConnections.get(userId) || 1
if (currentCount <= 1) {
userConnections.delete(userId)
} else {
userConnections.set(userId, currentCount - 1)
}

req.log.info({
event: 'MCP_SESSION_CLOSED',
userId: userId,
sessionId: transport.sessionId,
mcpId: mcp.id,
timestamp: new Date().toISOString()
}, `MCP session closed for user ${userId}`)

await mcpSessionManager(req.log).publish(transport.sessionId, {}, 'remove')
})

} catch (error) {
req.log.error({
error: error,
userId: userId,
mcpId: mcp.id,
tokenPrefix: token.substring(0, 8)
}, 'Failed to create MCP server')

return reply.code(StatusCodes.INTERNAL_SERVER_ERROR).send({
error: 'Server error',
message: 'Failed to establish MCP connection'
})
}
})

app.post('/messages', MessagesRequest, async (req, reply) => {
Expand Down Expand Up @@ -62,11 +212,41 @@ const MessagesRequest = {

const SSERequest = {
config: {
allowedPrincipals: ALL_PRINCIPAL_TYPES,
// πŸ” SECURITY FIX: Restrict to authenticated users only (no service principals)
allowedPrincipals: [PrincipalType.USER],
},
schema: {
tags: ['mcp'],
description: 'Establish SSE connection to MCP server (requires user authentication)',
params: Type.Object({
id: ApId,
}),
response: {
[StatusCodes.UNAUTHORIZED]: Type.Object({
error: Type.String(),
message: Type.String(),
}),
[StatusCodes.FORBIDDEN]: Type.Object({
error: Type.String(),
message: Type.String(),
}),
[StatusCodes.TOO_MANY_REQUESTS]: Type.Object({
error: Type.String(),
message: Type.String(),
retryAfter: Type.Optional(Type.Number()),
}),
[StatusCodes.NOT_FOUND]: Type.Object({
error: Type.String(),
message: Type.String(),
}),
[StatusCodes.BAD_REQUEST]: Type.Object({
error: Type.String(),
message: Type.String(),
}),
[StatusCodes.INTERNAL_SERVER_ERROR]: Type.Object({
error: Type.String(),
message: Type.String(),
}),
},
},
}