|
| 1 | +package mcp |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + |
| 6 | + "github.com/databricks/cli/experimental/apps-mcp/lib/session" |
| 7 | +) |
| 8 | + |
| 9 | +// MiddlewareContext provides context for middleware execution. |
| 10 | +type MiddlewareContext struct { |
| 11 | + Ctx context.Context |
| 12 | + Request *CallToolRequest |
| 13 | + Session *session.Session |
| 14 | +} |
| 15 | + |
| 16 | +// MiddlewareFunc is a function that processes a tool call request. |
| 17 | +// It can: |
| 18 | +// - Return (nil, nil) to pass execution to the next middleware or tool handler |
| 19 | +// - Return (result, nil) to short-circuit and return a result immediately |
| 20 | +// - Return (nil, error) to abort execution with an error |
| 21 | +type MiddlewareFunc func(*MiddlewareContext, NextFunc) (*CallToolResult, error) |
| 22 | + |
| 23 | +// NextFunc is called by middleware to pass execution to the next middleware or tool handler. |
| 24 | +type NextFunc func() (*CallToolResult, error) |
| 25 | + |
| 26 | +// Middleware represents a middleware component in the chain. |
| 27 | +type Middleware interface { |
| 28 | + // Handle processes the request and optionally calls next to continue the chain. |
| 29 | + Handle(ctx *MiddlewareContext, next NextFunc) (*CallToolResult, error) |
| 30 | +} |
| 31 | + |
| 32 | +// MiddlewareFuncAdapter adapts a MiddlewareFunc to the Middleware interface. |
| 33 | +type MiddlewareFuncAdapter struct { |
| 34 | + fn MiddlewareFunc |
| 35 | +} |
| 36 | + |
| 37 | +// Handle implements the Middleware interface. |
| 38 | +func (m *MiddlewareFuncAdapter) Handle(ctx *MiddlewareContext, next NextFunc) (*CallToolResult, error) { |
| 39 | + return m.fn(ctx, next) |
| 40 | +} |
| 41 | + |
| 42 | +// NewMiddleware creates a Middleware from a MiddlewareFunc. |
| 43 | +func NewMiddleware(fn MiddlewareFunc) Middleware { |
| 44 | + return &MiddlewareFuncAdapter{fn: fn} |
| 45 | +} |
| 46 | + |
| 47 | +// Chain executes a chain of middleware with an existing Session followed by a final handler. |
| 48 | +// The Session persists across multiple tool calls (server session scope). |
| 49 | +func Chain(middlewares []Middleware, sess *session.Session, handler ToolHandler) ToolHandler { |
| 50 | + return func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { |
| 51 | + // Add session to context |
| 52 | + ctx = session.WithSession(ctx, sess) |
| 53 | + |
| 54 | + mwCtx := &MiddlewareContext{ |
| 55 | + Ctx: ctx, |
| 56 | + Request: req, |
| 57 | + Session: sess, |
| 58 | + } |
| 59 | + |
| 60 | + // Build the chain from the end |
| 61 | + var chain NextFunc |
| 62 | + chain = func() (*CallToolResult, error) { |
| 63 | + return handler(ctx, req) |
| 64 | + } |
| 65 | + |
| 66 | + // Wrap each middleware in reverse order |
| 67 | + for i := len(middlewares) - 1; i >= 0; i-- { |
| 68 | + currentMiddleware := middlewares[i] |
| 69 | + next := chain |
| 70 | + chain = func() (*CallToolResult, error) { |
| 71 | + return currentMiddleware.Handle(mwCtx, next) |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + // Execute the chain |
| 76 | + return chain() |
| 77 | + } |
| 78 | +} |
0 commit comments