diff --git a/README.md b/README.md index bcdf0d841..8f3c40f98 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ - [Improving Network Efficiency with Notification Debouncing](#improving-network-efficiency-with-notification-debouncing) - [Low-Level Server](#low-level-server) - [Eliciting User Input](#eliciting-user-input) + - [Task-Based Execution](#task-based-execution) - [Writing MCP Clients](#writing-mcp-clients) - [Proxy Authorization Requests Upstream](#proxy-authorization-requests-upstream) - [Backwards Compatibility](#backwards-compatibility) @@ -1382,6 +1383,204 @@ const client = new Client( ); ``` +### Task-Based Execution + +Task-based execution enables "call-now, fetch-later" patterns for long-running operations. This is useful for tools that take significant time to complete, where clients may want to disconnect and check on progress or retrieve results later. + +Common use cases include: + +- Long-running data processing or analysis +- Code migration or refactoring operations +- Complex computational tasks +- Operations that require periodic status updates + +#### Server-Side: Implementing Task Support + +To enable task-based execution, configure your server with a `TaskStore` implementation. The SDK doesn't provide a built-in TaskStore—you'll need to implement one backed by your database of choice: + +```typescript +import { Server } from '@modelcontextprotocol/sdk/server/index.js'; +import { TaskStore } from '@modelcontextprotocol/sdk/shared/task.js'; +import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'; + +// Implement TaskStore backed by your database (e.g., PostgreSQL, Redis, etc.) +class MyTaskStore implements TaskStore { + async createTask(taskParams, requestId, request, sessionId?): Promise { + // Generate unique taskId and lastUpdatedAt/createdAt timestamps + // Store task in your database, using the session ID as a proxy to restrict unauthorized access + // Return final Task object + } + + async getTask(taskId): Promise { + // Retrieve task from your database + } + + async updateTaskStatus(taskId, status, statusMessage?): Promise { + // Update task status in your database + } + + async storeTaskResult(taskId, result): Promise { + // Store task result in your database + } + + async getTaskResult(taskId): Promise { + // Retrieve task result from your database + } + + async listTasks(cursor?, sessionId?): Promise<{ tasks: Task[]; nextCursor?: string }> { + // List tasks with pagination support + } +} + +const taskStore = new MyTaskStore(); + +const server = new Server( + { + name: 'task-enabled-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {}, + // Declare capabilities + tasks: { + list: {}, + cancel: {}, + requests: { + tools: { + // Declares support for tasks on tools/call + call: {} + } + } + } + }, + taskStore // Enable task support + } +); + +// Register a tool that supports tasks +server.registerToolTask( + 'my-echo-tool', + { + title: 'My Echo Tool', + description: 'A simple task-based echo tool.', + inputSchema: { + message: z.string().describe('Message to send') + } + }, + { + async createTask({ message }, { taskStore, taskRequestedTtl, requestId }) { + // Create the task + const task = await taskStore.createTask({ + ttl: taskRequestedTtl + }); + + // Simulate out-of-band work + (async () => { + await new Promise(resolve => setTimeout(resolve, 5000)); + await taskStore.storeTaskResult(task.taskId, 'completed', { + content: [ + { + type: 'text', + text: message + } + ] + }); + })(); + + // Return CreateTaskResult with the created task + return { task }; + }, + async getTask(_args, { taskId, taskStore }) { + // Retrieve the task + return await taskStore.getTask(taskId); + }, + async getTaskResult(_args, { taskId, taskStore }) { + // Retrieve the result of the task + const result = await taskStore.getTaskResult(taskId); + return result as CallToolResult; + } + } +); +``` + +**Note**: See `src/examples/shared/inMemoryTaskStore.ts` in the SDK source for a reference task store implementation suitable for development and testing. + +#### Client-Side: Using Task-Based Execution + +Clients use `callToolStream()` to initiate task-augmented tool calls. The returned `AsyncGenerator` abstracts automatic polling and status updates: + +```typescript +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { CallToolResultSchema } from '@modelcontextprotocol/sdk/types.js'; + +const client = new Client({ + name: 'task-client', + version: '1.0.0' +}); + +// ... connect to server ... + +// Call the tool with task metadata using streaming API +const stream = client.callToolStream( + { + name: 'my-echo-tool', + arguments: { message: 'Hello, world!' } + }, + CallToolResultSchema +); + +// Iterate the stream and handle stream events +let taskId = ''; +for await (const message of stream) { + switch (message.type) { + case 'taskCreated': + console.log('Task created successfully with ID:', message.task.taskId); + taskId = message.task.taskId; + break; + case 'taskStatus': + console.log(` ${message.task.status}${message.task.statusMessage ?? ''}`); + break; + case 'result': + console.log('Task completed! Tool result:'); + message.result.content.forEach(item => { + if (item.type === 'text') { + console.log(` ${item.text}`); + } + }); + break; + case 'error': + throw message.error; + } +} + +// Optional: Fire and forget - disconnect and reconnect later +// (useful when you don't want to wait for long-running tasks) +// Later, after disconnecting and reconnecting to the server: +const taskStatus = await client.getTask({ taskId }); +console.log('Task status:', taskStatus.status); + +if (taskStatus.status === 'completed') { + const taskResult = await client.getTaskResult({ taskId }, CallToolResultSchema); + console.log('Retrieved result after reconnect:', taskResult); +} +``` + +The `callToolStream()` method also works with non-task tools, making it a drop-in replacement for `callTool()` in applications that support it. When used to invoke a tool that doesn't support tasks, the `taskCreated` and `taskStatus` events will not be emitted. + +#### Task Status Lifecycle + +Tasks transition through the following states: + +- **working**: Task is actively being processed +- **input_required**: Task is waiting for additional input (e.g., from elicitation) +- **completed**: Task finished successfully +- **failed**: Task encountered an error +- **cancelled**: Task was cancelled by the client + +The `ttl` parameter suggests how long the server will manage the task for. If the task duration exceeds this, the server may delete the task prematurely. The client's suggested value may be overridden by the server, and the final TTL will be provided in `Task.ttl` in +`taskCreated` and `taskStatus` events. + ### Writing MCP Clients The SDK provides a high-level client interface: diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 4c26c796c..38af2a841 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -11,16 +11,22 @@ import { InitializeRequestSchema, ListResourcesRequestSchema, ListToolsRequestSchema, + ListToolsResultSchema, CallToolRequestSchema, + CallToolResultSchema, CreateMessageRequestSchema, ElicitRequestSchema, ElicitResultSchema, ListRootsRequestSchema, - ErrorCode + ErrorCode, + McpError, + CreateTaskResultSchema } from '../types.js'; import { Transport } from '../shared/transport.js'; import { Server } from '../server/index.js'; +import { McpServer } from '../server/mcp.js'; import { InMemoryTransport } from '../inMemory.js'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../examples/shared/inMemoryTaskStore.js'; import * as z3 from 'zod/v3'; import * as z4 from 'zod/v4'; @@ -1170,8 +1176,8 @@ test('should handle client cancelling a request', async () => { }); controller.abort('Cancelled by test'); - // Request should be rejected - await expect(listResourcesPromise).rejects.toBe('Cancelled by test'); + // Request should be rejected with an McpError + await expect(listResourcesPromise).rejects.toThrow(McpError); }); /*** @@ -1281,10 +1287,28 @@ describe('outputSchema validation', () => { throw new Error('Unknown tool'); }); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + }, + tasks: { + get: true, + list: {}, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1356,10 +1380,28 @@ describe('outputSchema validation', () => { throw new Error('Unknown tool'); }); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + }, + tasks: { + get: true, + list: {}, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1428,10 +1470,28 @@ describe('outputSchema validation', () => { throw new Error('Unknown tool'); }); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + }, + tasks: { + get: true, + list: {}, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1496,10 +1556,28 @@ describe('outputSchema validation', () => { throw new Error('Unknown tool'); }); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + }, + tasks: { + get: true, + list: {}, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1591,10 +1669,28 @@ describe('outputSchema validation', () => { throw new Error('Unknown tool'); }); - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + }, + tasks: { + get: true, + list: {}, + result: true + } + } + } + } + } + ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -1690,6 +1786,1308 @@ describe('outputSchema validation', () => { }); }); +describe('Task-based execution', () => { + describe('Client calling server', () => { + let serverTaskStore: InMemoryTaskStore; + + beforeEach(() => { + serverTaskStore = new InMemoryTaskStore(); + }); + + afterEach(() => { + serverTaskStore?.cleanup(); + }); + + test('should create task on server via tool call', async () => { + const server = new McpServer( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore: serverTaskStore + } + ); + + server.registerToolTask( + 'test-tool', + { + description: 'A test tool', + inputSchema: {} + }, + { + async createTask(_args, extra) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + + const result = { + content: [{ type: 'text', text: 'Tool executed successfully!' }] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Client creates task on server via tool call + await client.callTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { + ttl: 60000 + } + }); + + // Verify task was created successfully by listing tasks + const taskList = await client.listTasks(); + expect(taskList.tasks.length).toBeGreaterThan(0); + const task = taskList.tasks[0]; + expect(task.status).toBe('completed'); + }); + + test('should query task status from server using getTask', async () => { + const server = new McpServer( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore: serverTaskStore + } + ); + + server.registerToolTask( + 'test-tool', + { + description: 'A test tool', + inputSchema: {} + }, + { + async createTask(_args, extra) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + + const result = { + content: [{ type: 'text', text: 'Success!' }] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create a task + await client.callTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { ttl: 60000 } + }); + + // Query task status by listing tasks and getting the first one + const taskList = await client.listTasks(); + expect(taskList.tasks.length).toBeGreaterThan(0); + const task = taskList.tasks[0]; + expect(task).toBeDefined(); + expect(task.taskId).toBeDefined(); + expect(task.status).toBe('completed'); + }); + + test('should query task result from server using getTaskResult', async () => { + const server = new McpServer( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {}, + list: {} + } + } + } + }, + taskStore: serverTaskStore + } + ); + + server.registerToolTask( + 'test-tool', + { + description: 'A test tool', + inputSchema: {} + }, + { + async createTask(_args, extra) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + + const result = { + content: [{ type: 'text', text: 'Result data!' }] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create a task using callToolStream to capture the task ID + let taskId: string | undefined; + const stream = client.callToolStream({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { ttl: 60000 } + }); + + for await (const message of stream) { + if (message.type === 'taskCreated') { + taskId = message.task.taskId; + } + } + + expect(taskId).toBeDefined(); + + // Query task result using the captured task ID + const result = await client.getTaskResult({ taskId: taskId! }, CallToolResultSchema); + expect(result.content).toEqual([{ type: 'text', text: 'Result data!' }]); + }); + + test('should query task list from server using listTasks', async () => { + const server = new McpServer( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore: serverTaskStore + } + ); + + server.registerToolTask( + 'test-tool', + { + description: 'A test tool', + inputSchema: {} + }, + { + async createTask(_args, extra) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + + const result = { + content: [{ type: 'text', text: 'Success!' }] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create multiple tasks + const createdTaskIds: string[] = []; + + for (let i = 0; i < 2; i++) { + await client.callTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { ttl: 60000 } + }); + + // Get the task ID from the task list + const taskList = await client.listTasks(); + const newTask = taskList.tasks.find(t => !createdTaskIds.includes(t.taskId)); + if (newTask) { + createdTaskIds.push(newTask.taskId); + } + } + + // Query task list + const taskList = await client.listTasks(); + expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); + for (const taskId of createdTaskIds) { + expect(taskList.tasks).toContainEqual( + expect.objectContaining({ + taskId, + status: 'completed' + }) + ); + } + }); + }); + + describe('Server calling client', () => { + let clientTaskStore: InMemoryTaskStore; + + beforeEach(() => { + clientTaskStore = new InMemoryTaskStore(); + }); + + afterEach(() => { + clientTaskStore?.cleanup(); + }); + + test('should create task on client via server elicitation', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + const result = { + action: 'accept', + content: { username: 'list-user' } + }; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; + } + + // Return ElicitResult for non-task requests + return result; + }); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Server creates task on client via elicitation + const createTaskResult = await server.request( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: 'Please provide your username', + requestedSchema: { + type: 'object', + properties: { + username: { type: 'string' } + }, + required: ['username'] + } + } + }, + CreateTaskResultSchema, + { task: { ttl: 60000 } } + ); + + // Verify CreateTaskResult structure + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + const taskId = createTaskResult.task.taskId; + + // Verify task was created + const task = await server.getTask({ taskId }); + expect(task.status).toBe('completed'); + }); + + test('should query task status from client using getTask', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + const result = { + action: 'accept', + content: { username: 'list-user' } + }; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; + } + + // Return ElicitResult for non-task requests + return result; + }); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create a task on client and wait for CreateTaskResult + const createTaskResult = await server.request( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: 'Please provide info', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } + } + }, + CreateTaskResultSchema, + { task: { ttl: 60000 } } + ); + + // Verify CreateTaskResult structure + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + const taskId = createTaskResult.task.taskId; + + // Query task status + const task = await server.getTask({ taskId }); + expect(task).toBeDefined(); + expect(task.taskId).toBe(taskId); + expect(task.status).toBe('completed'); + }); + + test('should query task result from client using getTaskResult', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + const result = { + action: 'accept', + content: { username: 'result-user' } + }; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; + } + + // Return ElicitResult for non-task requests + return result; + }); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create a task on client and wait for CreateTaskResult + const createTaskResult = await server.request( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: 'Please provide info', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } + } + }, + CreateTaskResultSchema, + { task: { ttl: 60000 } } + ); + + // Verify CreateTaskResult structure + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + const taskId = createTaskResult.task.taskId; + + // Query task result using getTaskResult + const taskResult = await server.getTaskResult({ taskId }, ElicitResultSchema); + expect(taskResult.action).toBe('accept'); + expect(taskResult.content).toEqual({ username: 'result-user' }); + }); + + test('should query task list from client using listTasks', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + const result = { + action: 'accept', + content: { username: 'list-user' } + }; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; + } + + // Return ElicitResult for non-task requests + return result; + }); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create multiple tasks on client + const createdTaskIds: string[] = []; + for (let i = 0; i < 2; i++) { + const createTaskResult = await server.request( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: 'Please provide info', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } + } + }, + CreateTaskResultSchema, + { task: { ttl: 60000 } } + ); + + // Verify CreateTaskResult structure and capture taskId + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + createdTaskIds.push(createTaskResult.task.taskId); + } + + // Query task list + const taskList = await server.listTasks(); + expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); + for (const taskId of createdTaskIds) { + expect(taskList.tasks).toContainEqual( + expect.objectContaining({ + taskId, + status: 'completed' + }) + ); + } + }); + }); + + test('should list tasks from server with pagination', async () => { + const serverTaskStore = new InMemoryTaskStore(); + + const server = new McpServer( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore: serverTaskStore + } + ); + + server.registerToolTask( + 'test-tool', + { + description: 'A test tool', + inputSchema: { + id: z4.string() + } + }, + { + async createTask({ id }, extra) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + + const result = { + content: [{ type: 'text', text: `Result for ${id || 'unknown'}` }] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create multiple tasks + const createdTaskIds: string[] = []; + + for (let i = 0; i < 3; i++) { + await client.callTool({ name: 'test-tool', arguments: { id: `task-${i + 1}` } }, CallToolResultSchema, { + task: { ttl: 60000 } + }); + + // Get the task ID from the task list + const taskList = await client.listTasks(); + const newTask = taskList.tasks.find(t => !createdTaskIds.includes(t.taskId)); + if (newTask) { + createdTaskIds.push(newTask.taskId); + } + } + + // List all tasks without cursor + const firstPage = await client.listTasks(); + expect(firstPage.tasks.length).toBeGreaterThan(0); + expect(firstPage.tasks.map(t => t.taskId)).toEqual(expect.arrayContaining(createdTaskIds)); + + // If there's a cursor, test pagination + if (firstPage.nextCursor) { + const secondPage = await client.listTasks({ cursor: firstPage.nextCursor }); + expect(secondPage.tasks).toBeDefined(); + } + + serverTaskStore.cleanup(); + }); + + describe('Error scenarios', () => { + let serverTaskStore: InMemoryTaskStore; + let clientTaskStore: InMemoryTaskStore; + + beforeEach(() => { + serverTaskStore = new InMemoryTaskStore(); + clientTaskStore = new InMemoryTaskStore(); + }); + + afterEach(() => { + serverTaskStore?.cleanup(); + clientTaskStore?.cleanup(); + }); + + test('should throw error when querying non-existent task from server', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore: serverTaskStore + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to get a task that doesn't exist + await expect(client.getTask({ taskId: 'non-existent-task' })).rejects.toThrow(); + }); + + test('should throw error when querying result of non-existent task from server', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore: serverTaskStore + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to get result of a task that doesn't exist + await expect(client.getTaskResult({ taskId: 'non-existent-task' }, CallToolResultSchema)).rejects.toThrow(); + }); + + test('should throw error when server queries non-existent task from client', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'test' } + })); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to query a task that doesn't exist on client + await expect(server.getTask({ taskId: 'non-existent-task' })).rejects.toThrow(); + }); + }); +}); + +test('should respect server task capabilities', async () => { + const serverTaskStore = new InMemoryTaskStore(); + const server = new McpServer( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore: serverTaskStore + } + ); + + server.registerToolTask( + 'test-tool', + { + description: 'A test tool', + inputSchema: {} + }, + { + async createTask(_args, extra) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + + const result = { + content: [{ type: 'text', text: 'Success!' }] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + enforceStrictCapabilities: true + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Server supports task creation for tools/call + expect(client.getServerCapabilities()).toEqual({ + tools: { + listChanged: true + }, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }); + + // These should work because server supports tasks + await expect( + client.callTool({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { ttl: 60000 } + }) + ).resolves.not.toThrow(); + await expect(client.listTasks()).resolves.not.toThrow(); + + // tools/list doesn't support task creation, but it shouldn't throw - it should just ignore the task metadata + await expect( + client.request( + { + method: 'tools/list', + params: {} + }, + ListToolsResultSchema + ) + ).resolves.not.toThrow(); + + serverTaskStore.cleanup(); +}); + +/** + * Test: requestStream() method + */ +test('should expose requestStream() method for streaming responses', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + server.setRequestHandler(CallToolRequestSchema, async () => { + return { + content: [{ type: 'text', text: 'Tool result' }] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: {} + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // First verify that regular request() works + const regularResult = await client.callTool({ name: 'test-tool', arguments: {} }); + expect(regularResult.content).toEqual([{ type: 'text', text: 'Tool result' }]); + + // Test requestStream with non-task request (should yield only result) + const stream = client.requestStream( + { + method: 'tools/call', + params: { name: 'test-tool', arguments: {} } + }, + CallToolResultSchema + ); + + const messages = []; + for await (const message of stream) { + messages.push(message); + } + + // Should have received only a result message (no task messages) + expect(messages.length).toBe(1); + expect(messages[0].type).toBe('result'); + if (messages[0].type === 'result') { + expect(messages[0].result.content).toEqual([{ type: 'text', text: 'Tool result' }]); + } + + await client.close(); + await server.close(); +}); + +/** + * Test: callToolStream() method + */ +test('should expose callToolStream() method for streaming tool calls', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + server.setRequestHandler(CallToolRequestSchema, async () => { + return { + content: [{ type: 'text', text: 'Tool result' }] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: {} + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Test callToolStream + const stream = client.callToolStream({ name: 'test-tool', arguments: {} }); + + const messages = []; + for await (const message of stream) { + messages.push(message); + } + + // Should have received messages ending with result + expect(messages.length).toBe(1); + expect(messages[0].type).toBe('result'); + if (messages[0].type === 'result') { + expect(messages[0].result.content).toEqual([{ type: 'text', text: 'Tool result' }]); + } + + await client.close(); + await server.close(); +}); + +/** + * Test: callToolStream() with output schema validation + */ +test('should validate structured output in callToolStream()', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + server.setRequestHandler(ListToolsRequestSchema, async () => { + return { + tools: [ + { + name: 'structured-tool', + description: 'A tool with output schema', + inputSchema: { + type: 'object', + properties: {} + }, + outputSchema: { + type: 'object', + properties: { + value: { type: 'number' } + }, + required: ['value'] + } + } + ] + }; + }); + + server.setRequestHandler(CallToolRequestSchema, async () => { + return { + content: [{ type: 'text', text: 'Result' }], + structuredContent: { value: 42 } + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: {} + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // List tools to cache the output schema + await client.listTools(); + + // Test callToolStream with valid structured output + const stream = client.callToolStream({ name: 'structured-tool', arguments: {} }); + + const messages = []; + for await (const message of stream) { + messages.push(message); + } + + // Should have received result with validated structured content + expect(messages.length).toBe(1); + expect(messages[0].type).toBe('result'); + if (messages[0].type === 'result') { + expect(messages[0].result.structuredContent).toEqual({ value: 42 }); + } + + await client.close(); + await server.close(); +}); + describe('getSupportedElicitationModes', () => { test('should support nothing when capabilities are undefined', () => { const result = getSupportedElicitationModes(undefined); diff --git a/src/client/index.ts b/src/client/index.ts index 823aa790e..367f6c998 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,5 +1,7 @@ import { mergeCapabilities, Protocol, type ProtocolOptions, type RequestOptions } from '../shared/protocol.js'; import type { Transport } from '../shared/transport.js'; +import { ResponseMessage, takeResult } from '../shared/responseMessage.js'; + import { type CallToolRequest, CallToolResultSchema, @@ -38,7 +40,10 @@ import { type Tool, type UnsubscribeRequest, ElicitResultSchema, - ElicitRequestSchema + ElicitRequestSchema, + CreateTaskResultSchema, + CreateMessageRequestSchema, + CreateMessageResultSchema } from '../types.js'; import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js'; import type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator } from '../validation/types.js'; @@ -195,6 +200,7 @@ export class Client< private _instructions?: string; private _jsonSchemaValidator: jsonSchemaValidator; private _cachedToolOutputValidators: Map> = new Map(); + private _cachedKnownTaskTools: Set = new Set(); /** * Initializes this client with the given name and version information. @@ -280,6 +286,20 @@ export class Client< const result = await Promise.resolve(handler(request, extra)); + // When task creation is requested, validate and return CreateTaskResult + if (params.task) { + const taskValidationResult = safeParse(CreateTaskResultSchema, result); + if (!taskValidationResult.success) { + const errorMessage = + taskValidationResult.error instanceof Error + ? taskValidationResult.error.message + : String(taskValidationResult.error); + throw new McpError(ErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); + } + return taskValidationResult.data; + } + + // For non-task requests, validate against ElicitResultSchema const validationResult = safeParse(ElicitResultSchema, result); if (!validationResult.success) { // Type guard: if success is false, error is guaranteed to exist @@ -308,7 +328,51 @@ export class Client< return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler); } - // Non-elicitation handlers use default behavior + if (method === 'sampling/createMessage') { + const wrappedHandler = async ( + request: SchemaOutput, + extra: RequestHandlerExtra + ): Promise => { + const validatedRequest = safeParse(CreateMessageRequestSchema, request); + if (!validatedRequest.success) { + const errorMessage = + validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); + throw new McpError(ErrorCode.InvalidParams, `Invalid sampling request: ${errorMessage}`); + } + + const { params } = validatedRequest.data; + + const result = await Promise.resolve(handler(request, extra)); + + // When task creation is requested, validate and return CreateTaskResult + if (params.task) { + const taskValidationResult = safeParse(CreateTaskResultSchema, result); + if (!taskValidationResult.success) { + const errorMessage = + taskValidationResult.error instanceof Error + ? taskValidationResult.error.message + : String(taskValidationResult.error); + throw new McpError(ErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); + } + return taskValidationResult.data; + } + + // For non-task requests, validate against CreateMessageResultSchema + const validationResult = safeParse(CreateMessageResultSchema, result); + if (!validationResult.success) { + const errorMessage = + validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); + throw new McpError(ErrorCode.InvalidParams, `Invalid sampling result: ${errorMessage}`); + } + + return validationResult.data; + }; + + // Install the wrapped handler + return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler); + } + + // Other handlers use default behavior return super.setRequestHandler(requestSchema, handler); } @@ -463,6 +527,12 @@ export class Client< } protected assertRequestHandlerCapability(method: string): void { + // Task handlers are registered in Protocol constructor before _capabilities is initialized + // Skip capability check for task methods during initialization + if (!this._capabilities) { + return; + } + switch (method) { case 'sampling/createMessage': if (!this._capabilities.sampling) { @@ -482,12 +552,73 @@ export class Client< } break; + case 'tasks/get': + case 'tasks/list': + case 'tasks/result': + case 'tasks/cancel': + if (!this._capabilities.tasks) { + throw new Error(`Client does not support tasks capability (required for ${method})`); + } + break; + case 'ping': // No specific capability required for ping break; } } + protected assertTaskCapability(method: string): void { + if (!this._serverCapabilities?.tasks?.requests) { + throw new Error(`Server does not support task creation (required for ${method})`); + } + + const requests = this._serverCapabilities.tasks.requests; + + switch (method) { + case 'tools/call': + if (!requests.tools?.call) { + throw new Error(`Server does not support task creation for tools/call (required for ${method})`); + } + break; + + default: + // Method doesn't support tasks, which is fine - no error + break; + } + } + + protected assertTaskHandlerCapability(method: string): void { + // Task handlers are registered in Protocol constructor before _capabilities is initialized + // Skip capability check for task methods during initialization + if (!this._capabilities) { + return; + } + + if (!this._capabilities.tasks?.requests) { + throw new Error(`Client does not support task creation (required for ${method})`); + } + + const requests = this._capabilities.tasks.requests; + + switch (method) { + case 'sampling/createMessage': + if (!requests.sampling?.createMessage) { + throw new Error(`Client does not support task creation for sampling/createMessage (required for ${method})`); + } + break; + + case 'elicitation/create': + if (!requests.elicitation?.create) { + throw new Error(`Client does not support task creation for elicitation/create (required for ${method})`); + } + break; + + default: + // Method doesn't support tasks, which is fine - no error + break; + } + } + async ping(options?: RequestOptions) { return this.request({ method: 'ping' }, EmptyResultSchema, options); } @@ -528,57 +659,145 @@ export class Client< return this.request({ method: 'resources/unsubscribe', params }, EmptyResultSchema, options); } - async callTool( + /** + * Calls a tool and waits for the result. Automatically validates structured output if the tool has an outputSchema. + * + * For task-based execution with streaming behavior, use callToolStream() instead. + */ + async callTool( params: CallToolRequest['params'], - resultSchema: typeof CallToolResultSchema | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, + resultSchema: T = CallToolResultSchema as T, options?: RequestOptions - ) { - const result = await this.request({ method: 'tools/call', params }, resultSchema, options); + ): Promise> { + return await takeResult(this.callToolStream(params, resultSchema, options)); + } - // Check if the tool has an outputSchema + /** + * Calls a tool and returns an AsyncGenerator that yields response messages. + * The generator is guaranteed to end with either a 'result' or 'error' message. + * + * This method provides streaming access to tool execution, allowing you to + * observe intermediate task status updates for long-running tool calls. + * Automatically validates structured output if the tool has an outputSchema. + * + * For simple tool calls without streaming, use callTool() instead. + * + * @example + * ```typescript + * const stream = client.callToolStream({ name: 'myTool', arguments: {} }); + * for await (const message of stream) { + * switch (message.type) { + * case 'taskCreated': + * console.log('Tool execution started:', message.task.taskId); + * break; + * case 'taskStatus': + * console.log('Tool status:', message.task.status); + * break; + * case 'result': + * console.log('Tool result:', message.result); + * // Structured output is automatically validated + * break; + * case 'error': + * console.error('Tool error:', message.error); + * break; + * } + * } + * ``` + * + * @param params - Tool call parameters (name and arguments) + * @param resultSchema - Zod schema for validating the result (defaults to CallToolResultSchema) + * @param options - Optional request options (timeout, signal, task creation params, etc.) + * @returns AsyncGenerator that yields ResponseMessage objects + */ + async *callToolStream( + params: CallToolRequest['params'], + resultSchema: T = CallToolResultSchema as T, + options?: RequestOptions + ): AsyncGenerator>, void, void> { + // Add task creation parameters if server supports it and not explicitly provided + const optionsWithTask = { + ...options, + // We check if the tool is known to be a task during auto-configuration, but assume + // the caller knows what they're doing if they pass this explicitly + task: options?.task ?? (this.isToolTask(params.name) ? {} : undefined) + }; + + const stream = this.requestStream({ method: 'tools/call', params }, resultSchema, optionsWithTask); + + // Get the validator for this tool (if it has an output schema) const validator = this.getToolOutputValidator(params.name); - if (validator) { - // If tool has outputSchema, it MUST return structuredContent (unless it's an error) - if (!result.structuredContent && !result.isError) { - throw new McpError( - ErrorCode.InvalidRequest, - `Tool ${params.name} has an output schema but did not return structured content` - ); - } - // Only validate structured content if present (not when there's an error) - if (result.structuredContent) { - try { - // Validate the structured content against the schema - const validationResult = validator(result.structuredContent); - - if (!validationResult.valid) { - throw new McpError( - ErrorCode.InvalidParams, - `Structured content does not match the tool's output schema: ${validationResult.errorMessage}` - ); - } - } catch (error) { - if (error instanceof McpError) { - throw error; + // Iterate through the stream and validate the final result if needed + for await (const message of stream) { + // If this is a result message and the tool has an output schema, validate it + if (message.type === 'result' && validator) { + const result = message.result; + + // If tool has outputSchema, it MUST return structuredContent (unless it's an error) + if (!result.structuredContent && !result.isError) { + yield { + type: 'error', + error: new McpError( + ErrorCode.InvalidRequest, + `Tool ${params.name} has an output schema but did not return structured content` + ) + }; + return; + } + + // Only validate structured content if present (not when there's an error) + if (result.structuredContent) { + try { + // Validate the structured content against the schema + const validationResult = validator(result.structuredContent); + + if (!validationResult.valid) { + yield { + type: 'error', + error: new McpError( + ErrorCode.InvalidParams, + `Structured content does not match the tool's output schema: ${validationResult.errorMessage}` + ) + }; + return; + } + } catch (error) { + if (error instanceof McpError) { + yield { type: 'error', error }; + return; + } + yield { + type: 'error', + error: new McpError( + ErrorCode.InvalidParams, + `Failed to validate structured content: ${error instanceof Error ? error.message : String(error)}` + ) + }; + return; } - throw new McpError( - ErrorCode.InvalidParams, - `Failed to validate structured content: ${error instanceof Error ? error.message : String(error)}` - ); } } + + // Yield the message (either validated result or any other message type) + yield message; } + } - return result; + private isToolTask(toolName: string): boolean { + if (!this._serverCapabilities?.tasks?.requests?.tools?.call) { + return false; + } + + return this._cachedKnownTaskTools.has(toolName); } /** * Cache validators for tool output schemas. * Called after listTools() to pre-compile validators for better performance. */ - private cacheToolOutputSchemas(tools: Tool[]): void { + private cacheToolMetadata(tools: Tool[]): void { this._cachedToolOutputValidators.clear(); + this._cachedKnownTaskTools.clear(); for (const tool of tools) { // If the tool has an outputSchema, create and cache the validator @@ -586,6 +805,12 @@ export class Client< const toolValidator = this._jsonSchemaValidator.getValidator(tool.outputSchema as JsonSchemaType); this._cachedToolOutputValidators.set(tool.name, toolValidator); } + + // If the tool supports task-based execution, cache that information + const taskSupport = tool.execution?.taskSupport; + if (taskSupport === 'required' || taskSupport === 'optional') { + this._cachedKnownTaskTools.add(tool.name); + } } } @@ -600,7 +825,7 @@ export class Client< const result = await this.request({ method: 'tools/list', params }, ListToolsResultSchema, options); // Cache the tools and their output schemas for future validation - this.cacheToolOutputSchemas(result.tools); + this.cacheToolMetadata(result.tools); return result; } @@ -608,4 +833,45 @@ export class Client< async sendRootsListChanged() { return this.notification({ method: 'notifications/roots/list_changed' }); } + + /** + * Sends a request and returns an AsyncGenerator that yields response messages. + * The generator is guaranteed to end with either a 'result' or 'error' message. + * + * This method provides streaming access to request processing, allowing you to + * observe intermediate task status updates for task-augmented requests. + * + * @example + * ```typescript + * const stream = client.requestStream(request, resultSchema, options); + * for await (const message of stream) { + * switch (message.type) { + * case 'taskCreated': + * console.log('Task created:', message.task.taskId); + * break; + * case 'taskStatus': + * console.log('Task status:', message.task.status); + * break; + * case 'result': + * console.log('Final result:', message.result); + * break; + * case 'error': + * console.error('Error:', message.error); + * break; + * } + * } + * ``` + * + * @param request - The request to send + * @param resultSchema - Zod schema for validating the result + * @param options - Optional request options (timeout, signal, task creation params, etc.) + * @returns AsyncGenerator that yields ResponseMessage objects + */ + requestStream( + request: ClientRequest | RequestT, + resultSchema: T, + options?: RequestOptions + ): AsyncGenerator>, void, void> { + return super.requestStream(request, resultSchema, options); + } } diff --git a/src/examples/client/simpleOAuthClient.ts b/src/examples/client/simpleOAuthClient.ts index 21dcae012..4dc724d25 100644 --- a/src/examples/client/simpleOAuthClient.ts +++ b/src/examples/client/simpleOAuthClient.ts @@ -200,6 +200,7 @@ class InteractiveOAuthClient { console.log('Commands:'); console.log(' list - List available tools'); console.log(' call [args] - Call a tool'); + console.log(' stream [args] - Call a tool with streaming (shows task status)'); console.log(' quit - Exit the client'); console.log(); @@ -219,8 +220,10 @@ class InteractiveOAuthClient { await this.listTools(); } else if (command.startsWith('call ')) { await this.handleCallTool(command); + } else if (command.startsWith('stream ')) { + await this.handleStreamTool(command); } else { - console.log("❌ Unknown command. Try 'list', 'call ', or 'quit'"); + console.log("❌ Unknown command. Try 'list', 'call ', 'stream ', or 'quit'"); } } catch (error) { if (error instanceof Error && error.message === 'SIGINT') { @@ -321,6 +324,89 @@ class InteractiveOAuthClient { } } + private async handleStreamTool(command: string): Promise { + const parts = command.split(/\s+/); + const toolName = parts[1]; + + if (!toolName) { + console.log('❌ Please specify a tool name'); + return; + } + + // Parse arguments (simple JSON-like format) + let toolArgs: Record = {}; + if (parts.length > 2) { + const argsString = parts.slice(2).join(' '); + try { + toolArgs = JSON.parse(argsString); + } catch { + console.log('❌ Invalid arguments format (expected JSON)'); + return; + } + } + + await this.streamTool(toolName, toolArgs); + } + + private async streamTool(toolName: string, toolArgs: Record): Promise { + if (!this.client) { + console.log('❌ Not connected to server'); + return; + } + + try { + console.log(`\n🔧 Streaming tool '${toolName}'...`); + + const stream = this.client.callToolStream( + { + name: toolName, + arguments: toolArgs + }, + CallToolResultSchema, + { + task: { + taskId: `task-${Date.now()}`, + ttl: 60000 + } + } + ); + + // Iterate through all messages yielded by the generator + for await (const message of stream) { + switch (message.type) { + case 'taskCreated': + console.log(`✓ Task created: ${message.task.taskId}`); + break; + + case 'taskStatus': + console.log(`⟳ Status: ${message.task.status}`); + if (message.task.statusMessage) { + console.log(` ${message.task.statusMessage}`); + } + break; + + case 'result': + console.log('✓ Completed!'); + message.result.content.forEach(content => { + if (content.type === 'text') { + console.log(content.text); + } else { + console.log(content); + } + }); + break; + + case 'error': + console.log('✗ Error:'); + console.log(` ${message.error.message}`); + break; + } + } + } catch (error) { + console.error(`❌ Failed to stream tool '${toolName}':`, error); + } + } + close(): void { this.rl.close(); if (this.client) { diff --git a/src/examples/client/simpleStreamableHttp.ts b/src/examples/client/simpleStreamableHttp.ts index 6627e0b83..4dbd109d6 100644 --- a/src/examples/client/simpleStreamableHttp.ts +++ b/src/examples/client/simpleStreamableHttp.ts @@ -18,6 +18,7 @@ import { ResourceLink, ReadResourceRequest, ReadResourceResultSchema, + RELATED_TASK_META_KEY, ErrorCode, McpError } from '../../types.js'; @@ -60,6 +61,7 @@ function printHelp(): void { console.log(' reconnect - Reconnect to the server'); console.log(' list-tools - List available tools'); console.log(' call-tool [args] - Call a tool with optional JSON arguments'); + console.log(' call-tool-task [args] - Call a tool with task-based execution (example: call-tool-task delay {"duration":3000})'); console.log(' greet [name] - Call the greet tool'); console.log(' multi-greet [name] - Call the multi-greet tool with notifications'); console.log(' collect-info [type] - Test form elicitation with collect-user-info tool (contact/preferences/feedback)'); @@ -143,6 +145,23 @@ function commandLoop(): void { break; } + case 'call-tool-task': + if (args.length < 2) { + console.log('Usage: call-tool-task [args]'); + } else { + const toolName = args[1]; + let toolArgs = {}; + if (args.length > 2) { + try { + toolArgs = JSON.parse(args.slice(2).join(' ')); + } catch { + console.log('Invalid JSON arguments. Using empty args.'); + } + } + await callToolTask(toolName, toolArgs); + } + break; + case 'list-prompts': await listPrompts(); break; @@ -238,6 +257,7 @@ async function connect(url?: string): Promise { } console.log('\n🔔 Elicitation (form) Request Received:'); console.log(`Message: ${request.params.message}`); + console.log(`Related Task: ${request.params._meta?.[RELATED_TASK_META_KEY]?.taskId}`); console.log('Requested Schema:'); console.log(JSON.stringify(request.params.requestedSchema, null, 2)); @@ -784,6 +804,65 @@ async function readResource(uri: string): Promise { } } +async function callToolTask(name: string, args: Record): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + + console.log(`Calling tool '${name}' with task-based execution...`); + console.log('Arguments:', args); + + // Use task-based execution - call now, fetch later + console.log('This will return immediately while processing continues in the background...'); + + try { + // Call the tool with task metadata using streaming API + const stream = client.callToolStream( + { + name, + arguments: args + }, + CallToolResultSchema, + { + task: { + ttl: 60000 // Keep results for 60 seconds + } + } + ); + + console.log('Waiting for task completion...'); + + let lastStatus = ''; + for await (const message of stream) { + switch (message.type) { + case 'taskCreated': + console.log('Task created successfully with ID:', message.task.taskId); + break; + case 'taskStatus': + if (lastStatus !== message.task.status) { + console.log(` ${message.task.status}${message.task.statusMessage ? ` - ${message.task.statusMessage}` : ''}`); + } + lastStatus = message.task.status; + break; + case 'result': + console.log('Task completed!'); + console.log('Tool result:'); + message.result.content.forEach(item => { + if (item.type === 'text') { + console.log(` ${item.text}`); + } + }); + break; + case 'error': + throw message.error; + } + } + } catch (error) { + console.log(`Error with task-based execution: ${error}`); + } +} + async function cleanup(): Promise { if (client && transport) { try { diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 33568bc82..4a9c00e95 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -7,6 +7,7 @@ import { getOAuthProtectedResourceMetadataUrl, mcpAuthMetadataRouter } from '../ import { requireBearerAuth } from '../../server/auth/middleware/bearerAuth.js'; import { CallToolResult, + ElicitResultSchema, GetPromptResult, isInitializeRequest, PrimitiveSchemaDefinition, @@ -14,6 +15,7 @@ import { ResourceLink } from '../../types.js'; import { InMemoryEventStore } from '../shared/inMemoryEventStore.js'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../shared/inMemoryTaskStore.js'; import { setupAuthServer } from './demoInMemoryOAuthProvider.js'; import { OAuthMetadata } from '../../shared/auth.js'; import { checkResourceAllowed } from '../../shared/auth-utils.js'; @@ -24,6 +26,9 @@ import cors from 'cors'; const useOAuth = process.argv.includes('--oauth'); const strictOAuth = process.argv.includes('--oauth-strict'); +// Create shared task store for demonstration +const taskStore = new InMemoryTaskStore(); + // Create an MCP server with implementation details const getServer = () => { const server = new McpServer( @@ -33,7 +38,11 @@ const getServer = () => { icons: [{ src: './mcp.svg', sizes: ['512x512'], mimeType: 'image/svg+xml' }], websiteUrl: 'https://github.com/modelcontextprotocol/typescript-sdk' }, - { capabilities: { logging: {} } } + { + capabilities: { logging: {}, tasks: { requests: { tools: { call: {} } } } }, + taskStore, // Enable task support + taskMessageQueue: new InMemoryTaskMessageQueue() + } ); // Register a simple tool that returns a greeting @@ -119,7 +128,7 @@ const getServer = () => { { infoType: z.enum(['contact', 'preferences', 'feedback']).describe('Type of information to collect') }, - async ({ infoType }): Promise => { + async ({ infoType }, extra): Promise => { let message: string; let requestedSchema: { type: 'object'; @@ -214,12 +223,18 @@ const getServer = () => { } try { - // Use the underlying server instance to elicit input from the client - const result = await server.server.elicitInput({ - mode: 'form', - message, - requestedSchema - }); + // Use sendRequest through the extra parameter to elicit input + const result = await extra.sendRequest( + { + method: 'elicitation/create', + params: { + mode: 'form', + message, + requestedSchema + } + }, + ElicitResultSchema + ); if (result.action === 'accept') { return { @@ -440,6 +455,51 @@ const getServer = () => { } ); + // Register a long-running tool that demonstrates task execution + server.registerToolTask( + 'delay', + { + title: 'Delay', + description: 'A simple tool that delays for a specified duration, useful for testing task execution', + inputSchema: { + duration: z.number().describe('Duration in milliseconds').default(5000) + } + }, + { + async createTask({ duration }, { taskStore, taskRequestedTtl }) { + // Create the task + const task = await taskStore.createTask({ + ttl: taskRequestedTtl + }); + + // Simulate out-of-band work + (async () => { + await new Promise(resolve => setTimeout(resolve, duration)); + await taskStore.storeTaskResult(task.taskId, 'completed', { + content: [ + { + type: 'text', + text: `Completed ${duration}ms delay` + } + ] + }); + })(); + + // Return CreateTaskResult with the created task + return { + task + }; + }, + async getTask(_args, { taskId, taskStore }) { + return await taskStore.getTask(taskId); + }, + async getTaskResult(_args, { taskId, taskStore }) { + const result = await taskStore.getTaskResult(taskId); + return result as CallToolResult; + } + } + ); + return server; }; diff --git a/src/examples/shared/inMemoryTaskStore.test.ts b/src/examples/shared/inMemoryTaskStore.test.ts new file mode 100644 index 000000000..658e4deb1 --- /dev/null +++ b/src/examples/shared/inMemoryTaskStore.test.ts @@ -0,0 +1,936 @@ +import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from './inMemoryTaskStore.js'; +import { TaskCreationParams, Request } from '../../types.js'; +import { QueuedMessage } from '../../shared/task.js'; + +describe('InMemoryTaskStore', () => { + let store: InMemoryTaskStore; + + beforeEach(() => { + store = new InMemoryTaskStore(); + }); + + afterEach(() => { + store.cleanup(); + }); + + describe('createTask', () => { + it('should create a new task with working status', async () => { + const taskParams: TaskCreationParams = { + ttl: 60000 + }; + const request: Request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const task = await store.createTask(taskParams, 123, request); + + expect(task).toBeDefined(); + expect(task.taskId).toBeDefined(); + expect(typeof task.taskId).toBe('string'); + expect(task.taskId.length).toBeGreaterThan(0); + expect(task.status).toBe('working'); + expect(task.ttl).toBe(60000); + expect(task.pollInterval).toBeDefined(); + expect(task.createdAt).toBeDefined(); + expect(new Date(task.createdAt).getTime()).toBeGreaterThan(0); + }); + + it('should create task without ttl', async () => { + const taskParams: TaskCreationParams = {}; + const request: Request = { + method: 'tools/call', + params: {} + }; + + const task = await store.createTask(taskParams, 456, request); + + expect(task).toBeDefined(); + expect(task.ttl).toBeNull(); + }); + + it('should generate unique taskIds', async () => { + const taskParams: TaskCreationParams = {}; + const request: Request = { + method: 'tools/call', + params: {} + }; + + const task1 = await store.createTask(taskParams, 789, request); + const task2 = await store.createTask(taskParams, 790, request); + + expect(task1.taskId).not.toBe(task2.taskId); + }); + }); + + describe('getTask', () => { + it('should return null for non-existent task', async () => { + const task = await store.getTask('non-existent'); + expect(task).toBeNull(); + }); + + it('should return task state', async () => { + const taskParams: TaskCreationParams = {}; + const request: Request = { + method: 'tools/call', + params: {} + }; + + const createdTask = await store.createTask(taskParams, 111, request); + await store.updateTaskStatus(createdTask.taskId, 'working'); + + const task = await store.getTask(createdTask.taskId); + expect(task).toBeDefined(); + expect(task?.status).toBe('working'); + }); + }); + + describe('updateTaskStatus', () => { + let taskId: string; + + beforeEach(async () => { + const taskParams: TaskCreationParams = {}; + const createdTask = await store.createTask(taskParams, 222, { + method: 'tools/call', + params: {} + }); + taskId = createdTask.taskId; + }); + + it('should keep task status as working', async () => { + const task = await store.getTask(taskId); + expect(task?.status).toBe('working'); + }); + + it('should update task status to input_required', async () => { + await store.updateTaskStatus(taskId, 'input_required'); + + const task = await store.getTask(taskId); + expect(task?.status).toBe('input_required'); + }); + + it('should update task status to completed', async () => { + await store.updateTaskStatus(taskId, 'completed'); + + const task = await store.getTask(taskId); + expect(task?.status).toBe('completed'); + }); + + it('should update task status to failed with error', async () => { + await store.updateTaskStatus(taskId, 'failed', 'Something went wrong'); + + const task = await store.getTask(taskId); + expect(task?.status).toBe('failed'); + expect(task?.statusMessage).toBe('Something went wrong'); + }); + + it('should update task status to cancelled', async () => { + await store.updateTaskStatus(taskId, 'cancelled'); + + const task = await store.getTask(taskId); + expect(task?.status).toBe('cancelled'); + }); + + it('should throw if task not found', async () => { + await expect(store.updateTaskStatus('non-existent', 'working')).rejects.toThrow('Task with ID non-existent not found'); + }); + + describe('status lifecycle validation', () => { + it('should allow transition from working to input_required', async () => { + await store.updateTaskStatus(taskId, 'input_required'); + const task = await store.getTask(taskId); + expect(task?.status).toBe('input_required'); + }); + + it('should allow transition from working to completed', async () => { + await store.updateTaskStatus(taskId, 'completed'); + const task = await store.getTask(taskId); + expect(task?.status).toBe('completed'); + }); + + it('should allow transition from working to failed', async () => { + await store.updateTaskStatus(taskId, 'failed'); + const task = await store.getTask(taskId); + expect(task?.status).toBe('failed'); + }); + + it('should allow transition from working to cancelled', async () => { + await store.updateTaskStatus(taskId, 'cancelled'); + const task = await store.getTask(taskId); + expect(task?.status).toBe('cancelled'); + }); + + it('should allow transition from input_required to working', async () => { + await store.updateTaskStatus(taskId, 'input_required'); + await store.updateTaskStatus(taskId, 'working'); + const task = await store.getTask(taskId); + expect(task?.status).toBe('working'); + }); + + it('should allow transition from input_required to completed', async () => { + await store.updateTaskStatus(taskId, 'input_required'); + await store.updateTaskStatus(taskId, 'completed'); + const task = await store.getTask(taskId); + expect(task?.status).toBe('completed'); + }); + + it('should allow transition from input_required to failed', async () => { + await store.updateTaskStatus(taskId, 'input_required'); + await store.updateTaskStatus(taskId, 'failed'); + const task = await store.getTask(taskId); + expect(task?.status).toBe('failed'); + }); + + it('should allow transition from input_required to cancelled', async () => { + await store.updateTaskStatus(taskId, 'input_required'); + await store.updateTaskStatus(taskId, 'cancelled'); + const task = await store.getTask(taskId); + expect(task?.status).toBe('cancelled'); + }); + + it('should reject transition from completed to any other status', async () => { + await store.updateTaskStatus(taskId, 'completed'); + await expect(store.updateTaskStatus(taskId, 'working')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'input_required')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'failed')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'cancelled')).rejects.toThrow('Cannot update task'); + }); + + it('should reject transition from failed to any other status', async () => { + await store.updateTaskStatus(taskId, 'failed'); + await expect(store.updateTaskStatus(taskId, 'working')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'input_required')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'completed')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'cancelled')).rejects.toThrow('Cannot update task'); + }); + + it('should reject transition from cancelled to any other status', async () => { + await store.updateTaskStatus(taskId, 'cancelled'); + await expect(store.updateTaskStatus(taskId, 'working')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'input_required')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'completed')).rejects.toThrow('Cannot update task'); + await expect(store.updateTaskStatus(taskId, 'failed')).rejects.toThrow('Cannot update task'); + }); + }); + }); + + describe('storeTaskResult', () => { + let taskId: string; + + beforeEach(async () => { + const taskParams: TaskCreationParams = { + ttl: 60000 + }; + const createdTask = await store.createTask(taskParams, 333, { + method: 'tools/call', + params: {} + }); + taskId = createdTask.taskId; + }); + + it('should store task result and set status to completed', async () => { + const result = { + content: [{ type: 'text' as const, text: 'Success!' }] + }; + + await store.storeTaskResult(taskId, 'completed', result); + + const task = await store.getTask(taskId); + expect(task?.status).toBe('completed'); + + const storedResult = await store.getTaskResult(taskId); + expect(storedResult).toEqual(result); + }); + + it('should throw if task not found', async () => { + await expect(store.storeTaskResult('non-existent', 'completed', {})).rejects.toThrow('Task with ID non-existent not found'); + }); + + it('should reject storing result for task already in completed status', async () => { + // First complete the task + const firstResult = { + content: [{ type: 'text' as const, text: 'First result' }] + }; + await store.storeTaskResult(taskId, 'completed', firstResult); + + // Try to store result again (should fail) + const secondResult = { + content: [{ type: 'text' as const, text: 'Second result' }] + }; + + await expect(store.storeTaskResult(taskId, 'completed', secondResult)).rejects.toThrow('Cannot store result for task'); + }); + + it('should store result with failed status', async () => { + const result = { + content: [{ type: 'text' as const, text: 'Error details' }], + isError: true + }; + + await store.storeTaskResult(taskId, 'failed', result); + + const task = await store.getTask(taskId); + expect(task?.status).toBe('failed'); + + const storedResult = await store.getTaskResult(taskId); + expect(storedResult).toEqual(result); + }); + + it('should reject storing result for task already in failed status', async () => { + // First fail the task + const firstResult = { + content: [{ type: 'text' as const, text: 'First error' }], + isError: true + }; + await store.storeTaskResult(taskId, 'failed', firstResult); + + // Try to store result again (should fail) + const secondResult = { + content: [{ type: 'text' as const, text: 'Second error' }], + isError: true + }; + + await expect(store.storeTaskResult(taskId, 'failed', secondResult)).rejects.toThrow('Cannot store result for task'); + }); + + it('should reject storing result for cancelled task', async () => { + // Mark task as cancelled + await store.updateTaskStatus(taskId, 'cancelled'); + + // Try to store result (should fail) + const result = { + content: [{ type: 'text' as const, text: 'Cancellation result' }] + }; + + await expect(store.storeTaskResult(taskId, 'completed', result)).rejects.toThrow('Cannot store result for task'); + }); + + it('should allow storing result from input_required status', async () => { + await store.updateTaskStatus(taskId, 'input_required'); + + const result = { + content: [{ type: 'text' as const, text: 'Success!' }] + }; + + await store.storeTaskResult(taskId, 'completed', result); + + const task = await store.getTask(taskId); + expect(task?.status).toBe('completed'); + }); + }); + + describe('getTaskResult', () => { + it('should throw if task not found', async () => { + await expect(store.getTaskResult('non-existent')).rejects.toThrow('Task with ID non-existent not found'); + }); + + it('should throw if task has no result stored', async () => { + const taskParams: TaskCreationParams = {}; + const createdTask = await store.createTask(taskParams, 444, { + method: 'tools/call', + params: {} + }); + + await expect(store.getTaskResult(createdTask.taskId)).rejects.toThrow(`Task ${createdTask.taskId} has no result stored`); + }); + + it('should return stored result', async () => { + const taskParams: TaskCreationParams = {}; + const createdTask = await store.createTask(taskParams, 555, { + method: 'tools/call', + params: {} + }); + + const result = { + content: [{ type: 'text' as const, text: 'Result data' }] + }; + await store.storeTaskResult(createdTask.taskId, 'completed', result); + + const retrieved = await store.getTaskResult(createdTask.taskId); + expect(retrieved).toEqual(result); + }); + }); + + describe('ttl cleanup', () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it('should cleanup task after ttl duration', async () => { + const taskParams: TaskCreationParams = { + ttl: 1000 + }; + const createdTask = await store.createTask(taskParams, 666, { + method: 'tools/call', + params: {} + }); + + // Task should exist initially + let task = await store.getTask(createdTask.taskId); + expect(task).toBeDefined(); + + // Fast-forward past ttl + vi.advanceTimersByTime(1001); + + // Task should be cleaned up + task = await store.getTask(createdTask.taskId); + expect(task).toBeNull(); + }); + + it('should reset cleanup timer when result is stored', async () => { + const taskParams: TaskCreationParams = { + ttl: 1000 + }; + const createdTask = await store.createTask(taskParams, 777, { + method: 'tools/call', + params: {} + }); + + // Fast-forward 500ms + vi.advanceTimersByTime(500); + + // Store result (should reset timer) + await store.storeTaskResult(createdTask.taskId, 'completed', { + content: [{ type: 'text' as const, text: 'Done' }] + }); + + // Fast-forward another 500ms (total 1000ms since creation, but timer was reset) + vi.advanceTimersByTime(500); + + // Task should still exist + const task = await store.getTask(createdTask.taskId); + expect(task).toBeDefined(); + + // Fast-forward remaining time + vi.advanceTimersByTime(501); + + // Now task should be cleaned up + const cleanedTask = await store.getTask(createdTask.taskId); + expect(cleanedTask).toBeNull(); + }); + + it('should not cleanup tasks without ttl', async () => { + const taskParams: TaskCreationParams = {}; + const createdTask = await store.createTask(taskParams, 888, { + method: 'tools/call', + params: {} + }); + + // Fast-forward a long time + vi.advanceTimersByTime(100000); + + // Task should still exist + const task = await store.getTask(createdTask.taskId); + expect(task).toBeDefined(); + }); + + it('should start cleanup timer when task reaches terminal state', async () => { + const taskParams: TaskCreationParams = { + ttl: 1000 + }; + const createdTask = await store.createTask(taskParams, 999, { + method: 'tools/call', + params: {} + }); + + // Task in non-terminal state, fast-forward + vi.advanceTimersByTime(1001); + + // Task should be cleaned up + let task = await store.getTask(createdTask.taskId); + expect(task).toBeNull(); + + // Create another task + const taskParams2: TaskCreationParams = { + ttl: 2000 + }; + const createdTask2 = await store.createTask(taskParams2, 1000, { + method: 'tools/call', + params: {} + }); + + // Update to terminal state + await store.updateTaskStatus(createdTask2.taskId, 'completed'); + + // Fast-forward past original ttl + vi.advanceTimersByTime(2001); + + // Task should be cleaned up + task = await store.getTask(createdTask2.taskId); + expect(task).toBeNull(); + }); + + it('should return actual TTL in task response', async () => { + // Test that the TaskStore returns the actual TTL it will use + // This implementation uses the requested TTL as-is, but implementations + // MAY override it (e.g., enforce maximum TTL limits) + const requestedTtl = 5000; + const taskParams: TaskCreationParams = { + ttl: requestedTtl + }; + const createdTask = await store.createTask(taskParams, 1111, { + method: 'tools/call', + params: {} + }); + + // The returned task should include the actual TTL that will be used + expect(createdTask.ttl).toBe(requestedTtl); + + // Verify the task is cleaned up after the actual TTL + vi.advanceTimersByTime(requestedTtl + 1); + const task = await store.getTask(createdTask.taskId); + expect(task).toBeNull(); + }); + + it('should support null TTL for unlimited lifetime', async () => { + // Test that null TTL means unlimited lifetime + const taskParams: TaskCreationParams = { + ttl: null + }; + const createdTask = await store.createTask(taskParams, 2222, { + method: 'tools/call', + params: {} + }); + + // The returned task should have null TTL + expect(createdTask.ttl).toBeNull(); + + // Task should not be cleaned up even after a long time + vi.advanceTimersByTime(100000); + const task = await store.getTask(createdTask.taskId); + expect(task).toBeDefined(); + expect(task?.taskId).toBe(createdTask.taskId); + }); + + it('should cleanup tasks regardless of status', async () => { + // Test that TTL cleanup happens regardless of task status + const taskParams: TaskCreationParams = { + ttl: 1000 + }; + + // Create tasks in different statuses + const workingTask = await store.createTask(taskParams, 3333, { + method: 'tools/call', + params: {} + }); + + const completedTask = await store.createTask(taskParams, 4444, { + method: 'tools/call', + params: {} + }); + await store.storeTaskResult(completedTask.taskId, 'completed', { + content: [{ type: 'text' as const, text: 'Done' }] + }); + + const failedTask = await store.createTask(taskParams, 5555, { + method: 'tools/call', + params: {} + }); + await store.storeTaskResult(failedTask.taskId, 'failed', { + content: [{ type: 'text' as const, text: 'Error' }] + }); + + // Fast-forward past TTL + vi.advanceTimersByTime(1001); + + // All tasks should be cleaned up regardless of status + expect(await store.getTask(workingTask.taskId)).toBeNull(); + expect(await store.getTask(completedTask.taskId)).toBeNull(); + expect(await store.getTask(failedTask.taskId)).toBeNull(); + }); + }); + + describe('getAllTasks', () => { + it('should return all tasks', async () => { + await store.createTask({}, 1, { + method: 'tools/call', + params: {} + }); + await store.createTask({}, 2, { + method: 'tools/call', + params: {} + }); + await store.createTask({}, 3, { + method: 'tools/call', + params: {} + }); + + const tasks = store.getAllTasks(); + expect(tasks).toHaveLength(3); + // Verify all tasks have unique IDs + const taskIds = tasks.map(t => t.taskId); + expect(new Set(taskIds).size).toBe(3); + }); + + it('should return empty array when no tasks', () => { + const tasks = store.getAllTasks(); + expect(tasks).toEqual([]); + }); + }); + + describe('listTasks', () => { + it('should return empty list when no tasks', async () => { + const result = await store.listTasks(); + expect(result.tasks).toEqual([]); + expect(result.nextCursor).toBeUndefined(); + }); + + it('should return all tasks when less than page size', async () => { + await store.createTask({}, 1, { + method: 'tools/call', + params: {} + }); + await store.createTask({}, 2, { + method: 'tools/call', + params: {} + }); + await store.createTask({}, 3, { + method: 'tools/call', + params: {} + }); + + const result = await store.listTasks(); + expect(result.tasks).toHaveLength(3); + expect(result.nextCursor).toBeUndefined(); + }); + + it('should paginate when more than page size', async () => { + // Create 15 tasks (page size is 10) + for (let i = 1; i <= 15; i++) { + await store.createTask({}, i, { + method: 'tools/call', + params: {} + }); + } + + // Get first page + const page1 = await store.listTasks(); + expect(page1.tasks).toHaveLength(10); + expect(page1.nextCursor).toBeDefined(); + + // Get second page using cursor + const page2 = await store.listTasks(page1.nextCursor); + expect(page2.tasks).toHaveLength(5); + expect(page2.nextCursor).toBeUndefined(); + }); + + it('should throw error for invalid cursor', async () => { + await store.createTask({}, 1, { + method: 'tools/call', + params: {} + }); + + await expect(store.listTasks('non-existent-cursor')).rejects.toThrow('Invalid cursor: non-existent-cursor'); + }); + + it('should continue from cursor correctly', async () => { + // Create 5 tasks + for (let i = 1; i <= 5; i++) { + await store.createTask({}, i, { + method: 'tools/call', + params: {} + }); + } + + // Get first 3 tasks + const allTaskIds = Array.from(store.getAllTasks().map(t => t.taskId)); + const result = await store.listTasks(allTaskIds[2]); + + // Should get tasks after the third task + expect(result.tasks).toHaveLength(2); + }); + }); + + describe('cleanup', () => { + it('should clear all timers and tasks', async () => { + await store.createTask({ ttl: 1000 }, 1, { + method: 'tools/call', + params: {} + }); + await store.createTask({ ttl: 2000 }, 2, { + method: 'tools/call', + params: {} + }); + + expect(store.getAllTasks()).toHaveLength(2); + + store.cleanup(); + + expect(store.getAllTasks()).toHaveLength(0); + }); + }); +}); + +describe('InMemoryTaskMessageQueue', () => { + let queue: InMemoryTaskMessageQueue; + + beforeEach(() => { + queue = new InMemoryTaskMessageQueue(); + }); + + describe('enqueue and dequeue', () => { + it('should enqueue and dequeue request messages', async () => { + const requestMessage: QueuedMessage = { + type: 'request', + message: { + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { name: 'test-tool', arguments: {} } + }, + timestamp: Date.now() + }; + + await queue.enqueue('task-1', requestMessage); + const dequeued = await queue.dequeue('task-1'); + + expect(dequeued).toEqual(requestMessage); + }); + + it('should enqueue and dequeue notification messages', async () => { + const notificationMessage: QueuedMessage = { + type: 'notification', + message: { + jsonrpc: '2.0', + method: 'notifications/progress', + params: { progress: 50, total: 100 } + }, + timestamp: Date.now() + }; + + await queue.enqueue('task-2', notificationMessage); + const dequeued = await queue.dequeue('task-2'); + + expect(dequeued).toEqual(notificationMessage); + }); + + it('should enqueue and dequeue response messages', async () => { + const responseMessage: QueuedMessage = { + type: 'response', + message: { + jsonrpc: '2.0', + id: 42, + result: { content: [{ type: 'text', text: 'Success' }] } + }, + timestamp: Date.now() + }; + + await queue.enqueue('task-3', responseMessage); + const dequeued = await queue.dequeue('task-3'); + + expect(dequeued).toEqual(responseMessage); + }); + + it('should return undefined when dequeuing from empty queue', async () => { + const dequeued = await queue.dequeue('task-empty'); + expect(dequeued).toBeUndefined(); + }); + + it('should maintain FIFO order for mixed message types', async () => { + const request: QueuedMessage = { + type: 'request', + message: { + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: {} + }, + timestamp: 1000 + }; + + const notification: QueuedMessage = { + type: 'notification', + message: { + jsonrpc: '2.0', + method: 'notifications/progress', + params: {} + }, + timestamp: 2000 + }; + + const response: QueuedMessage = { + type: 'response', + message: { + jsonrpc: '2.0', + id: 1, + result: {} + }, + timestamp: 3000 + }; + + await queue.enqueue('task-fifo', request); + await queue.enqueue('task-fifo', notification); + await queue.enqueue('task-fifo', response); + + expect(await queue.dequeue('task-fifo')).toEqual(request); + expect(await queue.dequeue('task-fifo')).toEqual(notification); + expect(await queue.dequeue('task-fifo')).toEqual(response); + expect(await queue.dequeue('task-fifo')).toBeUndefined(); + }); + }); + + describe('dequeueAll', () => { + it('should dequeue all messages including responses', async () => { + const request: QueuedMessage = { + type: 'request', + message: { + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: {} + }, + timestamp: 1000 + }; + + const response: QueuedMessage = { + type: 'response', + message: { + jsonrpc: '2.0', + id: 1, + result: {} + }, + timestamp: 2000 + }; + + const notification: QueuedMessage = { + type: 'notification', + message: { + jsonrpc: '2.0', + method: 'notifications/progress', + params: {} + }, + timestamp: 3000 + }; + + await queue.enqueue('task-all', request); + await queue.enqueue('task-all', response); + await queue.enqueue('task-all', notification); + + const all = await queue.dequeueAll('task-all'); + + expect(all).toHaveLength(3); + expect(all[0]).toEqual(request); + expect(all[1]).toEqual(response); + expect(all[2]).toEqual(notification); + }); + + it('should return empty array for non-existent task', async () => { + const all = await queue.dequeueAll('non-existent'); + expect(all).toEqual([]); + }); + + it('should clear the queue after dequeueAll', async () => { + const message: QueuedMessage = { + type: 'request', + message: { + jsonrpc: '2.0', + id: 1, + method: 'test', + params: {} + }, + timestamp: Date.now() + }; + + await queue.enqueue('task-clear', message); + await queue.dequeueAll('task-clear'); + + const dequeued = await queue.dequeue('task-clear'); + expect(dequeued).toBeUndefined(); + }); + }); + + describe('queue size limits', () => { + it('should throw when maxSize is exceeded', async () => { + const message: QueuedMessage = { + type: 'request', + message: { + jsonrpc: '2.0', + id: 1, + method: 'test', + params: {} + }, + timestamp: Date.now() + }; + + await queue.enqueue('task-limit', message, undefined, 2); + await queue.enqueue('task-limit', message, undefined, 2); + + await expect(queue.enqueue('task-limit', message, undefined, 2)).rejects.toThrow('Task message queue overflow'); + }); + + it('should allow enqueue when under maxSize', async () => { + const message: QueuedMessage = { + type: 'response', + message: { + jsonrpc: '2.0', + id: 1, + result: {} + }, + timestamp: Date.now() + }; + + await expect(queue.enqueue('task-ok', message, undefined, 5)).resolves.toBeUndefined(); + }); + }); + + describe('task isolation', () => { + it('should isolate messages between different tasks', async () => { + const message1: QueuedMessage = { + type: 'request', + message: { + jsonrpc: '2.0', + id: 1, + method: 'test1', + params: {} + }, + timestamp: 1000 + }; + + const message2: QueuedMessage = { + type: 'response', + message: { + jsonrpc: '2.0', + id: 2, + result: {} + }, + timestamp: 2000 + }; + + await queue.enqueue('task-a', message1); + await queue.enqueue('task-b', message2); + + expect(await queue.dequeue('task-a')).toEqual(message1); + expect(await queue.dequeue('task-b')).toEqual(message2); + expect(await queue.dequeue('task-a')).toBeUndefined(); + expect(await queue.dequeue('task-b')).toBeUndefined(); + }); + }); + + describe('response message error handling', () => { + it('should handle response messages with errors', async () => { + const errorResponse: QueuedMessage = { + type: 'error', + message: { + jsonrpc: '2.0', + id: 1, + error: { + code: -32600, + message: 'Invalid Request' + } + }, + timestamp: Date.now() + }; + + await queue.enqueue('task-error', errorResponse); + const dequeued = await queue.dequeue('task-error'); + + expect(dequeued).toEqual(errorResponse); + expect(dequeued?.type).toBe('error'); + }); + }); +}); diff --git a/src/examples/shared/inMemoryTaskStore.ts b/src/examples/shared/inMemoryTaskStore.ts new file mode 100644 index 000000000..0e3716bdf --- /dev/null +++ b/src/examples/shared/inMemoryTaskStore.ts @@ -0,0 +1,284 @@ +import { Task, Request, RequestId, Result } from '../../types.js'; +import { TaskStore, isTerminal, TaskMessageQueue, QueuedMessage, CreateTaskOptions } from '../../shared/task.js'; +import { randomBytes } from 'crypto'; + +interface StoredTask { + task: Task; + request: Request; + requestId: RequestId; + result?: Result; +} + +/** + * A simple in-memory implementation of TaskStore for demonstration purposes. + * + * This implementation stores all tasks in memory and provides automatic cleanup + * based on the ttl duration specified in the task creation parameters. + * + * Note: This is not suitable for production use as all data is lost on restart. + * For production, consider implementing TaskStore with a database or distributed cache. + */ +export class InMemoryTaskStore implements TaskStore { + private tasks = new Map(); + private cleanupTimers = new Map>(); + + /** + * Generates a unique task ID. + * Uses 16 bytes of random data encoded as hex (32 characters). + */ + private generateTaskId(): string { + return randomBytes(16).toString('hex'); + } + + async createTask(taskParams: CreateTaskOptions, requestId: RequestId, request: Request, _sessionId?: string): Promise { + // Generate a unique task ID + const taskId = this.generateTaskId(); + + // Ensure uniqueness + if (this.tasks.has(taskId)) { + throw new Error(`Task with ID ${taskId} already exists`); + } + + const actualTtl = taskParams.ttl ?? null; + + // Create task with generated ID and timestamps + const createdAt = new Date().toISOString(); + const task: Task = { + taskId, + status: 'working', + ttl: actualTtl, + createdAt, + lastUpdatedAt: createdAt, + pollInterval: taskParams.pollInterval ?? 1000 + }; + + this.tasks.set(taskId, { + task, + request, + requestId + }); + + // Schedule cleanup if ttl is specified + // Cleanup occurs regardless of task status + if (actualTtl) { + const timer = setTimeout(() => { + this.tasks.delete(taskId); + this.cleanupTimers.delete(taskId); + }, actualTtl); + + this.cleanupTimers.set(taskId, timer); + } + + return task; + } + + async getTask(taskId: string, _sessionId?: string): Promise { + const stored = this.tasks.get(taskId); + return stored ? { ...stored.task } : null; + } + + async storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result, _sessionId?: string): Promise { + const stored = this.tasks.get(taskId); + if (!stored) { + throw new Error(`Task with ID ${taskId} not found`); + } + + // Don't allow storing results for tasks already in terminal state + if (isTerminal(stored.task.status)) { + throw new Error( + `Cannot store result for task ${taskId} in terminal status '${stored.task.status}'. Task results can only be stored once.` + ); + } + + stored.result = result; + stored.task.status = status; + stored.task.lastUpdatedAt = new Date().toISOString(); + + // Reset cleanup timer to start from now (if ttl is set) + if (stored.task.ttl) { + const existingTimer = this.cleanupTimers.get(taskId); + if (existingTimer) { + clearTimeout(existingTimer); + } + + const timer = setTimeout(() => { + this.tasks.delete(taskId); + this.cleanupTimers.delete(taskId); + }, stored.task.ttl); + + this.cleanupTimers.set(taskId, timer); + } + } + + async getTaskResult(taskId: string, _sessionId?: string): Promise { + const stored = this.tasks.get(taskId); + if (!stored) { + throw new Error(`Task with ID ${taskId} not found`); + } + + if (!stored.result) { + throw new Error(`Task ${taskId} has no result stored`); + } + + return stored.result; + } + + async updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string, _sessionId?: string): Promise { + const stored = this.tasks.get(taskId); + if (!stored) { + throw new Error(`Task with ID ${taskId} not found`); + } + + // Don't allow transitions from terminal states + if (isTerminal(stored.task.status)) { + throw new Error( + `Cannot update task ${taskId} from terminal status '${stored.task.status}' to '${status}'. Terminal states (completed, failed, cancelled) cannot transition to other states.` + ); + } + + stored.task.status = status; + if (statusMessage) { + stored.task.statusMessage = statusMessage; + } + + stored.task.lastUpdatedAt = new Date().toISOString(); + + // If task is in a terminal state and has ttl, start cleanup timer + if (isTerminal(status) && stored.task.ttl) { + const existingTimer = this.cleanupTimers.get(taskId); + if (existingTimer) { + clearTimeout(existingTimer); + } + + const timer = setTimeout(() => { + this.tasks.delete(taskId); + this.cleanupTimers.delete(taskId); + }, stored.task.ttl); + + this.cleanupTimers.set(taskId, timer); + } + } + + async listTasks(cursor?: string, _sessionId?: string): Promise<{ tasks: Task[]; nextCursor?: string }> { + const PAGE_SIZE = 10; + const allTaskIds = Array.from(this.tasks.keys()); + + let startIndex = 0; + if (cursor) { + const cursorIndex = allTaskIds.indexOf(cursor); + if (cursorIndex >= 0) { + startIndex = cursorIndex + 1; + } else { + // Invalid cursor - throw error + throw new Error(`Invalid cursor: ${cursor}`); + } + } + + const pageTaskIds = allTaskIds.slice(startIndex, startIndex + PAGE_SIZE); + const tasks = pageTaskIds.map(taskId => { + const stored = this.tasks.get(taskId)!; + return { ...stored.task }; + }); + + const nextCursor = startIndex + PAGE_SIZE < allTaskIds.length ? pageTaskIds[pageTaskIds.length - 1] : undefined; + + return { tasks, nextCursor }; + } + + /** + * Cleanup all timers (useful for testing or graceful shutdown) + */ + cleanup(): void { + for (const timer of this.cleanupTimers.values()) { + clearTimeout(timer); + } + this.cleanupTimers.clear(); + this.tasks.clear(); + } + + /** + * Get all tasks (useful for debugging) + */ + getAllTasks(): Task[] { + return Array.from(this.tasks.values()).map(stored => ({ ...stored.task })); + } +} + +/** + * A simple in-memory implementation of TaskMessageQueue for demonstration purposes. + * + * This implementation stores messages in memory, organized by task ID and optional session ID. + * Messages are stored in FIFO queues per task. + * + * Note: This is not suitable for production use in distributed systems. + * For production, consider implementing TaskMessageQueue with Redis or other distributed queues. + */ +export class InMemoryTaskMessageQueue implements TaskMessageQueue { + private queues = new Map(); + + /** + * Generates a queue key from taskId. + * SessionId is intentionally ignored because taskIds are globally unique + * and tasks need to be accessible across HTTP requests/sessions. + */ + private getQueueKey(taskId: string, _sessionId?: string): string { + return taskId; + } + + /** + * Gets or creates a queue for the given task and session. + */ + private getQueue(taskId: string, sessionId?: string): QueuedMessage[] { + const key = this.getQueueKey(taskId, sessionId); + let queue = this.queues.get(key); + if (!queue) { + queue = []; + this.queues.set(key, queue); + } + return queue; + } + + /** + * Adds a message to the end of the queue for a specific task. + * Atomically checks queue size and throws if maxSize would be exceeded. + * @param taskId The task identifier + * @param message The message to enqueue + * @param sessionId Optional session ID for binding the operation to a specific session + * @param maxSize Optional maximum queue size - if specified and queue is full, throws an error + * @throws Error if maxSize is specified and would be exceeded + */ + async enqueue(taskId: string, message: QueuedMessage, sessionId?: string, maxSize?: number): Promise { + const queue = this.getQueue(taskId, sessionId); + + // Atomically check size and enqueue + if (maxSize !== undefined && queue.length >= maxSize) { + throw new Error(`Task message queue overflow: queue size (${queue.length}) exceeds maximum (${maxSize})`); + } + + queue.push(message); + } + + /** + * Removes and returns the first message from the queue for a specific task. + * @param taskId The task identifier + * @param sessionId Optional session ID for binding the query to a specific session + * @returns The first message, or undefined if the queue is empty + */ + async dequeue(taskId: string, sessionId?: string): Promise { + const queue = this.getQueue(taskId, sessionId); + return queue.shift(); + } + + /** + * Removes and returns all messages from the queue for a specific task. + * @param taskId The task identifier + * @param sessionId Optional session ID for binding the query to a specific session + * @returns Array of all messages that were in the queue + */ + async dequeueAll(taskId: string, sessionId?: string): Promise { + const key = this.getQueueKey(taskId, sessionId); + const queue = this.queues.get(key) ?? []; + this.queues.delete(key); + return queue; + } +} diff --git a/src/integration-tests/taskLifecycle.test.ts b/src/integration-tests/taskLifecycle.test.ts new file mode 100644 index 000000000..3aba46b07 --- /dev/null +++ b/src/integration-tests/taskLifecycle.test.ts @@ -0,0 +1,1588 @@ +import { createServer, type Server } from 'node:http'; +import { AddressInfo } from 'node:net'; +import { randomUUID } from 'node:crypto'; +import { Client } from '../client/index.js'; +import { StreamableHTTPClientTransport } from '../client/streamableHttp.js'; +import { McpServer } from '../server/mcp.js'; +import { StreamableHTTPServerTransport } from '../server/streamableHttp.js'; +import { CallToolResultSchema, CreateTaskResultSchema, ElicitRequestSchema, ElicitResultSchema, TaskSchema } from '../types.js'; +import { z } from 'zod'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../examples/shared/inMemoryTaskStore.js'; +import type { TaskRequestOptions } from '../shared/protocol.js'; + +describe('Task Lifecycle Integration Tests', () => { + let server: Server; + let mcpServer: McpServer; + let serverTransport: StreamableHTTPServerTransport; + let baseUrl: URL; + let taskStore: InMemoryTaskStore; + + beforeEach(async () => { + // Create task store + taskStore = new InMemoryTaskStore(); + + // Create MCP server with task support + mcpServer = new McpServer( + { name: 'test-server', version: '1.0.0' }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + }, + list: {}, + cancel: {} + } + }, + taskStore, + taskMessageQueue: new InMemoryTaskMessageQueue() + } + ); + + // Register a long-running tool using registerToolTask + mcpServer.registerToolTask( + 'long-task', + { + title: 'Long Running Task', + description: 'A tool that takes time to complete', + inputSchema: { + duration: z.number().describe('Duration in milliseconds').default(1000), + shouldFail: z.boolean().describe('Whether the task should fail').default(false) + } + }, + { + async createTask({ duration, shouldFail }, extra) { + const task = await extra.taskStore.createTask({ + ttl: 60000, + pollInterval: 100 + }); + + // Simulate async work + (async () => { + await new Promise(resolve => setTimeout(resolve, duration)); + + try { + if (shouldFail) { + await extra.taskStore.storeTaskResult(task.taskId, 'failed', { + content: [{ type: 'text', text: 'Task failed as requested' }], + isError: true + }); + } else { + await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: `Completed after ${duration}ms` }] + }); + } + } catch { + // Task may have been cleaned up if test ended + } + })(); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + // Register a tool that requires input via elicitation + mcpServer.registerToolTask( + 'input-task', + { + title: 'Input Required Task', + description: 'A tool that requires user input', + inputSchema: { + userName: z.string().describe('User name').optional() + } + }, + { + async createTask({ userName }, extra) { + const task = await extra.taskStore.createTask({ + ttl: 60000, + pollInterval: 100 + }); + + // Perform async work that requires elicitation + (async () => { + await new Promise(resolve => setTimeout(resolve, 100)); + + // If userName not provided, request it via elicitation + if (!userName) { + const elicitationResult = await extra.sendRequest( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: 'What is your name?', + requestedSchema: { + type: 'object', + properties: { + userName: { type: 'string' } + }, + required: ['userName'] + } + } + }, + ElicitResultSchema, + { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions + ); + + // Complete with the elicited name + const name = + elicitationResult.action === 'accept' && elicitationResult.content + ? elicitationResult.content.userName + : 'Unknown'; + try { + await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: `Hello, ${name}!` }] + }); + } catch { + // Task may have been cleaned up if test ended + } + } else { + // Complete immediately if userName was provided + try { + await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: `Hello, ${userName}!` }] + }); + } catch { + // Task may have been cleaned up if test ended + } + } + })(); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + // Create transport + serverTransport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID() + }); + + await mcpServer.connect(serverTransport); + + // Create HTTP server + server = createServer(async (req, res) => { + await serverTransport.handleRequest(req, res); + }); + + // Start server + baseUrl = await new Promise(resolve => { + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo; + resolve(new URL(`http://127.0.0.1:${addr.port}`)); + }); + }); + }); + + afterEach(async () => { + taskStore.cleanup(); + await mcpServer.close().catch(() => {}); + await serverTransport.close().catch(() => {}); + server.close(); + }); + + describe('Task Creation and Completion', () => { + it('should create a task and return CreateTaskResult', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 500, + shouldFail: false + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + + // Verify CreateTaskResult structure + expect(createResult).toHaveProperty('task'); + expect(createResult.task).toHaveProperty('taskId'); + expect(createResult.task.status).toBe('working'); + expect(createResult.task.ttl).toBe(60000); + expect(createResult.task.createdAt).toBeDefined(); + expect(createResult.task.pollInterval).toBe(100); + + // Verify task is stored in taskStore + const taskId = createResult.task.taskId; + const storedTask = await taskStore.getTask(taskId); + expect(storedTask).toBeDefined(); + expect(storedTask?.taskId).toBe(taskId); + expect(storedTask?.status).toBe('working'); + + // Wait for completion + await new Promise(resolve => setTimeout(resolve, 600)); + + // Verify task completed + const completedTask = await taskStore.getTask(taskId); + expect(completedTask?.status).toBe('completed'); + + // Verify result is stored + const result = await taskStore.getTaskResult(taskId); + expect(result).toBeDefined(); + expect(result.content).toEqual([{ type: 'text', text: 'Completed after 500ms' }]); + + await transport.close(); + }); + + it('should handle task failure correctly', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task that will fail + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 300, + shouldFail: true + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + + const taskId = createResult.task.taskId; + + // Wait for failure + await new Promise(resolve => setTimeout(resolve, 400)); + + // Verify task failed + const task = await taskStore.getTask(taskId); + expect(task?.status).toBe('failed'); + + // Verify error result is stored + const result = await taskStore.getTaskResult(taskId); + expect(result.content).toEqual([{ type: 'text', text: 'Task failed as requested' }]); + expect(result.isError).toBe(true); + + await transport.close(); + }); + }); + + describe('Task Cancellation', () => { + it('should cancel a working task', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a long-running task + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 5000 + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + + const taskId = createResult.task.taskId; + + // Verify task is working + let task = await taskStore.getTask(taskId); + expect(task?.status).toBe('working'); + + // Cancel the task + await taskStore.updateTaskStatus(taskId, 'cancelled'); + + // Verify task is cancelled + task = await taskStore.getTask(taskId); + expect(task?.status).toBe('cancelled'); + + await transport.close(); + }); + + it('should reject cancellation of completed task', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a quick task + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 100 + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + + const taskId = createResult.task.taskId; + + // Wait for completion + await new Promise(resolve => setTimeout(resolve, 200)); + + // Verify task is completed + const task = await taskStore.getTask(taskId); + expect(task?.status).toBe('completed'); + + // Try to cancel (should fail) + await expect(taskStore.updateTaskStatus(taskId, 'cancelled')).rejects.toThrow(); + + await transport.close(); + }); + }); + + describe('Multiple Queued Messages', () => { + it('should deliver multiple queued messages in order', async () => { + // Register a tool that sends multiple server requests during execution + mcpServer.registerToolTask( + 'multi-request-task', + { + title: 'Multi Request Task', + description: 'A tool that sends multiple server requests', + inputSchema: { + requestCount: z.number().describe('Number of requests to send').default(3) + } + }, + { + async createTask({ requestCount }, extra) { + const task = await extra.taskStore.createTask({ + ttl: 60000, + pollInterval: 100 + }); + + // Perform async work that sends multiple requests + (async () => { + await new Promise(resolve => setTimeout(resolve, 100)); + + const responses: string[] = []; + + // Send multiple elicitation requests + for (let i = 0; i < requestCount; i++) { + const elicitationResult = await extra.sendRequest( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: `Request ${i + 1} of ${requestCount}`, + requestedSchema: { + type: 'object', + properties: { + response: { type: 'string' } + }, + required: ['response'] + } + } + }, + ElicitResultSchema, + { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions + ); + + if (elicitationResult.action === 'accept' && elicitationResult.content) { + responses.push(elicitationResult.content.response as string); + } + } + + // Complete with all responses + try { + await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: `Received responses: ${responses.join(', ')}` }] + }); + } catch { + // Task may have been cleaned up if test ended + } + })(); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + const receivedMessages: Array<{ method: string; message: string }> = []; + + // Set up elicitation handler on client to track message order + client.setRequestHandler(ElicitRequestSchema, async request => { + // Track the message + receivedMessages.push({ + method: request.method, + message: request.params.message + }); + + // Extract the request number from the message + const match = request.params.message.match(/Request (\d+) of (\d+)/); + const requestNum = match ? match[1] : 'unknown'; + + // Respond with the request number + return { + action: 'accept' as const, + content: { + response: `Response ${requestNum}` + } + }; + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task that will send 3 requests + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'multi-request-task', + arguments: { + requestCount: 3 + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + + const taskId = createResult.task.taskId; + + // Wait for messages to be queued + await new Promise(resolve => setTimeout(resolve, 200)); + + // Call tasks/result to receive all queued messages + // This should deliver all 3 elicitation requests in order + const result = await client.request( + { + method: 'tasks/result', + params: { taskId } + }, + CallToolResultSchema + ); + + // Verify all messages were delivered in order + expect(receivedMessages.length).toBe(3); + expect(receivedMessages[0].message).toBe('Request 1 of 3'); + expect(receivedMessages[1].message).toBe('Request 2 of 3'); + expect(receivedMessages[2].message).toBe('Request 3 of 3'); + + // Verify final result includes all responses + expect(result.content).toEqual([{ type: 'text', text: 'Received responses: Response 1, Response 2, Response 3' }]); + + // Verify task is completed + const task = await client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task.status).toBe('completed'); + + await transport.close(); + }, 10000); + }); + + describe('Input Required Flow', () => { + it('should handle elicitation during tool execution', async () => { + const elicitClient = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + // Set up elicitation handler on client + elicitClient.setRequestHandler(ElicitRequestSchema, async request => { + // Verify elicitation request structure + expect(request.params.message).toBe('What is your name?'); + expect(request.params.requestedSchema).toHaveProperty('properties'); + + // Respond with user input + const response = { + action: 'accept' as const, + content: { + userName: 'Alice' + } + }; + return response; + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await elicitClient.connect(transport); + + // Create a task without userName (will trigger elicitation) + const createResult = await elicitClient.request( + { + method: 'tools/call', + params: { + name: 'input-task', + arguments: {}, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + + const taskId = createResult.task.taskId; + + // Wait for elicitation to occur + await new Promise(resolve => setTimeout(resolve, 200)); + + // Check if the elicitation request was queued + + // Call tasks/result to receive the queued elicitation request + // This should deliver the elicitation request via the side-channel + // and then deliver the final result after the client responds + const result = await elicitClient.request( + { + method: 'tasks/result', + params: { taskId } + }, + CallToolResultSchema + ); + + // Verify final result is delivered correctly + expect(result.content).toEqual([{ type: 'text', text: 'Hello, Alice!' }]); + + // Verify task is now completed + const task = await elicitClient.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task.status).toBe('completed'); + + await transport.close(); + }, 10000); // Increase timeout to 10 seconds for debugging + + it('should complete immediately when input is provided upfront', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task with userName provided (no elicitation needed) + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'input-task', + arguments: { + userName: 'Bob' + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + + const taskId = createResult.task.taskId; + + // Wait for completion + await new Promise(resolve => setTimeout(resolve, 300)); + + // Verify task completed without elicitation + const task = await client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task.status).toBe('completed'); + + // Get result + const result = await client.request( + { + method: 'tasks/result', + params: { taskId } + }, + CallToolResultSchema + ); + + expect(result.content).toEqual([{ type: 'text', text: 'Hello, Bob!' }]); + + await transport.close(); + }); + }); + + describe('Task Listing and Pagination', () => { + it('should list tasks', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create multiple tasks + const taskIds: string[] = []; + for (let i = 0; i < 3; i++) { + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 1000 + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + taskIds.push(createResult.task.taskId); + } + + // List tasks using taskStore + const listResult = await taskStore.listTasks(); + + expect(listResult.tasks.length).toBeGreaterThanOrEqual(3); + expect(listResult.tasks.some(t => taskIds.includes(t.taskId))).toBe(true); + + await transport.close(); + }); + + it('should handle pagination with large datasets', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create 15 tasks (more than page size of 10) + for (let i = 0; i < 15; i++) { + await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 5000 + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + } + + // Get first page using taskStore + const page1 = await taskStore.listTasks(); + + expect(page1.tasks.length).toBe(10); + expect(page1.nextCursor).toBeDefined(); + + // Get second page + const page2 = await taskStore.listTasks(page1.nextCursor); + + expect(page2.tasks.length).toBeGreaterThanOrEqual(5); + + await transport.close(); + }); + }); + + describe('Error Handling', () => { + it('should return null for non-existent task', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Try to get non-existent task + const task = await taskStore.getTask('non-existent'); + expect(task).toBeNull(); + + await transport.close(); + }); + + it('should return error for invalid task operation', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create and complete a task + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 100 + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + + const taskId = createResult.task.taskId; + + // Wait for completion + await new Promise(resolve => setTimeout(resolve, 200)); + + // Try to cancel completed task (should fail) + await expect(taskStore.updateTaskStatus(taskId, 'cancelled')).rejects.toThrow(); + + await transport.close(); + }); + }); + + describe('TTL and Cleanup', () => { + it('should respect TTL in task creation', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task with specific TTL + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 100 + }, + task: { + ttl: 5000 + } + } + }, + CreateTaskResultSchema + ); + + const taskId = createResult.task.taskId; + + // Verify TTL is set correctly + expect(createResult.task.ttl).toBe(60000); // The task store uses 60000 as default + + // Task should exist + const task = await client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task).toBeDefined(); + expect(task.ttl).toBe(60000); + + await transport.close(); + }); + }); + + describe('Task Cancellation with Queued Messages', () => { + it('should clear queue and deliver no messages when task is cancelled before tasks/result', async () => { + // Register a tool that queues messages but doesn't complete immediately + mcpServer.registerToolTask( + 'cancellable-task', + { + title: 'Cancellable Task', + description: 'A tool that queues messages and can be cancelled', + inputSchema: { + messageCount: z.number().describe('Number of messages to queue').default(2) + } + }, + { + async createTask({ messageCount }, extra) { + const task = await extra.taskStore.createTask({ + ttl: 60000, + pollInterval: 100 + }); + + // Perform async work that queues messages + (async () => { + try { + await new Promise(resolve => setTimeout(resolve, 100)); + + // Queue multiple elicitation requests + for (let i = 0; i < messageCount; i++) { + // Send request but don't await - let it queue + extra + .sendRequest( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: `Message ${i + 1} of ${messageCount}`, + requestedSchema: { + type: 'object', + properties: { + response: { type: 'string' } + }, + required: ['response'] + } + } + }, + ElicitResultSchema, + { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions + ) + .catch(() => { + // Ignore errors from cancelled requests + }); + } + + // Don't complete - let the task be cancelled + // Wait indefinitely (or until cancelled) + await new Promise(() => {}); + } catch { + // Ignore errors - task was cancelled + } + })().catch(() => { + // Catch any unhandled errors from the async execution + }); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + let elicitationCallCount = 0; + + // Set up elicitation handler to track if any messages are delivered + client.setRequestHandler(ElicitRequestSchema, async () => { + elicitationCallCount++; + return { + action: 'accept' as const, + content: { + response: 'Should not be called' + } + }; + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task that will queue messages + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'cancellable-task', + arguments: { + messageCount: 2 + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + + const taskId = createResult.task.taskId; + + // Wait for messages to be queued + await new Promise(resolve => setTimeout(resolve, 200)); + + // Verify task is working and messages are queued + let task = await client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task.status).toBe('working'); + + // Cancel the task before calling tasks/result using the proper tasks/cancel request + // This will trigger queue cleanup via _clearTaskQueue in the handler + await client.request( + { + method: 'tasks/cancel', + params: { taskId } + }, + z.object({ _meta: z.record(z.unknown()).optional() }) + ); + + // Verify task is cancelled + task = await client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task.status).toBe('cancelled'); + + // Attempt to call tasks/result + // When a task is cancelled, the system needs to clear the message queue + // and reject any pending message delivery promises, meaning no further + // messages should be delivered for a cancelled task. + try { + await client.request( + { + method: 'tasks/result', + params: { taskId } + }, + CallToolResultSchema + ); + } catch { + // tasks/result might throw an error for cancelled tasks without a result + // This is acceptable behavior + } + + // Verify no elicitation messages were delivered, as the queue should be cleared immediately on cancellation + expect(elicitationCallCount).toBe(0); + + // Verify queue remains cleared on subsequent calls + try { + await client.request( + { + method: 'tasks/result', + params: { taskId } + }, + CallToolResultSchema + ); + } catch { + // Expected - task is cancelled + } + + // Still no messages should have been delivered + expect(elicitationCallCount).toBe(0); + + await transport.close(); + }, 10000); + }); + + describe('Continuous Message Delivery', () => { + it('should deliver messages immediately while tasks/result is blocking', async () => { + // Register a tool that queues messages over time + mcpServer.registerToolTask( + 'streaming-task', + { + title: 'Streaming Task', + description: 'A tool that sends messages over time', + inputSchema: { + messageCount: z.number().describe('Number of messages to send').default(3), + delayBetweenMessages: z.number().describe('Delay between messages in ms').default(200) + } + }, + { + async createTask({ messageCount, delayBetweenMessages }, extra) { + const task = await extra.taskStore.createTask({ + ttl: 60000, + pollInterval: 100 + }); + + // Perform async work that sends messages over time + (async () => { + try { + // Wait a bit before starting to send messages + await new Promise(resolve => setTimeout(resolve, 100)); + + const responses: string[] = []; + + // Send messages with delays between them + for (let i = 0; i < messageCount; i++) { + const elicitationResult = await extra.sendRequest( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: `Streaming message ${i + 1} of ${messageCount}`, + requestedSchema: { + type: 'object', + properties: { + response: { type: 'string' } + }, + required: ['response'] + } + } + }, + ElicitResultSchema, + { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions + ); + + if (elicitationResult.action === 'accept' && elicitationResult.content) { + responses.push(elicitationResult.content.response as string); + } + + // Wait before sending next message (if not the last one) + if (i < messageCount - 1) { + await new Promise(resolve => setTimeout(resolve, delayBetweenMessages)); + } + } + + // Complete with all responses + try { + await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: `Received all responses: ${responses.join(', ')}` }] + }); + } catch { + // Task may have been cleaned up if test ended + } + } catch (error) { + // Handle errors + try { + await extra.taskStore.storeTaskResult(task.taskId, 'failed', { + content: [{ type: 'text', text: `Error: ${error}` }], + isError: true + }); + } catch { + // Task may have been cleaned up if test ended + } + } + })(); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + const receivedMessages: Array<{ message: string; timestamp: number }> = []; + let tasksResultStartTime = 0; + + // Set up elicitation handler to track when messages arrive + client.setRequestHandler(ElicitRequestSchema, async request => { + const timestamp = Date.now(); + receivedMessages.push({ + message: request.params.message, + timestamp + }); + + // Extract the message number + const match = request.params.message.match(/Streaming message (\d+) of (\d+)/); + const messageNum = match ? match[1] : 'unknown'; + + // Respond immediately + return { + action: 'accept' as const, + content: { + response: `Response ${messageNum}` + } + }; + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task that will send messages over time + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'streaming-task', + arguments: { + messageCount: 3, + delayBetweenMessages: 300 + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + + const taskId = createResult.task.taskId; + + // Verify task is in working status + let task = await client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task.status).toBe('working'); + + // Call tasks/result immediately (before messages are queued) + // This should block and deliver messages as they arrive + tasksResultStartTime = Date.now(); + const resultPromise = client.request( + { + method: 'tasks/result', + params: { taskId } + }, + CallToolResultSchema + ); + + // Wait for the task to complete and get the result + const result = await resultPromise; + + // Verify all 3 messages were delivered + expect(receivedMessages.length).toBe(3); + expect(receivedMessages[0].message).toBe('Streaming message 1 of 3'); + expect(receivedMessages[1].message).toBe('Streaming message 2 of 3'); + expect(receivedMessages[2].message).toBe('Streaming message 3 of 3'); + + // Verify messages were delivered over time (not all at once) + // The delay between messages should be approximately 300ms + const timeBetweenFirstAndSecond = receivedMessages[1].timestamp - receivedMessages[0].timestamp; + const timeBetweenSecondAndThird = receivedMessages[2].timestamp - receivedMessages[1].timestamp; + + // Allow some tolerance for timing (messages should be at least 200ms apart) + expect(timeBetweenFirstAndSecond).toBeGreaterThan(200); + expect(timeBetweenSecondAndThird).toBeGreaterThan(200); + + // Verify messages were delivered while tasks/result was blocking + // (all messages should arrive after tasks/result was called) + for (const msg of receivedMessages) { + expect(msg.timestamp).toBeGreaterThanOrEqual(tasksResultStartTime); + } + + // Verify final result is correct + expect(result.content).toEqual([{ type: 'text', text: 'Received all responses: Response 1, Response 2, Response 3' }]); + + // Verify task is now completed + task = await client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task.status).toBe('completed'); + + await transport.close(); + }, 15000); // Increase timeout to 15 seconds to allow for message delays + }); + + describe('Terminal Task with Queued Messages', () => { + it('should deliver queued messages followed by final result for terminal task', async () => { + // Register a tool that completes quickly and queues messages before completion + mcpServer.registerToolTask( + 'quick-complete-task', + { + title: 'Quick Complete Task', + description: 'A tool that queues messages and completes quickly', + inputSchema: { + messageCount: z.number().describe('Number of messages to queue').default(2) + } + }, + { + async createTask({ messageCount }, extra) { + const task = await extra.taskStore.createTask({ + ttl: 60000, + pollInterval: 100 + }); + + // Perform async work that queues messages and completes quickly + (async () => { + try { + // Queue messages without waiting for responses + const pendingRequests: Promise[] = []; + + for (let i = 0; i < messageCount; i++) { + const requestPromise = extra.sendRequest( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: `Quick message ${i + 1} of ${messageCount}`, + requestedSchema: { + type: 'object', + properties: { + response: { type: 'string' } + }, + required: ['response'] + } + } + }, + ElicitResultSchema, + { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions + ); + pendingRequests.push(requestPromise); + } + + // Complete the task immediately (before responses are received) + // This creates a terminal task with queued messages + try { + await extra.taskStore.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: 'Task completed quickly' }] + }); + } catch { + // Task may have been cleaned up if test ended + } + + // Wait for all responses in the background + await Promise.all(pendingRequests.map(p => p.catch(() => {}))); + } catch (error) { + // Handle errors + try { + await extra.taskStore.storeTaskResult(task.taskId, 'failed', { + content: [{ type: 'text', text: `Error: ${error}` }], + isError: true + }); + } catch { + // Task may have been cleaned up if test ended + } + } + })(); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + const receivedMessages: Array<{ type: string; message?: string; content?: unknown }> = []; + + // Set up elicitation handler to track message order + client.setRequestHandler(ElicitRequestSchema, async request => { + receivedMessages.push({ + type: 'elicitation', + message: request.params.message + }); + + // Extract the message number + const match = request.params.message.match(/Quick message (\d+) of (\d+)/); + const messageNum = match ? match[1] : 'unknown'; + + return { + action: 'accept' as const, + content: { + response: `Response ${messageNum}` + } + }; + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task that will complete quickly with queued messages + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'quick-complete-task', + arguments: { + messageCount: 2 + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + + const taskId = createResult.task.taskId; + + // Wait for task to complete and messages to be queued + await new Promise(resolve => setTimeout(resolve, 200)); + + // Verify task is in terminal status (completed) + const task = await client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ); + expect(task.status).toBe('completed'); + + // Call tasks/result - should deliver queued messages followed by final result + const result = await client.request( + { + method: 'tasks/result', + params: { taskId } + }, + CallToolResultSchema + ); + + // Verify all queued messages were delivered before the final result + expect(receivedMessages.length).toBe(2); + expect(receivedMessages[0].message).toBe('Quick message 1 of 2'); + expect(receivedMessages[1].message).toBe('Quick message 2 of 2'); + + // Verify final result is correct + expect(result.content).toEqual([{ type: 'text', text: 'Task completed quickly' }]); + + // Verify queue is cleaned up - calling tasks/result again should only return the result + receivedMessages.length = 0; // Clear the array + + const result2 = await client.request( + { + method: 'tasks/result', + params: { taskId } + }, + CallToolResultSchema + ); + + // No messages should be delivered on second call (queue was cleaned up) + expect(receivedMessages.length).toBe(0); + expect(result2.content).toEqual([{ type: 'text', text: 'Task completed quickly' }]); + + await transport.close(); + }, 10000); + }); + + describe('Concurrent Operations', () => { + it('should handle multiple concurrent task creations', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create multiple tasks concurrently + const promises = Array.from({ length: 5 }, () => + client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 500 + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ) + ); + + const results = await Promise.all(promises); + + // Verify all tasks were created with unique IDs + const taskIds = results.map(r => r.task.taskId); + expect(new Set(taskIds).size).toBe(5); + + // Verify all tasks are in working status + for (const result of results) { + expect(result.task.status).toBe('working'); + } + + await transport.close(); + }); + + it('should handle concurrent operations on same task', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Create a task + const createResult = await client.request( + { + method: 'tools/call', + params: { + name: 'long-task', + arguments: { + duration: 2000 + }, + task: { + ttl: 60000 + } + } + }, + CreateTaskResultSchema + ); + + const taskId = createResult.task.taskId; + + // Perform multiple concurrent gets + const getPromises = Array.from({ length: 5 }, () => + client.request( + { + method: 'tasks/get', + params: { taskId } + }, + TaskSchema + ) + ); + + const tasks = await Promise.all(getPromises); + + // All should return the same task + for (const task of tasks) { + expect(task.taskId).toBe(taskId); + expect(task.status).toBe('working'); + } + + await transport.close(); + }); + }); +}); diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 6b301a4e8..b1fb8a77a 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -4,7 +4,9 @@ import { InMemoryTransport } from '../inMemory.js'; import type { Transport } from '../shared/transport.js'; import { CreateMessageRequestSchema, + CreateMessageResultSchema, ElicitRequestSchema, + ElicitResultSchema, ElicitationCompleteNotificationSchema, ErrorCode, LATEST_PROTOCOL_VERSION, @@ -12,13 +14,18 @@ import { ListResourcesRequestSchema, ListToolsRequestSchema, type LoggingMessageNotification, + McpError, NotificationSchema, RequestSchema, ResultSchema, SetLevelRequestSchema, - SUPPORTED_PROTOCOL_VERSIONS + SUPPORTED_PROTOCOL_VERSIONS, + CreateTaskResultSchema } from '../types.js'; import { Server } from './index.js'; +import { McpServer } from './mcp.js'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../examples/shared/inMemoryTaskStore.js'; +import { CallToolRequestSchema, CallToolResultSchema } from '../types.js'; import type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator } from '../validation/types.js'; import type { AnyObjectSchema } from './zod-compat.js'; import * as z3 from 'zod/v3'; @@ -411,7 +418,7 @@ test('should respect client capabilities', async () => { ).resolves.not.toThrow(); // This should still throw because roots are not supported by the client - await expect(server.listRoots()).rejects.toThrow(/^Client does not support/); + await expect(server.listRoots()).rejects.toThrow(/Client does not support/); }); test('should respect client elicitation capabilities', async () => { @@ -578,7 +585,7 @@ test('should use elicitInput with mode: "form" by default for backwards compatib messages: [], maxTokens: 10 }) - ).rejects.toThrow(/^Client does not support/); + ).rejects.toThrow(/Client does not support/); }); test('should throw when elicitInput is called without client form capability', async () => { @@ -1451,8 +1458,8 @@ test('should handle server cancelling a request', async () => { ); controller.abort('Cancelled by test'); - // Request should be rejected - await expect(createMessagePromise).rejects.toBe('Cancelled by test'); + // Request should be rejected with an McpError + await expect(createMessagePromise).rejects.toThrow(McpError); }); test('should handle request timeout', async () => { @@ -1985,3 +1992,1064 @@ test('should respect log level for transport with sessionId', async () => { await server.sendLoggingMessage(warningParams, SESSION_ID); expect(clientTransport.onmessage).toHaveBeenCalled(); }); + +describe('Task-based execution', () => { + test('server with TaskStore should handle task-based tool execution', async () => { + const taskStore = new InMemoryTaskStore(); + + const server = new McpServer( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + // Register a tool using registerToolTask + server.registerToolTask( + 'test-tool', + { + description: 'A test tool', + inputSchema: {} + }, + { + async createTask(_args, extra) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + + // Simulate some async work + (async () => { + await new Promise(resolve => setTimeout(resolve, 10)); + const result = { + content: [{ type: 'text', text: 'Tool executed successfully!' }] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + })(); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Use callToolStream to create a task and capture the task ID + let taskId: string | undefined; + const stream = client.callToolStream({ name: 'test-tool', arguments: {} }, CallToolResultSchema, { + task: { + ttl: 60000 + } + }); + + for await (const message of stream) { + if (message.type === 'taskCreated') { + taskId = message.task.taskId; + } + } + + expect(taskId).toBeDefined(); + + // Wait for the task to complete + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify we can retrieve the task + const task = await client.getTask({ taskId: taskId! }); + expect(task).toBeDefined(); + expect(task.status).toBe('completed'); + + // Verify we can retrieve the result + const result = await client.getTaskResult({ taskId: taskId! }, CallToolResultSchema); + expect(result.content).toEqual([{ type: 'text', text: 'Tool executed successfully!' }]); + + // Cleanup + taskStore.cleanup(); + }); + + test('server without TaskStore should reject task-based requests', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + // No taskStore configured + } + ); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + return { + content: [{ type: 'text', text: 'Success!' }] + }; + } + throw new Error('Unknown tool'); + }); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + } + ] + })); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to get a task when server doesn't have TaskStore + // The server will return a "Method not found" error + await expect(client.getTask({ taskId: 'non-existent' })).rejects.toThrow('Method not found'); + }); + + test('should automatically attach related-task metadata to nested requests during tool execution', async () => { + const taskStore = new InMemoryTaskStore(); + + const server = new McpServer( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + } + } + ); + + // Track the elicitation request to verify related-task metadata + let capturedElicitRequest: z4.infer | null = null; + + // Set up client elicitation handler + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + let taskId: string | undefined; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + const createdTask = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + taskId = createdTask.taskId; + } + + // Capture the request to verify metadata later + capturedElicitRequest = request; + + return { + action: 'accept', + content: { + username: 'test-user' + } + }; + }); + + // Register a tool using registerToolTask that makes a nested elicitation request + server.registerToolTask( + 'collect-info', + { + description: 'Collects user info via elicitation', + inputSchema: {} + }, + { + async createTask(_args, extra) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + + // Perform async work that makes a nested request + (async () => { + // During tool execution, make a nested request to the client using extra.sendRequest + const elicitResult = await extra.sendRequest( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: 'Please provide your username', + requestedSchema: { + type: 'object', + properties: { + username: { type: 'string' } + }, + required: ['username'] + } + } + }, + ElicitResultSchema + ); + + const result = { + content: [ + { + type: 'text', + text: `Collected username: ${elicitResult.action === 'accept' && elicitResult.content ? (elicitResult.content as Record).username : 'none'}` + } + ] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + })(); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Call tool WITH task creation using callToolStream to capture task ID + let taskId: string | undefined; + const stream = client.callToolStream({ name: 'collect-info', arguments: {} }, CallToolResultSchema, { + task: { + ttl: 60000 + } + }); + + for await (const message of stream) { + if (message.type === 'taskCreated') { + taskId = message.task.taskId; + } + } + + expect(taskId).toBeDefined(); + + // Wait for completion + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify the nested elicitation request was made (related-task metadata is no longer automatically attached) + expect(capturedElicitRequest).toBeDefined(); + + // Verify tool result was correct + const result = await client.getTaskResult({ taskId: taskId! }, CallToolResultSchema); + expect(result.content).toEqual([ + { + type: 'text', + text: 'Collected username: test-user' + } + ]); + + // Cleanup + taskStore.cleanup(); + }); + + describe('Server calling client via elicitation', () => { + let clientTaskStore: InMemoryTaskStore; + + beforeEach(() => { + clientTaskStore = new InMemoryTaskStore(); + }); + + afterEach(() => { + clientTaskStore?.cleanup(); + }); + + test('should create task on client via elicitation', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + const result = { + action: 'accept', + content: { username: 'server-test-user', confirmed: true } + }; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; + } + + // Return ElicitResult for non-task requests + return result; + }); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Server creates task on client via elicitation + const createTaskResult = await server.request( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: 'Please provide your username', + requestedSchema: { + type: 'object', + properties: { + username: { type: 'string' }, + confirmed: { type: 'boolean' } + }, + required: ['username'] + } + } + }, + CreateTaskResultSchema, + { task: { ttl: 60000 } } + ); + + // Verify CreateTaskResult structure + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + const taskId = createTaskResult.task.taskId; + + // Verify task was created + const task = await server.getTask({ taskId }); + expect(task.status).toBe('completed'); + }); + + test('should query task from client using getTask', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + const result = { + action: 'accept', + content: { username: 'list-user' } + }; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; + } + + // Return ElicitResult for non-task requests + return result; + }); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create task + const createTaskResult = await server.request( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: 'Provide info', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } + } + }, + CreateTaskResultSchema, + { task: { ttl: 60000 } } + ); + + // Verify CreateTaskResult structure + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + const taskId = createTaskResult.task.taskId; + + // Query task + const task = await server.getTask({ taskId }); + expect(task).toBeDefined(); + expect(task.taskId).toBe(taskId); + expect(task.status).toBe('completed'); + }); + + test('should query task result from client using getTaskResult', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + const result = { + action: 'accept', + content: { username: 'result-user', confirmed: true } + }; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; + } + + // Return ElicitResult for non-task requests + return result; + }); + + const server = new Server({ + name: 'test-server', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create task + const createTaskResult = await server.request( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: 'Provide info', + requestedSchema: { + type: 'object', + properties: { + username: { type: 'string' }, + confirmed: { type: 'boolean' } + } + } + } + }, + CreateTaskResultSchema, + { task: { ttl: 60000 } } + ); + + // Verify CreateTaskResult structure + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + const taskId = createTaskResult.task.taskId; + + // Query result + const result = await server.getTaskResult({ taskId }, ElicitResultSchema); + expect(result.action).toBe('accept'); + expect(result.content).toEqual({ username: 'result-user', confirmed: true }); + }); + + test('should query task list from client using listTasks', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + const result = { + action: 'accept', + content: { username: 'list-user' } + }; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; + } + + // Return ElicitResult for non-task requests + return result; + }); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create multiple tasks + const createdTaskIds: string[] = []; + for (let i = 0; i < 2; i++) { + const createTaskResult = await server.request( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: 'Provide info', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } + } + }, + CreateTaskResultSchema, + { task: { ttl: 60000 } } + ); + + // Verify CreateTaskResult structure and capture taskId + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + createdTaskIds.push(createTaskResult.task.taskId); + } + + // Query task list + const taskList = await server.listTasks(); + expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); + for (const taskId of createdTaskIds) { + expect(taskList.tasks).toContainEqual( + expect.objectContaining({ + taskId, + status: 'completed' + }) + ); + } + }); + }); + + test('should handle multiple concurrent task-based tool calls', async () => { + const taskStore = new InMemoryTaskStore(); + + const server = new McpServer( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + // Register a tool using registerToolTask with variable delay + server.registerToolTask( + 'async-tool', + { + description: 'An async test tool', + inputSchema: { + delay: z4.number().optional().default(10), + taskNum: z4.number().optional() + } + }, + { + async createTask({ delay, taskNum }, extra) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + + // Simulate async work + (async () => { + await new Promise(resolve => setTimeout(resolve, delay)); + const result = { + content: [{ type: 'text', text: `Completed task ${taskNum || 'unknown'}` }] + }; + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + })(); + + return { task }; + }, + async getTask(_args, extra) { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error(`Task ${extra.taskId} not found`); + } + return task; + }, + async getTaskResult(_args, extra) { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as { content: Array<{ type: 'text'; text: string }> }; + } + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Create multiple tasks concurrently + const pendingRequests = Array.from({ length: 4 }, (_, index) => + client.callTool({ name: 'async-tool', arguments: { delay: 10 + index * 5, taskNum: index + 1 } }, CallToolResultSchema, { + task: { ttl: 60000 } + }) + ); + + // Wait for all tasks to complete + await Promise.all(pendingRequests); + + // Wait a bit more to ensure all tasks are completed + await new Promise(resolve => setTimeout(resolve, 50)); + + // Get all task IDs from the task list + const taskList = await client.listTasks(); + expect(taskList.tasks.length).toBeGreaterThanOrEqual(4); + const taskIds = taskList.tasks.map(t => t.taskId); + + // Verify all tasks completed successfully + for (let i = 0; i < taskIds.length; i++) { + const task = await client.getTask({ taskId: taskIds[i] }); + expect(task.status).toBe('completed'); + expect(task.taskId).toBe(taskIds[i]); + + const result = await client.getTaskResult({ taskId: taskIds[i] }, CallToolResultSchema); + expect(result.content).toEqual([{ type: 'text', text: `Completed task ${i + 1}` }]); + } + + // Verify listTasks returns all tasks + const finalTaskList = await client.listTasks(); + for (const taskId of taskIds) { + expect(finalTaskList.tasks).toContainEqual(expect.objectContaining({ taskId })); + } + + // Cleanup + taskStore.cleanup(); + }); + + describe('Error scenarios', () => { + let taskStore: InMemoryTaskStore; + let clientTaskStore: InMemoryTaskStore; + + beforeEach(() => { + taskStore = new InMemoryTaskStore(); + clientTaskStore = new InMemoryTaskStore(); + }); + + afterEach(() => { + taskStore?.cleanup(); + clientTaskStore?.cleanup(); + }); + + test('should throw error when client queries non-existent task from server', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to query a task that doesn't exist + await expect(client.getTask({ taskId: 'non-existent-task' })).rejects.toThrow(); + }); + + test('should throw error when server queries non-existent task from client', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { username: 'test' } + })); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Try to query a task that doesn't exist on client + await expect(server.getTask({ taskId: 'non-existent-task' })).rejects.toThrow(); + }); + }); +}); + +test('should respect client task capabilities', async () => { + const clientTaskStore = new InMemoryTaskStore(); + + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + sampling: {}, + elicitation: {}, + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + }, + taskStore: clientTaskStore + } + ); + + client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + const result = { + action: 'accept', + content: { username: 'test-user' } + }; + + // Check if task creation is requested + if (request.params.task && extra.taskStore) { + const task = await extra.taskStore.createTask({ + ttl: extra.taskRequestedTtl + }); + await extra.taskStore.storeTaskResult(task.taskId, 'completed', result); + // Return CreateTaskResult when task creation is requested + return { task }; + } + + // Return ElicitResult for non-task requests + return result; + }); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + }, + enforceStrictCapabilities: true + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Client supports task creation for elicitation/create and task methods + expect(server.getClientCapabilities()).toEqual({ + sampling: {}, + elicitation: { + form: {} + }, + tasks: { + requests: { + elicitation: { + create: {} + } + } + } + }); + + // These should work because client supports tasks + const createTaskResult = await server.request( + { + method: 'elicitation/create', + params: { + mode: 'form', + message: 'Test', + requestedSchema: { + type: 'object', + properties: { username: { type: 'string' } } + } + } + }, + CreateTaskResultSchema, + { task: { ttl: 60000 } } + ); + + // Verify CreateTaskResult structure + expect(createTaskResult.task).toBeDefined(); + expect(createTaskResult.task.taskId).toBeDefined(); + const taskId = createTaskResult.task.taskId; + + await expect(server.listTasks()).resolves.not.toThrow(); + await expect(server.getTask({ taskId })).resolves.not.toThrow(); + + // This should throw because client doesn't support task creation for sampling/createMessage + await expect( + server.request( + { + method: 'sampling/createMessage', + params: { + messages: [], + maxTokens: 10 + } + }, + CreateMessageResultSchema, + { task: { taskId: 'test-task-2', keepAlive: 60000 } } + ) + ).rejects.toThrow('Client does not support task creation for sampling/createMessage'); + + clientTaskStore.cleanup(); +}); diff --git a/src/server/index.ts b/src/server/index.ts index 975efb8e2..23061bf98 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -1,4 +1,5 @@ import { mergeCapabilities, Protocol, type NotificationOptions, type ProtocolOptions, type RequestOptions } from '../shared/protocol.js'; +import { ResponseMessage } from '../shared/responseMessage.js'; import { type ClientCapabilities, type CreateMessageRequest, @@ -32,10 +33,24 @@ import { SetLevelRequestSchema, SUPPORTED_PROTOCOL_VERSIONS, type ToolResultContent, - type ToolUseContent + type ToolUseContent, + CallToolRequestSchema, + CallToolResultSchema, + CreateTaskResultSchema } from '../types.js'; import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js'; import type { JsonSchemaType, jsonSchemaValidator } from '../validation/types.js'; +import { + AnyObjectSchema, + AnySchema, + getObjectShape, + isZ4Schema, + safeParse, + SchemaOutput, + type ZodV3Internal, + type ZodV4Internal +} from './zod-compat.js'; +import { RequestHandlerExtra } from '../shared/protocol.js'; export type ServerOptions = ProtocolOptions & { /** @@ -175,6 +190,87 @@ export class Server< this._capabilities = mergeCapabilities(this._capabilities, capabilities); } + /** + * Override request handler registration to enforce server-side validation for tools/call. + */ + public override setRequestHandler( + requestSchema: T, + handler: ( + request: SchemaOutput, + extra: RequestHandlerExtra + ) => ServerResult | ResultT | Promise + ): void { + const shape = getObjectShape(requestSchema); + const methodSchema = shape?.method; + if (!methodSchema) { + throw new Error('Schema is missing a method literal'); + } + + // Extract literal value using type-safe property access + let methodValue: unknown; + if (isZ4Schema(methodSchema)) { + const v4Schema = methodSchema as unknown as ZodV4Internal; + const v4Def = v4Schema._zod?.def; + methodValue = v4Def?.value ?? v4Schema.value; + } else { + const v3Schema = methodSchema as unknown as ZodV3Internal; + const legacyDef = v3Schema._def; + methodValue = legacyDef?.value ?? v3Schema.value; + } + + if (typeof methodValue !== 'string') { + throw new Error('Schema method literal must be a string'); + } + const method = methodValue; + + if (method === 'tools/call') { + const wrappedHandler = async ( + request: SchemaOutput, + extra: RequestHandlerExtra + ): Promise => { + const validatedRequest = safeParse(CallToolRequestSchema, request); + if (!validatedRequest.success) { + const errorMessage = + validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); + throw new McpError(ErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`); + } + + const { params } = validatedRequest.data; + + const result = await Promise.resolve(handler(request, extra)); + + // When task creation is requested, validate and return CreateTaskResult + if (params.task) { + const taskValidationResult = safeParse(CreateTaskResultSchema, result); + if (!taskValidationResult.success) { + const errorMessage = + taskValidationResult.error instanceof Error + ? taskValidationResult.error.message + : String(taskValidationResult.error); + throw new McpError(ErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); + } + return taskValidationResult.data; + } + + // For non-task requests, validate against CallToolResultSchema + const validationResult = safeParse(CallToolResultSchema, result); + if (!validationResult.success) { + const errorMessage = + validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); + throw new McpError(ErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`); + } + + return validationResult.data; + }; + + // Install the wrapped handler + return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler); + } + + // Other handlers use default behavior + return super.setRequestHandler(requestSchema, handler); + } + protected assertCapabilityForMethod(method: RequestT['method']): void { switch (method as ServerRequest['method']) { case 'sampling/createMessage': @@ -245,6 +341,12 @@ export class Server< } protected assertRequestHandlerCapability(method: string): void { + // Task handlers are registered in Protocol constructor before _capabilities is initialized + // Skip capability check for task methods during initialization + if (!this._capabilities) { + return; + } + switch (method) { case 'completion/complete': if (!this._capabilities.completions) { @@ -280,6 +382,15 @@ export class Server< } break; + case 'tasks/get': + case 'tasks/list': + case 'tasks/result': + case 'tasks/cancel': + if (!this._capabilities.tasks) { + throw new Error(`Server does not support tasks capability (required for ${method})`); + } + break; + case 'ping': case 'initialize': // No specific capability required for these methods @@ -287,6 +398,58 @@ export class Server< } } + protected assertTaskCapability(method: string): void { + if (!this._clientCapabilities?.tasks?.requests) { + throw new Error(`Client does not support task creation (required for ${method})`); + } + + const requests = this._clientCapabilities.tasks.requests; + + switch (method) { + case 'sampling/createMessage': + if (!requests.sampling?.createMessage) { + throw new Error(`Client does not support task creation for sampling/createMessage (required for ${method})`); + } + break; + + case 'elicitation/create': + if (!requests.elicitation?.create) { + throw new Error(`Client does not support task creation for elicitation/create (required for ${method})`); + } + break; + + default: + // Method doesn't support tasks, which is fine - no error + break; + } + } + + protected assertTaskHandlerCapability(method: string): void { + // Task handlers are registered in Protocol constructor before _capabilities is initialized + // Skip capability check for task methods during initialization + if (!this._capabilities) { + return; + } + + if (!this._capabilities.tasks?.requests) { + throw new Error(`Server does not support task creation (required for ${method})`); + } + + const requests = this._capabilities.tasks.requests; + + switch (method) { + case 'tools/call': + if (!requests.tools?.call) { + throw new Error(`Server does not support task creation for tools/call (required for ${method})`); + } + break; + + default: + // Method doesn't support tasks, which is fine - no error + break; + } + } + private async _oninitialize(request: InitializeRequest): Promise { const requestedVersion = request.params.protocolVersion; @@ -321,6 +484,47 @@ export class Server< return this._capabilities; } + /** + * Sends a request and returns an AsyncGenerator that yields response messages. + * The generator is guaranteed to end with either a 'result' or 'error' message. + * + * This method provides streaming access to request processing, allowing you to + * observe intermediate task status updates for task-augmented requests. + * + * @example + * ```typescript + * const stream = server.requestStream(request, resultSchema, options); + * for await (const message of stream) { + * switch (message.type) { + * case 'taskCreated': + * console.log('Task created:', message.task.taskId); + * break; + * case 'taskStatus': + * console.log('Task status:', message.task.status); + * break; + * case 'result': + * console.log('Final result:', message.result); + * break; + * case 'error': + * console.error('Error:', message.error); + * break; + * } + * } + * ``` + * + * @param request - The request to send + * @param resultSchema - Zod schema for validating the result + * @param options - Optional request options (timeout, signal, task creation params, etc.) + * @returns AsyncGenerator that yields ResponseMessage objects + */ + requestStream( + request: ServerRequest | RequestT, + resultSchema: T, + options?: RequestOptions + ): AsyncGenerator>, void, void> { + return super.requestStream(request, resultSchema, options); + } + async ping() { return this.request({ method: 'ping' }, EmptyResultSchema); } diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 776d0a129..bb25440ac 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -4,6 +4,7 @@ import { getDisplayName } from '../shared/metadataUtils.js'; import { UriTemplate } from '../shared/uriTemplate.js'; import { CallToolResultSchema, + type CallToolResult, CompleteResultSchema, ElicitRequestSchema, GetPromptResultSchema, @@ -20,8 +21,25 @@ import { } from '../types.js'; import { completable } from './completable.js'; import { McpServer, ResourceTemplate } from './mcp.js'; +import { InMemoryTaskStore } from '../examples/shared/inMemoryTaskStore.js'; import { zodTestMatrix, type ZodMatrixEntry } from '../shared/zodTestMatrix.js'; +function createLatch() { + let latch = false; + const waitForLatch = async () => { + while (!latch) { + await new Promise(resolve => setTimeout(resolve, 0)); + } + }; + + return { + releaseLatch: () => { + latch = true; + }, + waitForLatch + }; +} + describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const { z } = entry; @@ -3738,7 +3756,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Tool 1: Only name mcpServer.tool('tool_name_only', async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [{ type: 'text' as const, text: 'Response' }] })); // Tool 2: Name and annotations.title @@ -3749,7 +3767,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { title: 'Annotations Title' }, async () => ({ - content: [{ type: 'text', text: 'Response' }] + content: [{ type: 'text' as const, text: 'Response' }] }) ); @@ -4157,7 +4175,6 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { if (!available) { // Ask user if they want to try alternative dates const result = await mcpServer.server.elicitInput({ - mode: 'form', message: `No tables available at ${restaurant} on ${date}. Would you like to check alternative dates?`, requestedSchema: { type: 'object', @@ -4367,11 +4384,11 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { server.registerTool('contact', { inputSchema: unionSchema }, async args => { if (args.type === 'email') { return { - content: [{ type: 'text', text: `Email contact: ${args.email}` }] + content: [{ type: 'text' as const, text: `Email contact: ${args.email}` }] }; } else { return { - content: [{ type: 'text', text: `Phone contact: ${args.phone}` }] + content: [{ type: 'text' as const, text: `Phone contact: ${args.phone}` }] }; } }); @@ -4537,7 +4554,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { server.registerTool('union-test', { inputSchema: unionSchema }, async () => { return { - content: [{ type: 'text', text: 'Success' }] + content: [{ type: 'text' as const, text: 'Success' }] }; }); @@ -4562,24 +4579,1718 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }) ]) ); + }); + }); - const invalidDiscriminatorResult = await client.callTool({ - name: 'union-test', - arguments: { - type: 'c', - value: 'test' + describe('resource()', () => { + /*** + * Test: Resource Registration with URI and Read Callback + */ + test('should register resource with uri and readCallback', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource('test', 'test://resource', async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Test content' + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'resources/list' + }, + ListResourcesResultSchema + ); + + expect(result.resources).toHaveLength(1); + expect(result.resources[0].name).toBe('test'); + expect(result.resources[0].uri).toBe('test://resource'); + }); + + /*** + * Test: Update Resource with URI + */ + test('should update resource with uri', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register initial resource + const resource = mcpServer.resource('test', 'test://resource', async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Initial content' + } + ] + })); + + // Update the resource + resource.update({ + callback: async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Updated content' + } + ] + }) + }); + + // Updates before connection should not trigger notifications + expect(notifications).toHaveLength(0); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'resources/read', + params: { + uri: 'test://resource' + } + }, + ReadResourceResultSchema + ); + + expect(result.contents).toEqual([ + { + uri: 'test://resource', + text: 'Updated content' } + ]); + + // Now update again after connection + resource.update({ + callback: async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Another update' + } + ] + }) }); - expect(invalidDiscriminatorResult.isError).toBe(true); - expect(invalidDiscriminatorResult.content).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - type: 'text', - text: expect.stringContaining('Input validation error') + // Yield to event loop for notification to fly + await new Promise(process.nextTick); + + expect(notifications).toMatchObject([{ method: 'notifications/resources/list_changed' }]); + }); + + /*** + * Test: Resource Template Metadata Priority + */ + test('should prioritize individual resource metadata over template metadata', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{id}', { + list: async () => ({ + resources: [ + { + name: 'Resource 1', + uri: 'test://resource/1', + description: 'Individual resource description', + mimeType: 'text/plain' + }, + { + name: 'Resource 2', + uri: 'test://resource/2' + // This resource has no description or mimeType + } + ] }) - ]) + }), + { + description: 'Template description', + mimeType: 'application/json' + }, + async uri => ({ + contents: [ + { + uri: uri.href, + text: 'Test content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'resources/list' + }, + ListResourcesResultSchema + ); + + expect(result.resources).toHaveLength(2); + + // Resource 1 should have its own metadata + expect(result.resources[0].name).toBe('Resource 1'); + expect(result.resources[0].description).toBe('Individual resource description'); + expect(result.resources[0].mimeType).toBe('text/plain'); + + // Resource 2 should inherit template metadata + expect(result.resources[1].name).toBe('Resource 2'); + expect(result.resources[1].description).toBe('Template description'); + expect(result.resources[1].mimeType).toBe('application/json'); + }); + + /*** + * Test: Resource Template Metadata Overrides All Fields + */ + test('should allow resource to override all template metadata fields', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{id}', { + list: async () => ({ + resources: [ + { + name: 'Overridden Name', + uri: 'test://resource/1', + description: 'Overridden description', + mimeType: 'text/markdown' + // Add any other metadata fields if they exist + } + ] + }) + }), + { + title: 'Template Name', + description: 'Template description', + mimeType: 'application/json' + }, + async uri => ({ + contents: [ + { + uri: uri.href, + text: 'Test content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'resources/list' + }, + ListResourcesResultSchema ); + + expect(result.resources).toHaveLength(1); + + // All fields should be from the individual resource, not the template + expect(result.resources[0].name).toBe('Overridden Name'); + expect(result.resources[0].description).toBe('Overridden description'); + expect(result.resources[0].mimeType).toBe('text/markdown'); + }); + }); + + describe('Tool title precedence', () => { + test('should follow correct title precedence: title → annotations.title → name', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // Tool 1: Only name + mcpServer.tool('tool_name_only', async () => ({ + content: [{ type: 'text', text: 'Response' }] + })); + + // Tool 2: Name and annotations.title + mcpServer.tool( + 'tool_with_annotations_title', + 'Tool with annotations title', + { + title: 'Annotations Title' + }, + async () => ({ + content: [{ type: 'text', text: 'Response' }] + }) + ); + + // Tool 3: Name and title (using registerTool) + mcpServer.registerTool( + 'tool_with_title', + { + title: 'Regular Title', + description: 'Tool with regular title' + }, + async () => ({ + content: [{ type: 'text' as const, text: 'Response' }] + }) + ); + + // Tool 4: All three - title should win + mcpServer.registerTool( + 'tool_with_all_titles', + { + title: 'Regular Title Wins', + description: 'Tool with all titles', + annotations: { + title: 'Annotations Title Should Not Show' + } + }, + async () => ({ + content: [{ type: 'text' as const, text: 'Response' }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + + expect(result.tools).toHaveLength(4); + + // Tool 1: Only name - should display name + const tool1 = result.tools.find(t => t.name === 'tool_name_only'); + expect(tool1).toBeDefined(); + expect(getDisplayName(tool1!)).toBe('tool_name_only'); + + // Tool 2: Name and annotations.title - should display annotations.title + const tool2 = result.tools.find(t => t.name === 'tool_with_annotations_title'); + expect(tool2).toBeDefined(); + expect(tool2!.annotations?.title).toBe('Annotations Title'); + expect(getDisplayName(tool2!)).toBe('Annotations Title'); + + // Tool 3: Name and title - should display title + const tool3 = result.tools.find(t => t.name === 'tool_with_title'); + expect(tool3).toBeDefined(); + expect(tool3!.title).toBe('Regular Title'); + expect(getDisplayName(tool3!)).toBe('Regular Title'); + + // Tool 4: All three - title should take precedence + const tool4 = result.tools.find(t => t.name === 'tool_with_all_titles'); + expect(tool4).toBeDefined(); + expect(tool4!.title).toBe('Regular Title Wins'); + expect(tool4!.annotations?.title).toBe('Annotations Title Should Not Show'); + expect(getDisplayName(tool4!)).toBe('Regular Title Wins'); + }); + + test('getDisplayName unit tests for title precedence', () => { + // Test 1: Only name + expect(getDisplayName({ name: 'tool_name' })).toBe('tool_name'); + + // Test 2: Name and title - title wins + expect( + getDisplayName({ + name: 'tool_name', + title: 'Tool Title' + }) + ).toBe('Tool Title'); + + // Test 3: Name and annotations.title - annotations.title wins + expect( + getDisplayName({ + name: 'tool_name', + annotations: { title: 'Annotations Title' } + }) + ).toBe('Annotations Title'); + + // Test 4: All three - title wins (correct precedence) + expect( + getDisplayName({ + name: 'tool_name', + title: 'Regular Title', + annotations: { title: 'Annotations Title' } + }) + ).toBe('Regular Title'); + + // Test 5: Empty title should not be used + expect( + getDisplayName({ + name: 'tool_name', + title: '', + annotations: { title: 'Annotations Title' } + }) + ).toBe('Annotations Title'); + + // Test 6: Undefined vs null handling + expect( + getDisplayName({ + name: 'tool_name', + title: undefined, + annotations: { title: 'Annotations Title' } + }) + ).toBe('Annotations Title'); + }); + + test('should support resource template completion with resolved context', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.registerResource( + 'test', + new ResourceTemplate('github://repos/{owner}/{repo}', { + list: undefined, + complete: { + repo: (value, context) => { + if (context?.arguments?.['owner'] === 'org1') { + return ['project1', 'project2', 'project3'].filter(r => r.startsWith(value)); + } else if (context?.arguments?.['owner'] === 'org2') { + return ['repo1', 'repo2', 'repo3'].filter(r => r.startsWith(value)); + } + return []; + } + } + }), + { + title: 'GitHub Repository', + description: 'Repository information' + }, + async () => ({ + contents: [ + { + uri: 'github://repos/test/test', + text: 'Test content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Test with microsoft owner + const result1 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/resource', + uri: 'github://repos/{owner}/{repo}' + }, + argument: { + name: 'repo', + value: 'p' + }, + context: { + arguments: { + owner: 'org1' + } + } + } + }, + CompleteResultSchema + ); + + expect(result1.completion.values).toEqual(['project1', 'project2', 'project3']); + expect(result1.completion.total).toBe(3); + + // Test with facebook owner + const result2 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/resource', + uri: 'github://repos/{owner}/{repo}' + }, + argument: { + name: 'repo', + value: 'r' + }, + context: { + arguments: { + owner: 'org2' + } + } + } + }, + CompleteResultSchema + ); + + expect(result2.completion.values).toEqual(['repo1', 'repo2', 'repo3']); + expect(result2.completion.total).toBe(3); + + // Test with no resolved context + const result3 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/resource', + uri: 'github://repos/{owner}/{repo}' + }, + argument: { + name: 'repo', + value: 't' + } + } + }, + CompleteResultSchema + ); + + expect(result3.completion.values).toEqual([]); + expect(result3.completion.total).toBe(0); + }); + + test('should support prompt argument completion with resolved context', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.registerPrompt( + 'test-prompt', + { + title: 'Team Greeting', + description: 'Generate a greeting for team members', + argsSchema: { + department: completable(z.string(), value => { + return ['engineering', 'sales', 'marketing', 'support'].filter(d => d.startsWith(value)); + }), + name: completable(z.string(), (value, context) => { + const department = context?.arguments?.['department']; + if (department === 'engineering') { + return ['Alice', 'Bob', 'Charlie'].filter(n => n.startsWith(value)); + } else if (department === 'sales') { + return ['David', 'Eve', 'Frank'].filter(n => n.startsWith(value)); + } else if (department === 'marketing') { + return ['Grace', 'Henry', 'Iris'].filter(n => n.startsWith(value)); + } + return ['Guest'].filter(n => n.startsWith(value)); + }) + } + }, + async ({ department, name }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Hello ${name}, welcome to the ${department} team!` + } + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Test with engineering department + const result1 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/prompt', + name: 'test-prompt' + }, + argument: { + name: 'name', + value: 'A' + }, + context: { + arguments: { + department: 'engineering' + } + } + } + }, + CompleteResultSchema + ); + + expect(result1.completion.values).toEqual(['Alice']); + + // Test with sales department + const result2 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/prompt', + name: 'test-prompt' + }, + argument: { + name: 'name', + value: 'D' + }, + context: { + arguments: { + department: 'sales' + } + } + } + }, + CompleteResultSchema + ); + + expect(result2.completion.values).toEqual(['David']); + + // Test with marketing department + const result3 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/prompt', + name: 'test-prompt' + }, + argument: { + name: 'name', + value: 'G' + }, + context: { + arguments: { + department: 'marketing' + } + } + } + }, + CompleteResultSchema + ); + + expect(result3.completion.values).toEqual(['Grace']); + + // Test with no resolved context + const result4 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/prompt', + name: 'test-prompt' + }, + argument: { + name: 'name', + value: 'G' + } + } + }, + CompleteResultSchema + ); + + expect(result4.completion.values).toEqual(['Guest']); + }); + }); + + describe('elicitInput()', () => { + const checkAvailability = vi.fn().mockResolvedValue(false); + const findAlternatives = vi.fn().mockResolvedValue([]); + const makeBooking = vi.fn().mockResolvedValue('BOOKING-123'); + + let mcpServer: McpServer; + let client: Client; + + beforeEach(() => { + vi.clearAllMocks(); + + // Create server with restaurant booking tool + mcpServer = new McpServer({ + name: 'restaurant-booking-server', + version: '1.0.0' + }); + + // Register the restaurant booking tool from README example + mcpServer.tool( + 'book-restaurant', + { + restaurant: z.string(), + date: z.string(), + partySize: z.number() + }, + async ({ restaurant, date, partySize }) => { + // Check availability + const available = await checkAvailability(restaurant, date, partySize); + + if (!available) { + // Ask user if they want to try alternative dates + const result = await mcpServer.server.elicitInput({ + mode: 'form', + message: `No tables available at ${restaurant} on ${date}. Would you like to check alternative dates?`, + requestedSchema: { + type: 'object', + properties: { + checkAlternatives: { + type: 'boolean', + title: 'Check alternative dates', + description: 'Would you like me to check other dates?' + }, + flexibleDates: { + type: 'string', + title: 'Date flexibility', + description: 'How flexible are your dates?', + enum: ['next_day', 'same_week', 'next_week'], + enumNames: ['Next day', 'Same week', 'Next week'] + } + }, + required: ['checkAlternatives'] + } + }); + + if (result.action === 'accept' && result.content?.checkAlternatives) { + const alternatives = await findAlternatives( + restaurant, + date, + partySize, + result.content.flexibleDates as string + ); + return { + content: [ + { + type: 'text', + text: `Found these alternatives: ${alternatives.join(', ')}` + } + ] + }; + } + + return { + content: [ + { + type: 'text', + text: 'No booking made. Original date not available.' + } + ] + }; + } + + await makeBooking(restaurant, date, partySize); + return { + content: [ + { + type: 'text', + text: `Booked table for ${partySize} at ${restaurant} on ${date}` + } + ] + }; + } + ); + + // Create client with elicitation capability + client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + }); + + test('should successfully elicit additional information', async () => { + // Mock availability check to return false + checkAvailability.mockResolvedValue(false); + findAlternatives.mockResolvedValue(['2024-12-26', '2024-12-27', '2024-12-28']); + + // Set up client to accept alternative date checking + client.setRequestHandler(ElicitRequestSchema, async request => { + expect(request.params.message).toContain('No tables available at ABC Restaurant on 2024-12-25'); + return { + action: 'accept', + content: { + checkAlternatives: true, + flexibleDates: 'same_week' + } + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Call the tool + const result = await client.callTool({ + name: 'book-restaurant', + arguments: { + restaurant: 'ABC Restaurant', + date: '2024-12-25', + partySize: 2 + } + }); + + expect(checkAvailability).toHaveBeenCalledWith('ABC Restaurant', '2024-12-25', 2); + expect(findAlternatives).toHaveBeenCalledWith('ABC Restaurant', '2024-12-25', 2, 'same_week'); + expect(result.content).toEqual([ + { + type: 'text', + text: 'Found these alternatives: 2024-12-26, 2024-12-27, 2024-12-28' + } + ]); + }); + + test('should handle user declining to elicitation request', async () => { + // Mock availability check to return false + checkAvailability.mockResolvedValue(false); + + // Set up client to reject alternative date checking + client.setRequestHandler(ElicitRequestSchema, async () => { + return { + action: 'accept', + content: { + checkAlternatives: false + } + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Call the tool + const result = await client.callTool({ + name: 'book-restaurant', + arguments: { + restaurant: 'ABC Restaurant', + date: '2024-12-25', + partySize: 2 + } + }); + + expect(checkAvailability).toHaveBeenCalledWith('ABC Restaurant', '2024-12-25', 2); + expect(findAlternatives).not.toHaveBeenCalled(); + expect(result.content).toEqual([ + { + type: 'text', + text: 'No booking made. Original date not available.' + } + ]); + }); + + test('should handle user cancelling the elicitation', async () => { + // Mock availability check to return false + checkAvailability.mockResolvedValue(false); + + // Set up client to cancel the elicitation + client.setRequestHandler(ElicitRequestSchema, async () => { + return { + action: 'cancel' + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Call the tool + const result = await client.callTool({ + name: 'book-restaurant', + arguments: { + restaurant: 'ABC Restaurant', + date: '2024-12-25', + partySize: 2 + } + }); + + expect(checkAvailability).toHaveBeenCalledWith('ABC Restaurant', '2024-12-25', 2); + expect(findAlternatives).not.toHaveBeenCalled(); + expect(result.content).toEqual([ + { + type: 'text', + text: 'No booking made. Original date not available.' + } + ]); + }); + }); + + describe('Tools with union and intersection schemas', () => { + test('should support union schemas', async () => { + const server = new McpServer({ + name: 'test', + version: '1.0.0' + }); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const unionSchema = z.union([ + z.object({ type: z.literal('email'), email: z.string().email() }), + z.object({ type: z.literal('phone'), phone: z.string() }) + ]); + + server.registerTool('contact', { inputSchema: unionSchema }, async args => { + if (args.type === 'email') { + return { + content: [{ type: 'text', text: `Email contact: ${args.email}` }] + }; + } else { + return { + content: [{ type: 'text', text: `Phone contact: ${args.phone}` }] + }; + } + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await server.connect(serverTransport); + await client.connect(clientTransport); + + const emailResult = await client.callTool({ + name: 'contact', + arguments: { + type: 'email', + email: 'test@example.com' + } + }); + + expect(emailResult.content).toEqual([ + { + type: 'text', + text: 'Email contact: test@example.com' + } + ]); + + const phoneResult = await client.callTool({ + name: 'contact', + arguments: { + type: 'phone', + phone: '+1234567890' + } + }); + + expect(phoneResult.content).toEqual([ + { + type: 'text', + text: 'Phone contact: +1234567890' + } + ]); + }); + + test('should support intersection schemas', async () => { + const server = new McpServer({ + name: 'test', + version: '1.0.0' + }); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const baseSchema = z.object({ id: z.string() }); + const extendedSchema = z.object({ name: z.string(), age: z.number() }); + const intersectionSchema = z.intersection(baseSchema, extendedSchema); + + server.registerTool('user', { inputSchema: intersectionSchema }, async args => { + return { + content: [ + { + type: 'text', + text: `User: ${args.id}, ${args.name}, ${args.age} years old` + } + ] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await server.connect(serverTransport); + await client.connect(clientTransport); + + const result = await client.callTool({ + name: 'user', + arguments: { + id: '123', + name: 'John Doe', + age: 30 + } + }); + + expect(result.content).toEqual([ + { + type: 'text', + text: 'User: 123, John Doe, 30 years old' + } + ]); + }); + + test('should support complex nested schemas', async () => { + const server = new McpServer({ + name: 'test', + version: '1.0.0' + }); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const schema = z.object({ + items: z.array( + z.union([ + z.object({ type: z.literal('text'), content: z.string() }), + z.object({ type: z.literal('number'), value: z.number() }) + ]) + ) + }); + + server.registerTool('process', { inputSchema: schema }, async args => { + const processed = args.items.map(item => { + if (item.type === 'text') { + return item.content.toUpperCase(); + } else { + return item.value * 2; + } + }); + return { + content: [ + { + type: 'text', + text: `Processed: ${processed.join(', ')}` + } + ] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await server.connect(serverTransport); + await client.connect(clientTransport); + + const result = await client.callTool({ + name: 'process', + arguments: { + items: [ + { type: 'text', content: 'hello' }, + { type: 'number', value: 5 }, + { type: 'text', content: 'world' } + ] + } + }); + + expect(result.content).toEqual([ + { + type: 'text', + text: 'Processed: HELLO, 10, WORLD' + } + ]); + }); + + test('should validate union schema inputs correctly', async () => { + const server = new McpServer({ + name: 'test', + version: '1.0.0' + }); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const unionSchema = z.union([ + z.object({ type: z.literal('a'), value: z.string() }), + z.object({ type: z.literal('b'), value: z.number() }) + ]); + + server.registerTool('union-test', { inputSchema: unionSchema }, async () => { + return { + content: [{ type: 'text', text: 'Success' }] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await server.connect(serverTransport); + await client.connect(clientTransport); + + const invalidTypeResult = await client.callTool({ + name: 'union-test', + arguments: { + type: 'a', + value: 123 + } + }); + + expect(invalidTypeResult.isError).toBe(true); + expect(invalidTypeResult.content).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + type: 'text', + text: expect.stringContaining('Input validation error') + }) + ]) + ); + + const invalidDiscriminatorResult = await client.callTool({ + name: 'union-test', + arguments: { + type: 'c', + value: 'test' + } + }); + + expect(invalidDiscriminatorResult.isError).toBe(true); + expect(invalidDiscriminatorResult.content).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + type: 'text', + text: expect.stringContaining('Input validation error') + }) + ]) + ); + }); + }); + + describe('Tool-level task hints with automatic polling wrapper', () => { + test('should return error for tool with taskSupport "required" called without task augmentation', async () => { + const taskStore = new InMemoryTaskStore(); + + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + // Register a task-based tool with taskSupport "required" + mcpServer.registerToolTask( + 'long-running-task', + { + description: 'A long running task', + inputSchema: { + input: z.string() + }, + execution: { + taskSupport: 'required' + } + }, + { + createTask: async ({ input }, extra) => { + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); + + // Capture taskStore for use in setTimeout + const store = extra.taskStore; + + // Simulate async work + setTimeout(async () => { + await store.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text' as const, text: `Processed: ${input}` }] + }); + }, 200); + + return { task }; + }, + getTask: async (_args, extra) => { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error('Task not found'); + } + return task; + }, + getTaskResult: async (_input, extra) => { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as CallToolResult; + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the tool WITHOUT task augmentation - should return error + const result = await client.callTool( + { + name: 'long-running-task', + arguments: { input: 'test data' } + }, + CallToolResultSchema + ); + + // Should receive error result + expect(result.isError).toBe(true); + const content = result.content as TextContent[]; + expect(content[0].text).toContain('requires task augmentation'); + + taskStore.cleanup(); + }); + + test('should automatically poll and return CallToolResult for tool with taskSupport "optional" called without task augmentation', async () => { + const taskStore = new InMemoryTaskStore(); + const { releaseLatch, waitForLatch } = createLatch(); + + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + // Register a task-based tool with taskSupport "optional" + mcpServer.registerToolTask( + 'optional-task', + { + description: 'An optional task', + inputSchema: { + value: z.number() + }, + execution: { + taskSupport: 'optional' + } + }, + { + createTask: async ({ value }, extra) => { + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); + + // Capture taskStore for use in setTimeout + const store = extra.taskStore; + + // Simulate async work + setTimeout(async () => { + await store.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text' as const, text: `Result: ${value * 2}` }] + }); + releaseLatch(); + }, 150); + + return { task }; + }, + getTask: async (_args, extra) => { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error('Task not found'); + } + return task; + }, + getTaskResult: async (_value, extra) => { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as CallToolResult; + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the tool WITHOUT task augmentation + const result = await client.callTool( + { + name: 'optional-task', + arguments: { value: 21 } + }, + CallToolResultSchema + ); + + // Should receive CallToolResult directly, not CreateTaskResult + expect(result).toHaveProperty('content'); + expect(result.content).toEqual([{ type: 'text' as const, text: 'Result: 42' }]); + expect(result).not.toHaveProperty('task'); + + // Wait for async operations to complete + await waitForLatch(); + taskStore.cleanup(); + }); + + test('should return CreateTaskResult when tool with taskSupport "required" is called WITH task augmentation', async () => { + const taskStore = new InMemoryTaskStore(); + const { releaseLatch, waitForLatch } = createLatch(); + + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + // Register a task-based tool with taskSupport "required" + mcpServer.registerToolTask( + 'task-tool', + { + description: 'A task tool', + inputSchema: { + data: z.string() + }, + execution: { + taskSupport: 'required' + } + }, + { + createTask: async ({ data }, extra) => { + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); + + // Capture taskStore for use in setTimeout + const store = extra.taskStore; + + // Simulate async work + setTimeout(async () => { + await store.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text' as const, text: `Completed: ${data}` }] + }); + releaseLatch(); + }, 200); + + return { task }; + }, + getTask: async (_args, extra) => { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error('Task not found'); + } + return task; + }, + getTaskResult: async (_data, extra) => { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as CallToolResult; + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the tool WITH task augmentation + const result = await client.request( + { + method: 'tools/call', + params: { + name: 'task-tool', + arguments: { data: 'test' }, + task: { ttl: 60000 } + } + }, + z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.union([z.number(), z.null()]), + createdAt: z.string(), + pollInterval: z.number().optional() + }) + }) + ); + + // Should receive CreateTaskResult with task field + expect(result).toHaveProperty('task'); + expect(result.task).toHaveProperty('taskId'); + expect(result.task.status).toBe('working'); + + // Wait for async operations to complete + await waitForLatch(); + taskStore.cleanup(); + }); + + test('should handle task failures during automatic polling', async () => { + const taskStore = new InMemoryTaskStore(); + const { releaseLatch, waitForLatch } = createLatch(); + + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + // Register a task-based tool that fails + mcpServer.registerToolTask( + 'failing-task', + { + description: 'A failing task', + execution: { + taskSupport: 'optional' + } + }, + { + createTask: async extra => { + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); + + // Capture taskStore for use in setTimeout + const store = extra.taskStore; + + // Simulate async failure + setTimeout(async () => { + await store.storeTaskResult(task.taskId, 'failed', { + content: [{ type: 'text' as const, text: 'Error occurred' }], + isError: true + }); + releaseLatch(); + }, 150); + + return { task }; + }, + getTask: async extra => { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error('Task not found'); + } + return task; + }, + getTaskResult: async extra => { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as CallToolResult; + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the tool WITHOUT task augmentation + const result = await client.callTool( + { + name: 'failing-task', + arguments: {} + }, + CallToolResultSchema + ); + + // Should receive the error result + expect(result).toHaveProperty('content'); + expect(result.content).toEqual([{ type: 'text' as const, text: 'Error occurred' }]); + expect(result.isError).toBe(true); + + // Wait for async operations to complete + await waitForLatch(); + taskStore.cleanup(); + }); + + test('should handle task cancellation during automatic polling', async () => { + const taskStore = new InMemoryTaskStore(); + const { releaseLatch, waitForLatch } = createLatch(); + + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + tasks: { + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + // Register a task-based tool that gets cancelled + mcpServer.registerToolTask( + 'cancelled-task', + { + description: 'A task that gets cancelled', + execution: { + taskSupport: 'optional' + } + }, + { + createTask: async extra => { + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); + + // Capture taskStore for use in setTimeout + const store = extra.taskStore; + + // Simulate async cancellation + setTimeout(async () => { + await store.updateTaskStatus(task.taskId, 'cancelled', 'Task was cancelled'); + releaseLatch(); + }, 150); + + return { task }; + }, + getTask: async extra => { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error('Task not found'); + } + return task; + }, + getTaskResult: async extra => { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as CallToolResult; + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the tool WITHOUT task augmentation + const result = await client.callTool( + { + name: 'cancelled-task', + arguments: {} + }, + CallToolResultSchema + ); + + // Should receive an error since cancelled tasks don't have results + expect(result).toHaveProperty('content'); + expect(result.content).toEqual([{ type: 'text' as const, text: expect.stringContaining('has no result stored') }]); + + // Wait for async operations to complete + await waitForLatch(); + taskStore.cleanup(); + }); + + test('should raise error when registerToolTask is called with taskSupport "forbidden"', () => { + const taskStore = new InMemoryTaskStore(); + + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { + call: {} + } + } + } + }, + taskStore + } + ); + + // Attempt to register a task-based tool with taskSupport "forbidden" (cast to bypass type checking) + expect(() => { + mcpServer.registerToolTask( + 'invalid-task', + { + description: 'A task with forbidden support', + inputSchema: { + input: z.string() + }, + execution: { + taskSupport: 'forbidden' as unknown as 'required' + } + }, + { + createTask: async (_args, extra) => { + const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 }); + return { task }; + }, + getTask: async (_args, extra) => { + const task = await extra.taskStore.getTask(extra.taskId); + if (!task) { + throw new Error('Task not found'); + } + return task; + }, + getTaskResult: async (_args, extra) => { + const result = await extra.taskStore.getTaskResult(extra.taskId); + return result as CallToolResult; + } + } + ); + }).toThrow(); + + taskStore.cleanup(); }); }); }); diff --git a/src/server/mcp.ts b/src/server/mcp.ts index b9b6d5596..11707dc14 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -45,15 +45,21 @@ import { ServerNotification, ToolAnnotations, LoggingMessageNotification, + CreateTaskResult, + GetTaskResult, + Result, CompleteRequestPrompt, CompleteRequestResourceTemplate, assertCompleteRequestPrompt, - assertCompleteRequestResourceTemplate + assertCompleteRequestResourceTemplate, + CallToolRequest, + ToolExecution } from '../types.js'; import { isCompletable, getCompleter } from './completable.js'; import { UriTemplate, Variables } from '../shared/uriTemplate.js'; -import { RequestHandlerExtra } from '../shared/protocol.js'; +import { RequestHandlerExtra, RequestTaskStore } from '../shared/protocol.js'; import { Transport } from '../shared/transport.js'; + import { validateAndWarnToolName } from '../shared/toolNameValidation.js'; /** @@ -148,60 +154,53 @@ export class McpServer { }) ); - this.server.setRequestHandler(CallToolRequestSchema, async (request, extra): Promise => { - const tool = this._registeredTools[request.params.name]; - - let result: CallToolResult; - + this.server.setRequestHandler(CallToolRequestSchema, async (request, extra): Promise => { try { + const tool = this._registeredTools[request.params.name]; if (!tool) { throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} not found`); } - if (!tool.enabled) { throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} disabled`); } - if (tool.inputSchema) { - const cb = tool.callback as ToolCallback; - // Try to normalize to object schema first (for raw shapes and object schemas) - // If that fails, use the schema directly (for union/intersection/etc) - const inputObj = normalizeObjectSchema(tool.inputSchema); - const schemaToParse = inputObj ?? (tool.inputSchema as AnySchema); - const parseResult = await safeParseAsync(schemaToParse, request.params.arguments); - if (!parseResult.success) { - throw new McpError( - ErrorCode.InvalidParams, - `Input validation error: Invalid arguments for tool ${request.params.name}: ${getParseErrorMessage(parseResult.error)}` - ); - } + const isTaskRequest = !!request.params.task; + const taskSupport = tool.execution?.taskSupport; + const isTaskHandler = 'createTask' in (tool.handler as AnyToolHandler); - const args = parseResult.data; + // Validate task hint configuration + if ((taskSupport === 'required' || taskSupport === 'optional') && !isTaskHandler) { + throw new McpError( + ErrorCode.InternalError, + `Tool ${request.params.name} has taskSupport '${taskSupport}' but was not registered with registerToolTask` + ); + } - result = await Promise.resolve(cb(args, extra)); - } else { - const cb = tool.callback as ToolCallback; - result = await Promise.resolve(cb(extra)); + // Handle taskSupport 'required' without task augmentation + if (taskSupport === 'required' && !isTaskRequest) { + throw new McpError( + ErrorCode.MethodNotFound, + `Tool ${request.params.name} requires task augmentation (taskSupport: 'required')` + ); } - if (tool.outputSchema && !result.isError) { - if (!result.structuredContent) { - throw new McpError( - ErrorCode.InvalidParams, - `Output validation error: Tool ${request.params.name} has an output schema but no structured content was provided` - ); - } + // Handle taskSupport 'optional' without task augmentation - automatic polling + if (taskSupport === 'optional' && !isTaskRequest && isTaskHandler) { + return await this.handleAutomaticTaskPolling(tool, request, extra); + } - // if the tool has an output schema, validate structured content - const outputObj = normalizeObjectSchema(tool.outputSchema) as AnyObjectSchema; - const parseResult = await safeParseAsync(outputObj, result.structuredContent); - if (!parseResult.success) { - throw new McpError( - ErrorCode.InvalidParams, - `Output validation error: Invalid structured content for tool ${request.params.name}: ${getParseErrorMessage(parseResult.error)}` - ); - } + // Normal execution path + const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); + const result = await this.executeToolHandler(tool, args, extra); + + // Return CreateTaskResult immediately for task requests + if (isTaskRequest) { + return result; } + + // Validate output schema for non-task requests + await this.validateToolOutput(tool, result, request.params.name); + return result; } catch (error) { if (error instanceof McpError) { if (error.code === ErrorCode.UrlElicitationRequired) { @@ -210,8 +209,6 @@ export class McpServer { } return this.createToolError(error instanceof Error ? error.message : String(error)); } - - return result; }); this._toolHandlersInitialized = true; @@ -235,6 +232,151 @@ export class McpServer { }; } + /** + * Validates tool input arguments against the tool's input schema. + */ + private async validateToolInput< + Tool extends RegisteredTool, + Args extends Tool['inputSchema'] extends infer InputSchema + ? InputSchema extends AnySchema + ? SchemaOutput + : undefined + : undefined + >(tool: Tool, args: Args, toolName: string): Promise { + if (!tool.inputSchema) { + return undefined as Args; + } + + // Try to normalize to object schema first (for raw shapes and object schemas) + // If that fails, use the schema directly (for union/intersection/etc) + const inputObj = normalizeObjectSchema(tool.inputSchema); + const schemaToParse = inputObj ?? (tool.inputSchema as AnySchema); + const parseResult = await safeParseAsync(schemaToParse, args); + if (!parseResult.success) { + const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; + const errorMessage = getParseErrorMessage(error); + throw new McpError(ErrorCode.InvalidParams, `Input validation error: Invalid arguments for tool ${toolName}: ${errorMessage}`); + } + + return parseResult.data as unknown as Args; + } + + /** + * Validates tool output against the tool's output schema. + */ + private async validateToolOutput(tool: RegisteredTool, result: CallToolResult | CreateTaskResult, toolName: string): Promise { + if (!tool.outputSchema) { + return; + } + + // Only validate CallToolResult, not CreateTaskResult + if (!('content' in result)) { + return; + } + + if (result.isError) { + return; + } + + if (!result.structuredContent) { + throw new McpError( + ErrorCode.InvalidParams, + `Output validation error: Tool ${toolName} has an output schema but no structured content was provided` + ); + } + + // if the tool has an output schema, validate structured content + const outputObj = normalizeObjectSchema(tool.outputSchema) as AnyObjectSchema; + const parseResult = await safeParseAsync(outputObj, result.structuredContent); + if (!parseResult.success) { + const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; + const errorMessage = getParseErrorMessage(error); + throw new McpError( + ErrorCode.InvalidParams, + `Output validation error: Invalid structured content for tool ${toolName}: ${errorMessage}` + ); + } + } + + /** + * Executes a tool handler (either regular or task-based). + */ + private async executeToolHandler( + tool: RegisteredTool, + args: unknown, + extra: RequestHandlerExtra + ): Promise { + const handler = tool.handler as AnyToolHandler; + const isTaskHandler = 'createTask' in handler; + + if (isTaskHandler) { + if (!extra.taskStore) { + throw new Error('No task store provided.'); + } + const taskExtra = { ...extra, taskStore: extra.taskStore }; + + if (tool.inputSchema) { + const typedHandler = handler as ToolTaskHandler; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return await Promise.resolve(typedHandler.createTask(args as any, taskExtra)); + } else { + const typedHandler = handler as ToolTaskHandler; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return await Promise.resolve((typedHandler.createTask as any)(taskExtra)); + } + } + + if (tool.inputSchema) { + const typedHandler = handler as ToolCallback; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return await Promise.resolve(typedHandler(args as any, extra)); + } else { + const typedHandler = handler as ToolCallback; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return await Promise.resolve((typedHandler as any)(extra)); + } + } + + /** + * Handles automatic task polling for tools with taskSupport 'optional'. + */ + private async handleAutomaticTaskPolling( + tool: RegisteredTool, + request: RequestT, + extra: RequestHandlerExtra + ): Promise { + if (!extra.taskStore) { + throw new Error('No task store provided for task-capable tool.'); + } + + // Validate input and create task + const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); + const handler = tool.handler as ToolTaskHandler; + const taskExtra = { ...extra, taskStore: extra.taskStore }; + + const createTaskResult: CreateTaskResult = args // undefined only if tool.inputSchema is undefined + ? await Promise.resolve((handler as ToolTaskHandler).createTask(args, taskExtra)) + : // eslint-disable-next-line @typescript-eslint/no-explicit-any + await Promise.resolve(((handler as ToolTaskHandler).createTask as any)(taskExtra)); + + // Poll until completion + const taskId = createTaskResult.task.taskId; + let task = createTaskResult.task; + const pollInterval = task.pollInterval ?? 5000; + + while (task.status !== 'completed' && task.status !== 'failed' && task.status !== 'cancelled') { + await new Promise(resolve => setTimeout(resolve, pollInterval)); + const updatedTask = await extra.taskStore.getTask(taskId); + if (!updatedTask) { + throw new McpError(ErrorCode.InternalError, `Task ${taskId} not found during polling`); + } + task = updatedTask; + } + + // Return the final result + return (await extra.taskStore.getTaskResult(taskId)) as CallToolResult; + } + private _completionHandlerInitialized = false; private setCompletionRequestHandler() { @@ -447,10 +589,9 @@ export class McpServer { const argsObj = normalizeObjectSchema(prompt.argsSchema) as AnyObjectSchema; const parseResult = await safeParseAsync(argsObj, request.params.arguments); if (!parseResult.success) { - throw new McpError( - ErrorCode.InvalidParams, - `Invalid arguments for prompt ${request.params.name}: ${getParseErrorMessage(parseResult.error)}` - ); + const error = 'error' in parseResult ? parseResult.error : 'Unknown error'; + const errorMessage = getParseErrorMessage(error); + throw new McpError(ErrorCode.InvalidParams, `Invalid arguments for prompt ${request.params.name}: ${errorMessage}`); } const args = parseResult.data; @@ -458,7 +599,8 @@ export class McpServer { return await Promise.resolve(cb(args, extra)); } else { const cb = prompt.callback as PromptCallback; - return await Promise.resolve(cb(extra)); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return await Promise.resolve((cb as any)(extra)); } }); @@ -697,8 +839,9 @@ export class McpServer { inputSchema: ZodRawShapeCompat | AnySchema | undefined, outputSchema: ZodRawShapeCompat | AnySchema | undefined, annotations: ToolAnnotations | undefined, + execution: ToolExecution | undefined, _meta: Record | undefined, - callback: ToolCallback + handler: AnyToolHandler ): RegisteredTool { // Validate tool name according to SEP specification validateAndWarnToolName(name); @@ -709,8 +852,9 @@ export class McpServer { inputSchema: getZodSchemaObject(inputSchema), outputSchema: getZodSchemaObject(outputSchema), annotations, + execution, _meta, - callback, + handler: handler, enabled: true, disable: () => registeredTool.update({ enabled: false }), enable: () => registeredTool.update({ enabled: true }), @@ -726,7 +870,7 @@ export class McpServer { if (typeof updates.title !== 'undefined') registeredTool.title = updates.title; if (typeof updates.description !== 'undefined') registeredTool.description = updates.description; if (typeof updates.paramsSchema !== 'undefined') registeredTool.inputSchema = objectFromShape(updates.paramsSchema); - if (typeof updates.callback !== 'undefined') registeredTool.callback = updates.callback; + if (typeof updates.callback !== 'undefined') registeredTool.handler = updates.callback; if (typeof updates.annotations !== 'undefined') registeredTool.annotations = updates.annotations; if (typeof updates._meta !== 'undefined') registeredTool._meta = updates._meta; if (typeof updates.enabled !== 'undefined') registeredTool.enabled = updates.enabled; @@ -758,7 +902,7 @@ export class McpServer { * This unified overload handles both `tool(name, paramsSchema, cb)` and `tool(name, annotations, cb)` cases. * * Note: We use a union type for the second parameter because TypeScript cannot reliably disambiguate - * between ToolAnnotations and ZodRawShape during overload resolution, as both are plain object types. + * between ToolAnnotations and ZodRawShapeCompat during overload resolution, as both are plain object types. * @deprecated Use `registerTool` instead. */ tool( @@ -773,7 +917,7 @@ export class McpServer { * `tool(name, description, annotations, cb)` cases. * * Note: We use a union type for the third parameter because TypeScript cannot reliably disambiguate - * between ToolAnnotations and ZodRawShape during overload resolution, as both are plain object types. + * between ToolAnnotations and ZodRawShapeCompat during overload resolution, as both are plain object types. * @deprecated Use `registerTool` instead. */ tool( @@ -832,18 +976,18 @@ export class McpServer { // We have at least one more arg before the callback const firstArg = rest[0]; - if (isZodRawShape(firstArg)) { + if (isZodRawShapeCompat(firstArg)) { // We have a params schema as the first arg inputSchema = rest.shift() as ZodRawShapeCompat; // Check if the next arg is potentially annotations - if (rest.length > 1 && typeof rest[0] === 'object' && rest[0] !== null && !isZodRawShape(rest[0])) { + if (rest.length > 1 && typeof rest[0] === 'object' && rest[0] !== null && !isZodRawShapeCompat(rest[0])) { // Case: tool(name, paramsSchema, annotations, cb) // Or: tool(name, description, paramsSchema, annotations, cb) annotations = rest.shift() as ToolAnnotations; } } else if (typeof firstArg === 'object' && firstArg !== null) { - // Not a ZodRawShape, so must be annotations in this position + // Not a ZodRawShapeCompat, so must be annotations in this position // Case: tool(name, annotations, cb) // Or: tool(name, description, annotations, cb) annotations = rest.shift() as ToolAnnotations; @@ -851,7 +995,17 @@ export class McpServer { } const callback = rest[0] as ToolCallback; - return this._createRegisteredTool(name, undefined, description, inputSchema, outputSchema, annotations, undefined, callback); + return this._createRegisteredTool( + name, + undefined, + description, + inputSchema, + outputSchema, + annotations, + { taskSupport: 'forbidden' }, + undefined, + callback + ); } /** @@ -882,11 +1036,80 @@ export class McpServer { inputSchema, outputSchema, annotations, + { taskSupport: 'forbidden' }, _meta, cb as ToolCallback ); } + /** + * Registers a task-based tool with a config object and callback. + */ + registerToolTask( + name: string, + config: { + title?: string; + description?: string; + outputSchema?: OutputArgs; + annotations?: ToolAnnotations; + execution?: TaskToolExecution; + _meta?: Record; + }, + handler: ToolTaskHandler + ): RegisteredTool; + + /** + * Registers a task-based tool with a config object and callback. + */ + registerToolTask( + name: string, + config: { + title?: string; + description?: string; + inputSchema: InputArgs; + outputSchema?: OutputArgs; + annotations?: ToolAnnotations; + execution?: TaskToolExecution; + _meta?: Record; + }, + handler: ToolTaskHandler + ): RegisteredTool; + + registerToolTask< + InputArgs extends undefined | ZodRawShapeCompat | AnySchema, + OutputArgs extends undefined | ZodRawShapeCompat | AnySchema + >( + name: string, + config: { + title?: string; + description?: string; + inputSchema?: InputArgs; + outputSchema?: OutputArgs; + annotations?: ToolAnnotations; + execution?: TaskToolExecution; + _meta?: Record; + }, + handler: ToolTaskHandler + ): RegisteredTool { + // Validate that taskSupport is not 'forbidden' for task-based tools + const execution: ToolExecution = { taskSupport: 'required', ...config.execution }; + if (execution.taskSupport === 'forbidden') { + throw new Error(`Cannot register task-based tool '${name}' with taskSupport 'forbidden'. Use registerTool() instead.`); + } + + return this._createRegisteredTool( + name, + config.title, + config.description, + config.inputSchema, + config.outputSchema, + config.annotations, + execution, + config._meta, + handler + ); + } + /** * Registers a zero-argument prompt `name`, which will run the given function when the client calls it. * @deprecated Use `registerPrompt` instead. @@ -1076,6 +1299,16 @@ export class ResourceTemplate { } } +export type BaseToolCallback< + SendResultT extends Result, + Extra extends RequestHandlerExtra, + Args extends undefined | ZodRawShapeCompat | AnySchema +> = Args extends ZodRawShapeCompat + ? (args: ShapeOutput, extra: Extra) => SendResultT | Promise + : Args extends AnySchema + ? (args: SchemaOutput, extra: Extra) => SendResultT | Promise + : (extra: Extra) => SendResultT | Promise; + /** * Callback for a tool handler registered with Server.tool(). * @@ -1086,14 +1319,52 @@ export class ResourceTemplate { * - `content` if the tool does not have an outputSchema * - Both fields are optional but typically one should be provided */ -export type ToolCallback = Args extends ZodRawShapeCompat - ? (args: ShapeOutput, extra: RequestHandlerExtra) => CallToolResult | Promise - : Args extends AnySchema - ? ( - args: SchemaOutput, - extra: RequestHandlerExtra - ) => CallToolResult | Promise - : (extra: RequestHandlerExtra) => CallToolResult | Promise; +export type ToolCallback = BaseToolCallback< + CallToolResult, + RequestHandlerExtra, + Args +>; + +export interface CreateTaskRequestHandlerExtra extends RequestHandlerExtra { + taskStore: RequestTaskStore; +} + +export interface TaskRequestHandlerExtra extends RequestHandlerExtra { + taskId: string; + taskStore: RequestTaskStore; +} + +export type CreateTaskRequestHandler< + SendResultT extends Result, + Args extends undefined | ZodRawShapeCompat | AnySchema = undefined +> = BaseToolCallback; + +export type TaskRequestHandler< + SendResultT extends Result, + Args extends undefined | ZodRawShapeCompat | AnySchema = undefined +> = BaseToolCallback; + +export interface ToolTaskHandler { + createTask: CreateTaskRequestHandler; + getTask: TaskRequestHandler; + getTaskResult: TaskRequestHandler; +} + +/** + * Supertype for tool handler callbacks registered with Server.registerTool() and Server.registerToolTask(). + */ +export type AnyToolCallback = + | ToolCallback + | TaskRequestHandler; + +/** + * Supertype that can handle both regular tools (simple callback) and task-based tools (task handler object). + */ +export type AnyToolHandler = ToolCallback | ToolTaskHandler; + +export type TaskToolExecution = Omit & { + taskSupport: TaskSupport extends 'forbidden' | undefined ? never : TaskSupport; +}; export type RegisteredTool = { title?: string; @@ -1101,8 +1372,9 @@ export type RegisteredTool = { inputSchema?: AnySchema; outputSchema?: AnySchema; annotations?: ToolAnnotations; + execution?: ToolExecution; _meta?: Record; - callback: ToolCallback; + handler: AnyToolHandler; enabled: boolean; enable(): void; disable(): void; @@ -1126,7 +1398,7 @@ const EMPTY_OBJECT_JSON_SCHEMA = { }; // Helper to check if an object is a Zod schema (ZodRawShapeCompat) -function isZodRawShape(obj: unknown): obj is ZodRawShapeCompat { +function isZodRawShapeCompat(obj: unknown): obj is ZodRawShapeCompat { if (typeof obj !== 'object' || obj === null) return false; const isEmptyObject = Object.keys(obj).length === 0; @@ -1148,7 +1420,7 @@ function isZodTypeLike(value: unknown): value is AnySchema { } /** - * Converts a provided Zod schema to a Zod object if it is a ZodRawShape, + * Converts a provided Zod schema to a Zod object if it is a ZodRawShapeCompat, * otherwise returns the schema as is. */ function getZodSchemaObject(schema: ZodRawShapeCompat | AnySchema | undefined): AnySchema | undefined { @@ -1156,7 +1428,7 @@ function getZodSchemaObject(schema: ZodRawShapeCompat | AnySchema | undefined): return undefined; } - if (isZodRawShape(schema)) { + if (isZodRawShapeCompat(schema)) { return objectFromShape(schema); } diff --git a/src/shared/protocol-transport-handling.test.ts b/src/shared/protocol-transport-handling.test.ts index b463d6db4..a2473f7f8 100644 --- a/src/shared/protocol-transport-handling.test.ts +++ b/src/shared/protocol-transport-handling.test.ts @@ -37,6 +37,8 @@ describe('Protocol transport handling bug', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })(); transportA = new MockTransport('A'); diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index b47de8c55..9ec39c871 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -1,8 +1,36 @@ import { ZodType, z } from 'zod'; -import { ClientCapabilities, ErrorCode, McpError, Notification, Request, Result, ServerCapabilities } from '../types.js'; +import { + CallToolRequestSchema, + ClientCapabilities, + ErrorCode, + JSONRPCMessage, + McpError, + Notification, + RELATED_TASK_META_KEY, + Request, + RequestId, + Result, + ServerCapabilities, + Task, + TaskCreationParams +} from '../types.js'; import { Protocol, mergeCapabilities } from './protocol.js'; -import { Transport } from './transport.js'; -import { MockInstance } from 'vitest'; +import { Transport, TransportSendOptions } from './transport.js'; +import { TaskStore, TaskMessageQueue, QueuedMessage, QueuedNotification, QueuedRequest } from './task.js'; +import { MockInstance, vi } from 'vitest'; +import { JSONRPCResponse, JSONRPCRequest, JSONRPCError } from '../types.js'; +import { ErrorMessage, ResponseMessage, toArrayAsync } from './responseMessage.js'; +import { InMemoryTaskMessageQueue } from '../examples/shared/inMemoryTaskStore.js'; + +// Type helper for accessing private Protocol properties in tests +interface TestProtocol { + _taskMessageQueue?: TaskMessageQueue; + _requestResolvers: Map void>; + _responseHandlers: Map void>; + _taskProgressTokens: Map; + _clearTaskQueue: (taskId: string, sessionId?: string) => Promise; + requestTaskStore: (request: Request, authInfo: unknown) => TaskStore; +} // Mock Transport class class MockTransport implements Transport { @@ -14,7 +42,96 @@ class MockTransport implements Transport { async close(): Promise { this.onclose?.(); } - async send(_message: unknown): Promise {} + async send(_message: JSONRPCMessage, _options?: TransportSendOptions): Promise {} +} + +function createMockTaskStore(options?: { + onStatus?: (status: Task['status']) => void; + onList?: () => void; +}): TaskStore & { [K in keyof TaskStore]: MockInstance } { + const tasks: Record = {}; + return { + createTask: vi.fn((taskParams: TaskCreationParams, _1: RequestId, _2: Request) => { + // Generate a unique task ID + const taskId = `test-task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + const createdAt = new Date().toISOString(); + const task = (tasks[taskId] = { + taskId, + status: 'working', + ttl: taskParams.ttl ?? null, + createdAt, + lastUpdatedAt: createdAt, + pollInterval: taskParams.pollInterval ?? 1000 + }); + options?.onStatus?.('working'); + return Promise.resolve(task); + }), + getTask: vi.fn((taskId: string) => { + return Promise.resolve(tasks[taskId] ?? null); + }), + updateTaskStatus: vi.fn((taskId, status, statusMessage) => { + const task = tasks[taskId]; + if (task) { + task.status = status; + task.statusMessage = statusMessage; + options?.onStatus?.(task.status); + } + return Promise.resolve(); + }), + storeTaskResult: vi.fn((taskId: string, status: 'completed' | 'failed', result: Result) => { + const task = tasks[taskId]; + if (task) { + task.status = status; + task.result = result; + options?.onStatus?.(status); + } + return Promise.resolve(); + }), + getTaskResult: vi.fn((taskId: string) => { + const task = tasks[taskId]; + if (task?.result) { + return Promise.resolve(task.result); + } + throw new Error('Task result not found'); + }), + listTasks: vi.fn(() => { + const result = { + tasks: Object.values(tasks) + }; + options?.onList?.(); + return Promise.resolve(result); + }) + }; +} + +function createLatch() { + let latch = false; + const waitForLatch = async () => { + while (!latch) { + await new Promise(resolve => setTimeout(resolve, 0)); + } + }; + + return { + releaseLatch: () => { + latch = true; + }, + waitForLatch + }; +} + +function assertErrorResponse(o: ResponseMessage): asserts o is ErrorMessage { + expect(o.type).toBe('error'); +} + +function assertQueuedNotification(o?: QueuedMessage): asserts o is QueuedNotification { + expect(o).toBeDefined(); + expect(o?.type).toBe('notification'); +} + +function assertQueuedRequest(o?: QueuedMessage): asserts o is QueuedRequest { + expect(o).toBeDefined(); + expect(o?.type).toBe('request'); } describe('protocol tests', () => { @@ -29,6 +146,8 @@ describe('protocol tests', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })(); }); @@ -92,9 +211,14 @@ describe('protocol tests', () => { }); const onProgressMock = vi.fn(); - protocol.request(request, mockSchema, { - onprogress: onProgressMock - }); + // Start request but don't await - we're testing the sent message + void protocol + .request(request, mockSchema, { + onprogress: onProgressMock + }) + .catch(() => { + // May not complete, ignore error + }); expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ @@ -127,9 +251,14 @@ describe('protocol tests', () => { }); const onProgressMock = vi.fn(); - protocol.request(request, mockSchema, { - onprogress: onProgressMock - }); + // Start request but don't await - we're testing the sent message + void protocol + .request(request, mockSchema, { + onprogress: onProgressMock + }) + .catch(() => { + // May not complete, ignore error + }); expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ @@ -162,7 +291,10 @@ describe('protocol tests', () => { result: z.string() }); - protocol.request(request, mockSchema); + // Start request but don't await - we're testing the sent message + void protocol.request(request, mockSchema).catch(() => { + // May not complete, ignore error + }); expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ @@ -190,9 +322,14 @@ describe('protocol tests', () => { }); const onProgressMock = vi.fn(); - protocol.request(request, mockSchema, { - onprogress: onProgressMock - }); + // Start request but don't await - we're testing the sent message + void protocol + .request(request, mockSchema, { + onprogress: onProgressMock + }) + .catch(() => { + // May not complete, ignore error + }); expect(sendSpy).toHaveBeenCalledWith( expect.objectContaining({ @@ -483,6 +620,8 @@ describe('protocol tests', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced_with_params'] }); await protocol.connect(transport); @@ -504,6 +643,8 @@ describe('protocol tests', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced_with_options'] }); await protocol.connect(transport); @@ -523,6 +664,8 @@ describe('protocol tests', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); @@ -547,6 +690,8 @@ describe('protocol tests', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); @@ -574,6 +719,8 @@ describe('protocol tests', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); @@ -599,6 +746,8 @@ describe('protocol tests', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); // Configure for a different method await protocol.connect(transport); @@ -632,6 +781,8 @@ describe('protocol tests', () => { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} })({ debouncedNotificationMethods: ['test/debounced'] }); await protocol.connect(transport); @@ -656,6 +807,97 @@ describe('protocol tests', () => { }); }); +describe('InMemoryTaskMessageQueue', () => { + let queue: TaskMessageQueue; + const taskId = 'test-task-id'; + + beforeEach(() => { + queue = new InMemoryTaskMessageQueue(); + }); + + describe('enqueue/dequeue maintains FIFO order', () => { + it('should maintain FIFO order for multiple messages', async () => { + const msg1 = { + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test1' }, + timestamp: 1 + }; + const msg2 = { + type: 'request' as const, + message: { jsonrpc: '2.0' as const, id: 1, method: 'test2' }, + timestamp: 2 + }; + const msg3 = { + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test3' }, + timestamp: 3 + }; + + await queue.enqueue(taskId, msg1); + await queue.enqueue(taskId, msg2); + await queue.enqueue(taskId, msg3); + + expect(await queue.dequeue(taskId)).toEqual(msg1); + expect(await queue.dequeue(taskId)).toEqual(msg2); + expect(await queue.dequeue(taskId)).toEqual(msg3); + }); + + it('should return undefined when dequeuing from empty queue', async () => { + expect(await queue.dequeue(taskId)).toBeUndefined(); + }); + }); + + describe('dequeueAll operation', () => { + it('should return all messages in FIFO order', async () => { + const msg1 = { + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test1' }, + timestamp: 1 + }; + const msg2 = { + type: 'request' as const, + message: { jsonrpc: '2.0' as const, id: 1, method: 'test2' }, + timestamp: 2 + }; + const msg3 = { + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test3' }, + timestamp: 3 + }; + + await queue.enqueue(taskId, msg1); + await queue.enqueue(taskId, msg2); + await queue.enqueue(taskId, msg3); + + const allMessages = await queue.dequeueAll(taskId); + + expect(allMessages).toEqual([msg1, msg2, msg3]); + }); + + it('should return empty array for empty queue', async () => { + const allMessages = await queue.dequeueAll(taskId); + expect(allMessages).toEqual([]); + }); + + it('should clear queue after dequeueAll', async () => { + await queue.enqueue(taskId, { + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test1' }, + timestamp: 1 + }); + await queue.enqueue(taskId, { + type: 'notification' as const, + message: { jsonrpc: '2.0' as const, method: 'test2' }, + timestamp: 2 + }); + + await queue.dequeueAll(taskId); + + expect(await queue.dequeue(taskId)).toBeUndefined(); + }); + }); +}); + describe('mergeCapabilities', () => { it('should merge client capabilities', () => { const base: ClientCapabilities = { @@ -745,3 +987,4364 @@ describe('mergeCapabilities', () => { expect(merged).toEqual({}); }); }); + +describe('Task-based execution', () => { + let protocol: Protocol; + let transport: MockTransport; + let sendSpy: MockInstance; + + beforeEach(() => { + transport = new MockTransport(); + sendSpy = vi.spyOn(transport, 'send'); + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: createMockTaskStore(), taskMessageQueue: new InMemoryTaskMessageQueue() }); + }); + + describe('request with task metadata', () => { + it('should include task parameters at top level', async () => { + await protocol.connect(transport); + + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + content: z.array(z.object({ type: z.literal('text'), text: z.string() })) + }); + + void protocol + .request(request, resultSchema, { + task: { + ttl: 30000, + pollInterval: 1000 + } + }) + .catch(() => { + // May not complete, ignore error + }); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'tools/call', + params: { + name: 'test-tool', + task: { + ttl: 30000, + pollInterval: 1000 + } + } + }), + expect.any(Object) + ); + }); + + it('should preserve existing _meta and add task parameters at top level', async () => { + await protocol.connect(transport); + + const request = { + method: 'tools/call', + params: { + name: 'test-tool', + _meta: { + customField: 'customValue' + } + } + }; + + const resultSchema = z.object({ + content: z.array(z.object({ type: z.literal('text'), text: z.string() })) + }); + + void protocol + .request(request, resultSchema, { + task: { + ttl: 60000 + } + }) + .catch(() => { + // May not complete, ignore error + }); + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + params: { + name: 'test-tool', + _meta: { + customField: 'customValue' + }, + task: { + ttl: 60000 + } + } + }), + expect.any(Object) + ); + }); + + it('should return Promise for task-augmented request', async () => { + await protocol.connect(transport); + + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + content: z.array(z.object({ type: z.literal('text'), text: z.string() })) + }); + + const resultPromise = protocol.request(request, resultSchema, { + task: { + ttl: 30000 + } + }); + + expect(resultPromise).toBeDefined(); + expect(resultPromise).toBeInstanceOf(Promise); + }); + }); + + describe('relatedTask metadata', () => { + it('should inject relatedTask metadata into _meta field', async () => { + await protocol.connect(transport); + + const request = { + method: 'notifications/message', + params: { data: 'test' } + }; + + const resultSchema = z.object({}); + + // Start the request (don't await completion, just let it send) + void protocol + .request(request, resultSchema, { + relatedTask: { + taskId: 'parent-task-123' + } + }) + .catch(() => { + // May not complete, ignore error + }); + + // Wait a bit for the request to be queued + await new Promise(resolve => setTimeout(resolve, 10)); + + // Requests with relatedTask should be queued, not sent via transport + // This prevents duplicate delivery for bidirectional transports + expect(sendSpy).not.toHaveBeenCalled(); + + // Verify the message was queued + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + }); + + it('should work with notification method', async () => { + await protocol.connect(transport); + + await protocol.notification( + { + method: 'notifications/message', + params: { level: 'info', data: 'test message' } + }, + { + relatedTask: { + taskId: 'parent-task-456' + } + } + ); + + // Notifications with relatedTask should be queued, not sent via transport + // This prevents duplicate delivery for bidirectional transports + expect(sendSpy).not.toHaveBeenCalled(); + + // Verify the message was queued + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + const queuedMessage = await queue!.dequeue('parent-task-456'); + assertQueuedNotification(queuedMessage); + expect(queuedMessage.message.method).toBe('notifications/message'); + expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ taskId: 'parent-task-456' }); + }); + }); + + describe('task metadata combination', () => { + it('should combine task, relatedTask, and progress metadata', async () => { + await protocol.connect(transport); + + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + content: z.array(z.object({ type: z.literal('text'), text: z.string() })) + }); + + // Start the request (don't await completion, just let it send) + void protocol + .request(request, resultSchema, { + task: { + ttl: 60000, + pollInterval: 1000 + }, + relatedTask: { + taskId: 'parent-task' + }, + onprogress: vi.fn() + }) + .catch(() => { + // May not complete, ignore error + }); + + // Wait a bit for the request to be queued + await new Promise(resolve => setTimeout(resolve, 10)); + + // Requests with relatedTask should be queued, not sent via transport + // This prevents duplicate delivery for bidirectional transports + expect(sendSpy).not.toHaveBeenCalled(); + + // Verify the message was queued with all metadata combined + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + const queuedMessage = await queue!.dequeue('parent-task'); + assertQueuedRequest(queuedMessage); + expect(queuedMessage.message.params).toMatchObject({ + name: 'test-tool', + task: { + ttl: 60000, + pollInterval: 1000 + }, + _meta: { + [RELATED_TASK_META_KEY]: { + taskId: 'parent-task' + }, + progressToken: expect.any(Number) + } + }); + }); + }); + + describe('task status transitions', () => { + it('should be handled by tool implementors, not protocol layer', () => { + // Task status management is now the responsibility of tool implementors + expect(true).toBe(true); + }); + + it('should handle requests with task creation parameters in top-level task field', async () => { + // This test documents that task creation parameters are now in the top-level task field + // rather than in _meta, and that task management is handled by tool implementors + const mockTaskStore = createMockTaskStore(); + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + protocol.setRequestHandler(CallToolRequestSchema, async request => { + // Tool implementor can access task creation parameters from request.params.task + expect(request.params.task).toEqual({ + ttl: 60000, + pollInterval: 1000 + }); + return { result: 'success' }; + }); + + transport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { + name: 'test', + arguments: {}, + task: { + ttl: 60000, + pollInterval: 1000 + } + } + }); + + // Wait for the request to be processed + await new Promise(resolve => setTimeout(resolve, 10)); + }); + }); + + describe('listTasks', () => { + it('should handle tasks/list requests and return tasks from TaskStore', async () => { + const listedTasks = createLatch(); + const mockTaskStore = createMockTaskStore({ + onList: () => listedTasks.releaseLatch() + }); + const task1 = await mockTaskStore.createTask( + { + pollInterval: 500 + }, + 1, + { + method: 'test/method', + params: {} + } + ); + // Manually set status to completed for this test + await mockTaskStore.updateTaskStatus(task1.taskId, 'completed'); + + const task2 = await mockTaskStore.createTask( + { + ttl: 60000, + pollInterval: 1000 + }, + 2, + { + method: 'test/method', + params: {} + } + ); + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + // Simulate receiving a tasks/list request + transport.onmessage?.({ + jsonrpc: '2.0', + id: 3, + method: 'tasks/list', + params: {} + }); + + await listedTasks.waitForLatch(); + + expect(mockTaskStore.listTasks).toHaveBeenCalledWith(undefined, undefined); + const sentMessage = sendSpy.mock.calls[0][0]; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(3); + expect(sentMessage.result.tasks).toEqual([ + { + taskId: task1.taskId, + status: 'completed', + ttl: null, + createdAt: expect.any(String), + lastUpdatedAt: expect.any(String), + pollInterval: 500 + }, + { + taskId: task2.taskId, + status: 'working', + ttl: 60000, + createdAt: expect.any(String), + lastUpdatedAt: expect.any(String), + pollInterval: 1000 + } + ]); + expect(sentMessage.result._meta).toEqual({}); + }); + + it('should handle tasks/list requests with cursor for pagination', async () => { + const listedTasks = createLatch(); + const mockTaskStore = createMockTaskStore({ + onList: () => listedTasks.releaseLatch() + }); + const task3 = await mockTaskStore.createTask( + { + pollInterval: 500 + }, + 1, + { + method: 'test/method', + params: {} + } + ); + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + // Simulate receiving a tasks/list request with cursor + transport.onmessage?.({ + jsonrpc: '2.0', + id: 2, + method: 'tasks/list', + params: { + cursor: 'task-2' + } + }); + + await listedTasks.waitForLatch(); + + expect(mockTaskStore.listTasks).toHaveBeenCalledWith('task-2', undefined); + const sentMessage = sendSpy.mock.calls[0][0]; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(2); + expect(sentMessage.result.tasks).toEqual([ + { + taskId: task3.taskId, + status: 'working', + ttl: null, + createdAt: expect.any(String), + lastUpdatedAt: expect.any(String), + pollInterval: 500 + } + ]); + expect(sentMessage.result.nextCursor).toBeUndefined(); + expect(sentMessage.result._meta).toEqual({}); + }); + + it('should handle tasks/list requests with empty results', async () => { + const listedTasks = createLatch(); + const mockTaskStore = createMockTaskStore({ + onList: () => listedTasks.releaseLatch() + }); + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + // Simulate receiving a tasks/list request + transport.onmessage?.({ + jsonrpc: '2.0', + id: 3, + method: 'tasks/list', + params: {} + }); + + await listedTasks.waitForLatch(); + + expect(mockTaskStore.listTasks).toHaveBeenCalledWith(undefined, undefined); + const sentMessage = sendSpy.mock.calls[0][0]; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(3); + expect(sentMessage.result.tasks).toEqual([]); + expect(sentMessage.result.nextCursor).toBeUndefined(); + expect(sentMessage.result._meta).toEqual({}); + }); + + it('should return error for invalid cursor', async () => { + const mockTaskStore = createMockTaskStore(); + mockTaskStore.listTasks.mockRejectedValue(new Error('Invalid cursor: bad-cursor')); + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + + await protocol.connect(transport); + + // Simulate receiving a tasks/list request with invalid cursor + transport.onmessage?.({ + jsonrpc: '2.0', + id: 4, + method: 'tasks/list', + params: { + cursor: 'bad-cursor' + } + }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(mockTaskStore.listTasks).toHaveBeenCalledWith('bad-cursor', undefined); + const sentMessage = sendSpy.mock.calls[0][0]; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(4); + expect(sentMessage.error).toBeDefined(); + expect(sentMessage.error.code).toBe(-32602); // InvalidParams error code + expect(sentMessage.error.message).toContain('Failed to list tasks'); + expect(sentMessage.error.message).toContain('Invalid cursor'); + }); + + it('should call listTasks method from client side', async () => { + await protocol.connect(transport); + + const listTasksPromise = protocol.listTasks(); + + // Simulate server response + setTimeout(() => { + transport.onmessage?.({ + jsonrpc: '2.0', + id: sendSpy.mock.calls[0][0].id, + result: { + tasks: [ + { + taskId: 'task-1', + status: 'completed', + ttl: null, + createdAt: '2024-01-01T00:00:00Z', + lastUpdatedAt: '2024-01-01T00:00:00Z', + pollInterval: 500 + } + ], + nextCursor: undefined, + _meta: {} + } + }); + }, 10); + + const result = await listTasksPromise; + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'tasks/list', + params: undefined + }), + expect.any(Object) + ); + expect(result.tasks).toHaveLength(1); + expect(result.tasks[0].taskId).toBe('task-1'); + }); + + it('should call listTasks with cursor from client side', async () => { + await protocol.connect(transport); + + const listTasksPromise = protocol.listTasks({ cursor: 'task-10' }); + + // Simulate server response + setTimeout(() => { + transport.onmessage?.({ + jsonrpc: '2.0', + id: sendSpy.mock.calls[0][0].id, + result: { + tasks: [ + { + taskId: 'task-11', + status: 'working', + ttl: 30000, + createdAt: '2024-01-01T00:00:00Z', + lastUpdatedAt: '2024-01-01T00:00:00Z', + pollInterval: 1000 + } + ], + nextCursor: 'task-11', + _meta: {} + } + }); + }, 10); + + const result = await listTasksPromise; + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'tasks/list', + params: { + cursor: 'task-10' + } + }), + expect.any(Object) + ); + expect(result.tasks).toHaveLength(1); + expect(result.tasks[0].taskId).toBe('task-11'); + expect(result.nextCursor).toBe('task-11'); + }); + }); + + describe('cancelTask', () => { + it('should handle tasks/cancel requests and update task status to cancelled', async () => { + const taskDeleted = createLatch(); + const mockTaskStore = createMockTaskStore(); + const task = await mockTaskStore.createTask({}, 1, { + method: 'test/method', + params: {} + }); + + mockTaskStore.getTask.mockResolvedValue(task); + mockTaskStore.updateTaskStatus.mockImplementation(async (taskId: string, status: string) => { + if (taskId === task.taskId && status === 'cancelled') { + taskDeleted.releaseLatch(); + return; + } + throw new Error('Task not found'); + }); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + const serverTransport = new MockTransport(); + const sendSpy = vi.spyOn(serverTransport, 'send'); + + await serverProtocol.connect(serverTransport); + + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 5, + method: 'tasks/cancel', + params: { + taskId: task.taskId + } + }); + + await taskDeleted.waitForLatch(); + + expect(mockTaskStore.getTask).toHaveBeenCalledWith(task.taskId, undefined); + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith( + task.taskId, + 'cancelled', + 'Client cancelled task execution.', + undefined + ); + const sentMessage = sendSpy.mock.calls[0][0] as unknown as JSONRPCResponse; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(5); + expect(sentMessage.result._meta).toBeDefined(); + }); + + it('should return error with code -32602 when task does not exist', async () => { + const taskDeleted = createLatch(); + const mockTaskStore = createMockTaskStore(); + + mockTaskStore.getTask.mockResolvedValue(null); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + const serverTransport = new MockTransport(); + const sendSpy = vi.spyOn(serverTransport, 'send'); + + await serverProtocol.connect(serverTransport); + + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 6, + method: 'tasks/cancel', + params: { + taskId: 'non-existent' + } + }); + + // Wait a bit for the async handler to complete + await new Promise(resolve => setTimeout(resolve, 10)); + taskDeleted.releaseLatch(); + + expect(mockTaskStore.getTask).toHaveBeenCalledWith('non-existent', undefined); + const sentMessage = sendSpy.mock.calls[0][0] as unknown as JSONRPCError; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(6); + expect(sentMessage.error).toBeDefined(); + expect(sentMessage.error.code).toBe(-32602); // InvalidParams error code + expect(sentMessage.error.message).toContain('Task not found'); + }); + + it('should return error with code -32602 when trying to cancel a task in terminal status', async () => { + const mockTaskStore = createMockTaskStore(); + const completedTask = await mockTaskStore.createTask({}, 1, { + method: 'test/method', + params: {} + }); + // Set task to completed status + await mockTaskStore.updateTaskStatus(completedTask.taskId, 'completed'); + completedTask.status = 'completed'; + + // Reset the mock so we can check it's not called during cancellation + mockTaskStore.updateTaskStatus.mockClear(); + mockTaskStore.getTask.mockResolvedValue(completedTask); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + const serverTransport = new MockTransport(); + const sendSpy = vi.spyOn(serverTransport, 'send'); + + await serverProtocol.connect(serverTransport); + + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 7, + method: 'tasks/cancel', + params: { + taskId: completedTask.taskId + } + }); + + // Wait a bit for the async handler to complete + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(mockTaskStore.getTask).toHaveBeenCalledWith(completedTask.taskId, undefined); + expect(mockTaskStore.updateTaskStatus).not.toHaveBeenCalled(); + const sentMessage = sendSpy.mock.calls[0][0] as unknown as JSONRPCError; + expect(sentMessage.jsonrpc).toBe('2.0'); + expect(sentMessage.id).toBe(7); + expect(sentMessage.error).toBeDefined(); + expect(sentMessage.error.code).toBe(-32602); // InvalidParams error code + expect(sentMessage.error.message).toContain('Cannot cancel task in terminal status'); + }); + + it('should call cancelTask method from client side', async () => { + await protocol.connect(transport); + + const deleteTaskPromise = protocol.cancelTask({ taskId: 'task-to-delete' }); + + // Simulate server response + setTimeout(() => { + transport.onmessage?.({ + jsonrpc: '2.0', + id: sendSpy.mock.calls[0][0].id, + result: { + _meta: {} + } + }); + }, 0); + + const result = await deleteTaskPromise; + + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'tasks/cancel', + params: { + taskId: 'task-to-delete' + } + }), + expect.any(Object) + ); + expect(result._meta).toBeDefined(); + }); + }); + + describe('task status notifications', () => { + it('should call getTask after updateTaskStatus to enable notification sending', async () => { + const mockTaskStore = createMockTaskStore(); + + // Create a task first + const task = await mockTaskStore.createTask({}, 1, { + method: 'test/method', + params: {} + }); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + const serverTransport = new MockTransport(); + + await serverProtocol.connect(serverTransport); + + // Simulate cancelling the task + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 2, + method: 'tasks/cancel', + params: { + taskId: task.taskId + } + }); + + // Wait for async processing + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify that updateTaskStatus was called + expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith( + task.taskId, + 'cancelled', + 'Client cancelled task execution.', + undefined + ); + + // Verify that getTask was called after updateTaskStatus + // This is done by the RequestTaskStore wrapper to get the updated task for the notification + const getTaskCalls = mockTaskStore.getTask.mock.calls; + const lastGetTaskCall = getTaskCalls[getTaskCalls.length - 1]; + expect(lastGetTaskCall[0]).toBe(task.taskId); + }); + }); + + describe('task metadata handling', () => { + it('should NOT include related-task metadata in tasks/get response', async () => { + const mockTaskStore = createMockTaskStore(); + + // Create a task first + const task = await mockTaskStore.createTask({}, 1, { + method: 'test/method', + params: {} + }); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + const serverTransport = new MockTransport(); + const sendSpy = vi.spyOn(serverTransport, 'send'); + + await serverProtocol.connect(serverTransport); + + // Request task status + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 2, + method: 'tasks/get', + params: { + taskId: task.taskId + } + }); + + // Wait for async processing + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify response does NOT include related-task metadata + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + result: expect.objectContaining({ + taskId: task.taskId, + status: 'working' + }) + }) + ); + + // Verify _meta is not present or doesn't contain RELATED_TASK_META_KEY + const response = sendSpy.mock.calls[0][0] as { result?: { _meta?: Record } }; + expect(response.result?._meta?.[RELATED_TASK_META_KEY]).toBeUndefined(); + }); + + it('should NOT include related-task metadata in tasks/list response', async () => { + const mockTaskStore = createMockTaskStore(); + + // Create a task first + await mockTaskStore.createTask({}, 1, { + method: 'test/method', + params: {} + }); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + const serverTransport = new MockTransport(); + const sendSpy = vi.spyOn(serverTransport, 'send'); + + await serverProtocol.connect(serverTransport); + + // Request task list + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 2, + method: 'tasks/list', + params: {} + }); + + // Wait for async processing + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify response does NOT include related-task metadata + const response = sendSpy.mock.calls[0][0] as { result?: { _meta?: Record } }; + expect(response.result?._meta).toEqual({}); + }); + + it('should NOT include related-task metadata in tasks/cancel response', async () => { + const mockTaskStore = createMockTaskStore(); + + // Create a task first + const task = await mockTaskStore.createTask({}, 1, { + method: 'test/method', + params: {} + }); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + const serverTransport = new MockTransport(); + const sendSpy = vi.spyOn(serverTransport, 'send'); + + await serverProtocol.connect(serverTransport); + + // Cancel the task + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 2, + method: 'tasks/cancel', + params: { + taskId: task.taskId + } + }); + + // Wait for async processing + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify response does NOT include related-task metadata + const response = sendSpy.mock.calls[0][0] as { result?: { _meta?: Record } }; + expect(response.result?._meta).toEqual({}); + }); + + it('should include related-task metadata in tasks/result response', async () => { + const mockTaskStore = createMockTaskStore(); + + // Create a task and complete it + const task = await mockTaskStore.createTask({}, 1, { + method: 'test/method', + params: {} + }); + + const testResult = { + content: [{ type: 'text', text: 'test result' }] + }; + + await mockTaskStore.storeTaskResult(task.taskId, 'completed', testResult); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + const serverTransport = new MockTransport(); + const sendSpy = vi.spyOn(serverTransport, 'send'); + + await serverProtocol.connect(serverTransport); + + // Request task result + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 2, + method: 'tasks/result', + params: { + taskId: task.taskId + } + }); + + // Wait for async processing + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify response DOES include related-task metadata + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + result: expect.objectContaining({ + content: testResult.content, + _meta: expect.objectContaining({ + [RELATED_TASK_META_KEY]: { + taskId: task.taskId + } + }) + }) + }) + ); + }); + + it('should propagate related-task metadata to handler sendRequest and sendNotification', async () => { + const mockTaskStore = createMockTaskStore(); + + const serverProtocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + + const serverTransport = new MockTransport(); + const sendSpy = vi.spyOn(serverTransport, 'send'); + + await serverProtocol.connect(serverTransport); + + // Set up a handler that uses sendRequest and sendNotification + serverProtocol.setRequestHandler(CallToolRequestSchema, async (_request, extra) => { + // Send a notification using the extra.sendNotification + await extra.sendNotification({ + method: 'notifications/message', + params: { level: 'info', data: 'test' } + }); + + return { + content: [{ type: 'text', text: 'done' }] + }; + }); + + // Send a request with related-task metadata + let handlerPromise: Promise | undefined; + const originalOnMessage = serverTransport.onmessage; + + serverTransport.onmessage = message => { + handlerPromise = Promise.resolve(originalOnMessage?.(message)); + return handlerPromise; + }; + + serverTransport.onmessage({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { + name: 'test-tool', + _meta: { + [RELATED_TASK_META_KEY]: { + taskId: 'parent-task-123' + } + } + } + }); + + // Wait for handler to complete + if (handlerPromise) { + await handlerPromise; + } + await new Promise(resolve => setTimeout(resolve, 100)); + + // Verify the notification was QUEUED (not sent via transport) + // Messages with relatedTask metadata should be queued for delivery via tasks/result + // to prevent duplicate delivery for bidirectional transports + const queue = (serverProtocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + const queuedMessage = await queue!.dequeue('parent-task-123'); + assertQueuedNotification(queuedMessage); + expect(queuedMessage.message.method).toBe('notifications/message'); + expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ + taskId: 'parent-task-123' + }); + + // Verify the notification was NOT sent via transport (should be queued instead) + const notificationCalls = sendSpy.mock.calls.filter(call => 'method' in call[0] && call[0].method === 'notifications/message'); + expect(notificationCalls).toHaveLength(0); + }); + }); +}); + +describe('Request Cancellation vs Task Cancellation', () => { + let protocol: Protocol; + let transport: MockTransport; + let taskStore: TaskStore; + + beforeEach(() => { + transport = new MockTransport(); + taskStore = createMockTaskStore(); + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore }); + }); + + describe('notifications/cancelled behavior', () => { + test('should abort request handler when notifications/cancelled is received', async () => { + await protocol.connect(transport); + + // Set up a request handler that checks if it was aborted + let wasAborted = false; + const TestRequestSchema = z.object({ + method: z.literal('test/longRunning'), + params: z.optional(z.record(z.unknown())) + }); + protocol.setRequestHandler(TestRequestSchema, async (_request, extra) => { + // Simulate a long-running operation + await new Promise(resolve => setTimeout(resolve, 100)); + wasAborted = extra.signal.aborted; + return { _meta: {} } as Result; + }); + + // Simulate an incoming request + const requestId = 123; + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: requestId, + method: 'test/longRunning', + params: {} + }); + } + + // Wait a bit for the handler to start + await new Promise(resolve => setTimeout(resolve, 10)); + + // Send cancellation notification + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/cancelled', + params: { + requestId: requestId, + reason: 'User cancelled' + } + }); + } + + // Wait for the handler to complete + await new Promise(resolve => setTimeout(resolve, 150)); + + // Verify the request was aborted + expect(wasAborted).toBe(true); + }); + + test('should NOT automatically cancel associated tasks when notifications/cancelled is received', async () => { + await protocol.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + method: 'test/method', + params: {} + }); + + // Send cancellation notification for the request + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/cancelled', + params: { + requestId: 'req-1', + reason: 'User cancelled' + } + }); + } + + // Wait a bit + await new Promise(resolve => setTimeout(resolve, 10)); + + // Verify the task status was NOT changed to cancelled + const updatedTask = await taskStore.getTask(task.taskId); + expect(updatedTask?.status).toBe('working'); + expect(taskStore.updateTaskStatus).not.toHaveBeenCalledWith(task.taskId, 'cancelled', expect.any(String)); + }); + }); + + describe('tasks/cancel behavior', () => { + test('should cancel task independently of request cancellation', async () => { + await protocol.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + method: 'test/method', + params: {} + }); + + // Cancel the task using tasks/cancel + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: 999, + method: 'tasks/cancel', + params: { + taskId: task.taskId + } + }); + } + + // Wait for the handler to complete + await new Promise(resolve => setTimeout(resolve, 10)); + + // Verify the task was cancelled + expect(taskStore.updateTaskStatus).toHaveBeenCalledWith( + task.taskId, + 'cancelled', + 'Client cancelled task execution.', + undefined + ); + }); + + test('should reject cancellation of terminal tasks', async () => { + await protocol.connect(transport); + const sendSpy = vi.spyOn(transport, 'send'); + + // Create a task and mark it as completed + const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + method: 'test/method', + params: {} + }); + await taskStore.updateTaskStatus(task.taskId, 'completed'); + + // Try to cancel the completed task + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: 999, + method: 'tasks/cancel', + params: { + taskId: task.taskId + } + }); + } + + // Wait for the handler to complete + await new Promise(resolve => setTimeout(resolve, 10)); + + // Verify an error was sent + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + jsonrpc: '2.0', + id: 999, + error: expect.objectContaining({ + code: ErrorCode.InvalidParams, + message: expect.stringContaining('Cannot cancel task in terminal status') + }) + }) + ); + }); + + test('should return error when task not found', async () => { + await protocol.connect(transport); + const sendSpy = vi.spyOn(transport, 'send'); + + // Try to cancel a non-existent task + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: 999, + method: 'tasks/cancel', + params: { + taskId: 'non-existent-task' + } + }); + } + + // Wait for the handler to complete + await new Promise(resolve => setTimeout(resolve, 10)); + + // Verify an error was sent + expect(sendSpy).toHaveBeenCalledWith( + expect.objectContaining({ + jsonrpc: '2.0', + id: 999, + error: expect.objectContaining({ + code: ErrorCode.InvalidParams, + message: expect.stringContaining('Task not found') + }) + }) + ); + }); + }); + + describe('separation of concerns', () => { + test('should allow request cancellation without affecting task', async () => { + await protocol.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + method: 'test/method', + params: {} + }); + + // Cancel the request (not the task) + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/cancelled', + params: { + requestId: 'req-1', + reason: 'User cancelled request' + } + }); + } + + await new Promise(resolve => setTimeout(resolve, 10)); + + // Verify task is still working + const updatedTask = await taskStore.getTask(task.taskId); + expect(updatedTask?.status).toBe('working'); + }); + + test('should allow task cancellation without affecting request', async () => { + await protocol.connect(transport); + + // Set up a request handler + let requestCompleted = false; + const TestMethodSchema = z.object({ + method: z.literal('test/method'), + params: z.optional(z.record(z.unknown())) + }); + protocol.setRequestHandler(TestMethodSchema, async () => { + await new Promise(resolve => setTimeout(resolve, 50)); + requestCompleted = true; + return { _meta: {} } as Result; + }); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { + method: 'test/method', + params: {} + }); + + // Start a request + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: 123, + method: 'test/method', + params: {} + }); + } + + // Cancel the task (not the request) + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: 999, + method: 'tasks/cancel', + params: { + taskId: task.taskId + } + }); + } + + // Wait for request to complete + await new Promise(resolve => setTimeout(resolve, 100)); + + // Verify request completed normally + expect(requestCompleted).toBe(true); + + // Verify task was cancelled + expect(taskStore.updateTaskStatus).toHaveBeenCalledWith( + task.taskId, + 'cancelled', + 'Client cancelled task execution.', + undefined + ); + }); + }); +}); + +describe('Progress notification support for tasks', () => { + let protocol: Protocol; + let transport: MockTransport; + let sendSpy: MockInstance; + + beforeEach(() => { + transport = new MockTransport(); + sendSpy = vi.spyOn(transport, 'send'); + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + }); + + it('should maintain progress token association after CreateTaskResult is returned', async () => { + const taskStore = createMockTaskStore(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore }); + + const transport = new MockTransport(); + const sendSpy = vi.spyOn(transport, 'send'); + await protocol.connect(transport); + + const progressCallback = vi.fn(); + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.number().nullable(), + createdAt: z.string() + }) + }); + + // Start a task-augmented request with progress callback + void protocol + .request(request, resultSchema, { + task: { ttl: 60000 }, + onprogress: progressCallback + }) + .catch(() => { + // May not complete, ignore error + }); + + // Wait a bit for the request to be sent + await new Promise(resolve => setTimeout(resolve, 10)); + + // Get the message ID from the sent request + const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; + const messageId = sentRequest.id; + const progressToken = sentRequest.params._meta.progressToken; + + expect(progressToken).toBe(messageId); + + // Simulate CreateTaskResult response + const taskId = 'test-task-123'; + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: messageId, + result: { + task: { + taskId, + status: 'working', + ttl: 60000, + createdAt: new Date().toISOString() + } + } + }); + } + + // Wait for response to be processed + await Promise.resolve(); + await Promise.resolve(); + + // Send a progress notification - should still work after CreateTaskResult + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 50, + total: 100 + } + }); + } + + // Wait for notification to be processed + await Promise.resolve(); + + // Verify progress callback was invoked + expect(progressCallback).toHaveBeenCalledWith({ + progress: 50, + total: 100 + }); + }); + + it('should stop progress notifications when task reaches terminal status (completed)', async () => { + const taskStore = createMockTaskStore(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore }); + + const transport = new MockTransport(); + const sendSpy = vi.spyOn(transport, 'send'); + await protocol.connect(transport); + + // Set up a request handler that will complete the task + protocol.setRequestHandler(CallToolRequestSchema, async (request, extra) => { + if (extra.taskStore) { + const task = await extra.taskStore.createTask({ ttl: 60000 }); + + // Simulate async work then complete the task + setTimeout(async () => { + await extra.taskStore!.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: 'Done' }] + }); + }, 50); + + return { task }; + } + return { content: [] }; + }); + + const progressCallback = vi.fn(); + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.number().nullable(), + createdAt: z.string() + }) + }); + + // Start a task-augmented request with progress callback + void protocol + .request(request, resultSchema, { + task: { ttl: 60000 }, + onprogress: progressCallback + }) + .catch(() => { + // May not complete, ignore error + }); + + // Wait a bit for the request to be sent + await new Promise(resolve => setTimeout(resolve, 10)); + + const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; + const messageId = sentRequest.id; + const progressToken = sentRequest.params._meta.progressToken; + + // Create a task in the mock store first so it exists when we try to get it later + const createdTask = await taskStore.createTask({ ttl: 60000 }, messageId, request); + const taskId = createdTask.taskId; + + // Simulate CreateTaskResult response + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: messageId, + result: { + task: createdTask + } + }); + } + + await Promise.resolve(); + await Promise.resolve(); + + // Progress notification should work while task is working + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 50, + total: 100 + } + }); + } + + await Promise.resolve(); + + expect(progressCallback).toHaveBeenCalledTimes(1); + + // Verify the task-progress association was created + const taskProgressTokens = (protocol as unknown as TestProtocol)._taskProgressTokens as Map; + expect(taskProgressTokens.has(taskId)).toBe(true); + expect(taskProgressTokens.get(taskId)).toBe(progressToken); + + // Simulate task completion by calling through the protocol's task store + // This will trigger the cleanup logic + const mockRequest = { jsonrpc: '2.0' as const, id: 999, method: 'test', params: {} }; + const requestTaskStore = (protocol as unknown as TestProtocol).requestTaskStore(mockRequest, undefined); + await requestTaskStore.storeTaskResult(taskId, 'completed', { content: [] }); + + // Wait for all async operations including notification sending to complete + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify the association was cleaned up + expect(taskProgressTokens.has(taskId)).toBe(false); + + // Try to send progress notification after task completion - should be ignored + progressCallback.mockClear(); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 100, + total: 100 + } + }); + } + + await Promise.resolve(); + + // Progress callback should NOT be invoked after task completion + expect(progressCallback).not.toHaveBeenCalled(); + }); + + it('should stop progress notifications when task reaches terminal status (failed)', async () => { + const taskStore = createMockTaskStore(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore }); + + const transport = new MockTransport(); + const sendSpy = vi.spyOn(transport, 'send'); + await protocol.connect(transport); + + const progressCallback = vi.fn(); + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.number().nullable(), + createdAt: z.string() + }) + }); + + void protocol.request(request, resultSchema, { + task: { ttl: 60000 }, + onprogress: progressCallback + }); + + const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; + const messageId = sentRequest.id; + const progressToken = sentRequest.params._meta.progressToken; + + // Simulate CreateTaskResult response + const taskId = 'test-task-456'; + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: messageId, + result: { + task: { + taskId, + status: 'working', + ttl: 60000, + createdAt: new Date().toISOString() + } + } + }); + } + + await new Promise(resolve => setTimeout(resolve, 10)); + + // Simulate task failure via storeTaskResult + await taskStore.storeTaskResult(taskId, 'failed', { + content: [], + isError: true + }); + + // Manually trigger the status notification + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/tasks/status', + params: { + task: { + taskId, + status: 'failed', + ttl: 60000, + createdAt: new Date().toISOString(), + statusMessage: 'Task failed' + } + } + }); + } + + await new Promise(resolve => setTimeout(resolve, 10)); + + // Try to send progress notification after task failure - should be ignored + progressCallback.mockClear(); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 75, + total: 100 + } + }); + } + + expect(progressCallback).not.toHaveBeenCalled(); + }); + + it('should stop progress notifications when task is cancelled', async () => { + const taskStore = createMockTaskStore(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore }); + + const transport = new MockTransport(); + const sendSpy = vi.spyOn(transport, 'send'); + await protocol.connect(transport); + + const progressCallback = vi.fn(); + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.number().nullable(), + createdAt: z.string() + }) + }); + + void protocol.request(request, resultSchema, { + task: { ttl: 60000 }, + onprogress: progressCallback + }); + + const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; + const messageId = sentRequest.id; + const progressToken = sentRequest.params._meta.progressToken; + + // Simulate CreateTaskResult response + const taskId = 'test-task-789'; + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: messageId, + result: { + task: { + taskId, + status: 'working', + ttl: 60000, + createdAt: new Date().toISOString() + } + } + }); + } + + await new Promise(resolve => setTimeout(resolve, 10)); + + // Simulate task cancellation via updateTaskStatus + await taskStore.updateTaskStatus(taskId, 'cancelled', 'User cancelled'); + + // Manually trigger the status notification + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/tasks/status', + params: { + task: { + taskId, + status: 'cancelled', + ttl: 60000, + createdAt: new Date().toISOString(), + statusMessage: 'User cancelled' + } + } + }); + } + + await new Promise(resolve => setTimeout(resolve, 10)); + + // Try to send progress notification after cancellation - should be ignored + progressCallback.mockClear(); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 25, + total: 100 + } + }); + } + + expect(progressCallback).not.toHaveBeenCalled(); + }); + + it('should use the same progressToken throughout task lifetime', async () => { + const taskStore = createMockTaskStore(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore }); + + const transport = new MockTransport(); + const sendSpy = vi.spyOn(transport, 'send'); + await protocol.connect(transport); + + const progressCallback = vi.fn(); + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.number().nullable(), + createdAt: z.string() + }) + }); + + void protocol.request(request, resultSchema, { + task: { ttl: 60000 }, + onprogress: progressCallback + }); + + const sentRequest = sendSpy.mock.calls[0][0] as { id: number; params: { _meta: { progressToken: number } } }; + const messageId = sentRequest.id; + const progressToken = sentRequest.params._meta.progressToken; + + // Simulate CreateTaskResult response + const taskId = 'test-task-consistency'; + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + id: messageId, + result: { + task: { + taskId, + status: 'working', + ttl: 60000, + createdAt: new Date().toISOString() + } + } + }); + } + + await Promise.resolve(); + await Promise.resolve(); + + // Send multiple progress notifications with the same token + const progressUpdates = [ + { progress: 25, total: 100 }, + { progress: 50, total: 100 }, + { progress: 75, total: 100 } + ]; + + for (const update of progressUpdates) { + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, // Same token for all notifications + ...update + } + }); + } + await Promise.resolve(); + } + + // Verify all progress notifications were received with the same token + expect(progressCallback).toHaveBeenCalledTimes(3); + expect(progressCallback).toHaveBeenNthCalledWith(1, { progress: 25, total: 100 }); + expect(progressCallback).toHaveBeenNthCalledWith(2, { progress: 50, total: 100 }); + expect(progressCallback).toHaveBeenNthCalledWith(3, { progress: 75, total: 100 }); + }); + + it('should maintain progressToken throughout task lifetime', async () => { + await protocol.connect(transport); + + const request = { + method: 'tools/call', + params: { name: 'long-running-tool' } + }; + + const resultSchema = z.object({ + content: z.array(z.object({ type: z.literal('text'), text: z.string() })) + }); + + const onProgressMock = vi.fn(); + + void protocol.request(request, resultSchema, { + task: { + ttl: 60000 + }, + onprogress: onProgressMock + }); + + const sentMessage = sendSpy.mock.calls[0][0]; + expect(sentMessage.params._meta.progressToken).toBeDefined(); + }); + + it('should support progress notifications with task-augmented requests', async () => { + await protocol.connect(transport); + + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + content: z.array(z.object({ type: z.literal('text'), text: z.string() })) + }); + + const onProgressMock = vi.fn(); + + void protocol.request(request, resultSchema, { + task: { + ttl: 30000 + }, + onprogress: onProgressMock + }); + + const sentMessage = sendSpy.mock.calls[0][0]; + const progressToken = sentMessage.params._meta.progressToken; + + // Simulate progress notification + transport.onmessage?.({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 50, + total: 100, + message: 'Processing...' + } + }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 50, + total: 100, + message: 'Processing...' + }); + }); + + it('should continue progress notifications after CreateTaskResult', async () => { + await protocol.connect(transport); + + const request = { + method: 'tools/call', + params: { name: 'test-tool' } + }; + + const resultSchema = z.object({ + task: z.object({ + taskId: z.string(), + status: z.string(), + ttl: z.number().nullable(), + createdAt: z.string() + }) + }); + + const onProgressMock = vi.fn(); + + void protocol.request(request, resultSchema, { + task: { + ttl: 30000 + }, + onprogress: onProgressMock + }); + + const sentMessage = sendSpy.mock.calls[0][0]; + const progressToken = sentMessage.params._meta.progressToken; + + // Simulate CreateTaskResult response + setTimeout(() => { + transport.onmessage?.({ + jsonrpc: '2.0', + id: sentMessage.id, + result: { + task: { + taskId: 'task-123', + status: 'working', + ttl: 30000, + createdAt: new Date().toISOString() + } + } + }); + }, 5); + + // Progress notifications should still work + setTimeout(() => { + transport.onmessage?.({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken, + progress: 75, + total: 100 + } + }); + }, 10); + + await new Promise(resolve => setTimeout(resolve, 20)); + + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 75, + total: 100 + }); + }); +}); + +describe('Capability negotiation for tasks', () => { + it('should use empty objects for capability fields', () => { + const serverCapabilities = { + tasks: { + list: {}, + cancel: {}, + requests: { + tools: { + call: {} + } + } + } + }; + + expect(serverCapabilities.tasks.list).toEqual({}); + expect(serverCapabilities.tasks.cancel).toEqual({}); + expect(serverCapabilities.tasks.requests.tools.call).toEqual({}); + }); + + it('should include list and cancel in server capabilities', () => { + const serverCapabilities = { + tasks: { + list: {}, + cancel: {} + } + }; + + expect('list' in serverCapabilities.tasks).toBe(true); + expect('cancel' in serverCapabilities.tasks).toBe(true); + }); + + it('should include list and cancel in client capabilities', () => { + const clientCapabilities = { + tasks: { + list: {}, + cancel: {} + } + }; + + expect('list' in clientCapabilities.tasks).toBe(true); + expect('cancel' in clientCapabilities.tasks).toBe(true); + }); +}); + +describe('Message interception for task-related notifications', () => { + it('should queue notifications with io.modelcontextprotocol/related-task metadata', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + + await server.connect(transport); + + // Create a task first + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Send a notification with related task metadata + await server.notification( + { + method: 'notifications/message', + params: { level: 'info', data: 'test message' } + }, + { + relatedTask: { taskId: task.taskId } + } + ); + + // Access the private queue to verify the message was queued + const queue = (server as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + const queuedMessage = await queue!.dequeue(task.taskId); + assertQueuedNotification(queuedMessage); + expect(queuedMessage.message.method).toBe('notifications/message'); + expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ taskId: task.taskId }); + }); + + it('should not queue notifications without related-task metadata', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + + await server.connect(transport); + + // Send a notification without related task metadata + await server.notification({ + method: 'notifications/message', + params: { level: 'info', data: 'test message' } + }); + + // Verify message was not queued (notification without metadata goes through transport) + // We can't directly check the queue, but we know it wasn't queued because + // notifications without relatedTask metadata are sent via transport, not queued + }); + + // Test removed: _taskResultWaiters was removed in favor of polling-based task updates + // The functionality is still tested through integration tests that verify message queuing works + + it('should propagate queue overflow errors without failing the task', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 }); + + await server.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Fill the queue to max capacity (100 messages) + for (let i = 0; i < 100; i++) { + await server.notification( + { + method: 'notifications/message', + params: { level: 'info', data: `message ${i}` } + }, + { + relatedTask: { taskId: task.taskId } + } + ); + } + + // Try to add one more message - should throw an error + await expect( + server.notification( + { + method: 'notifications/message', + params: { level: 'info', data: 'overflow message' } + }, + { + relatedTask: { taskId: task.taskId } + } + ) + ).rejects.toThrow('overflow'); + + // Verify the task was NOT automatically failed by the Protocol + // (implementations can choose to fail tasks on overflow if they want) + expect(taskStore.updateTaskStatus).not.toHaveBeenCalledWith(task.taskId, 'failed', expect.anything(), expect.anything()); + }); + + it('should extract task ID correctly from metadata', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + + await server.connect(transport); + + const taskId = 'custom-task-id-123'; + + // Send a notification with custom task ID + await server.notification( + { + method: 'notifications/message', + params: { level: 'info', data: 'test message' } + }, + { + relatedTask: { taskId } + } + ); + + // Verify the message was queued under the correct task ID + const queue = (server as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + const queuedMessage = await queue!.dequeue(taskId); + expect(queuedMessage).toBeDefined(); + }); + + it('should preserve message order when queuing multiple notifications', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + + await server.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Send multiple notifications + for (let i = 0; i < 5; i++) { + await server.notification( + { + method: 'notifications/message', + params: { level: 'info', data: `message ${i}` } + }, + { + relatedTask: { taskId: task.taskId } + } + ); + } + + // Verify messages are in FIFO order + const queue = (server as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + for (let i = 0; i < 5; i++) { + const queuedMessage = await queue!.dequeue(task.taskId); + assertQueuedNotification(queuedMessage); + expect(queuedMessage.message.params!.data).toBe(`message ${i}`); + } + }); +}); + +describe('Message interception for task-related requests', () => { + it('should queue requests with io.modelcontextprotocol/related-task metadata', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + + await server.connect(transport); + + // Create a task first + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Send a request with related task metadata (don't await - we're testing queuing) + const requestPromise = server.request( + { + method: 'ping', + params: {} + }, + z.object({}), + { + relatedTask: { taskId: task.taskId } + } + ); + + // Access the private queue to verify the message was queued + const queue = (server as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + const queuedMessage = await queue!.dequeue(task.taskId); + assertQueuedRequest(queuedMessage); + expect(queuedMessage.message.method).toBe('ping'); + expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ taskId: task.taskId }); + + // Verify resolver is stored in _requestResolvers map (not in the message) + const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; + const resolvers = (server as unknown as TestProtocol)._requestResolvers; + expect(resolvers.has(requestId)).toBe(true); + + // Clean up - send a response to prevent hanging promise + transport.onmessage?.({ + jsonrpc: '2.0', + id: requestId, + result: {} + }); + + await requestPromise; + }); + + it('should not queue requests without related-task metadata', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + + await server.connect(transport); + + // Send a request without related task metadata + const requestPromise = server.request( + { + method: 'ping', + params: {} + }, + z.object({}) + ); + + // Verify queue exists (but we don't track size in the new API) + const queue = (server as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + // Clean up - send a response + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + result: {} + }); + + await requestPromise; + }); + + // Test removed: _taskResultWaiters was removed in favor of polling-based task updates + // The functionality is still tested through integration tests that verify message queuing works + + it('should store request resolver for response routing', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + + await server.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Send a request with related task metadata + const requestPromise = server.request( + { + method: 'ping', + params: {} + }, + z.object({}), + { + relatedTask: { taskId: task.taskId } + } + ); + + // Verify the resolver was stored + const resolvers = (server as unknown as TestProtocol)._requestResolvers; + expect(resolvers.size).toBe(1); + + // Get the request ID from the queue + const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queuedMessage = await queue!.dequeue(task.taskId); + const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; + + expect(resolvers.has(requestId)).toBe(true); + + // Send a response to trigger resolver + transport.onmessage?.({ + jsonrpc: '2.0', + id: requestId, + result: {} + }); + + await requestPromise; + + // Verify resolver was cleaned up after response + expect(resolvers.has(requestId)).toBe(false); + }); + + it('should route responses to side-channeled requests', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const queue = new InMemoryTaskMessageQueue(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore, taskMessageQueue: queue }); + + await server.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Send a request with related task metadata + const requestPromise = server.request( + { + method: 'ping', + params: {} + }, + z.object({ message: z.string() }), + { + relatedTask: { taskId: task.taskId } + } + ); + + // Get the request ID from the queue + const queuedMessage = await queue.dequeue(task.taskId); + const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; + + // Enqueue a response message to the queue (simulating client sending response back) + await queue.enqueue(task.taskId, { + type: 'response', + message: { + jsonrpc: '2.0', + id: requestId, + result: { message: 'pong' } + }, + timestamp: Date.now() + }); + + // Simulate a client calling tasks/result which will process the response + // This is done by creating a mock request handler that will trigger the GetTaskPayloadRequest handler + const mockRequestId = 999; + transport.onmessage?.({ + jsonrpc: '2.0', + id: mockRequestId, + method: 'tasks/result', + params: { taskId: task.taskId } + }); + + // Wait for the response to be processed + await new Promise(resolve => setTimeout(resolve, 50)); + + // Mark task as completed + await taskStore.updateTaskStatus(task.taskId, 'completed'); + await taskStore.storeTaskResult(task.taskId, 'completed', { _meta: {} }); + + // Verify the response was routed correctly + const result = await requestPromise; + expect(result).toEqual({ message: 'pong' }); + }); + + it('should log error when resolver is missing for side-channeled request', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + + const errors: Error[] = []; + server.onerror = (error: Error) => { + errors.push(error); + }; + + await server.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Send a request with related task metadata + void server.request( + { + method: 'ping', + params: {} + }, + z.object({ message: z.string() }), + { + relatedTask: { taskId: task.taskId } + } + ); + + // Get the request ID from the queue + const queue = (server as unknown as TestProtocol)._taskMessageQueue; + const queuedMessage = await queue!.dequeue(task.taskId); + const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; + + // Manually delete the resolver to simulate missing resolver + (server as unknown as TestProtocol)._requestResolvers.delete(requestId); + + // Enqueue a response message - this should trigger the error logging when processed + await queue!.enqueue(task.taskId, { + type: 'response', + message: { + jsonrpc: '2.0', + id: requestId, + result: { message: 'pong' } + }, + timestamp: Date.now() + }); + + // Simulate a client calling tasks/result which will process the response + const mockRequestId = 888; + transport.onmessage?.({ + jsonrpc: '2.0', + id: mockRequestId, + method: 'tasks/result', + params: { taskId: task.taskId } + }); + + // Wait for the response to be processed + await new Promise(resolve => setTimeout(resolve, 50)); + + // Mark task as completed + await taskStore.updateTaskStatus(task.taskId, 'completed'); + await taskStore.storeTaskResult(task.taskId, 'completed', { _meta: {} }); + + // Wait a bit more for error to be logged + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify error was logged + expect(errors.length).toBeGreaterThanOrEqual(1); + expect(errors.some(e => e.message.includes('Response handler missing for request'))).toBe(true); + }); + + it('should propagate queue overflow errors for requests without failing the task', async () => { + const taskStore = createMockTaskStore(); + const transport = new MockTransport(); + const server = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 }); + + await server.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); + + // Fill the queue to max capacity (100 messages) + const promises: Promise[] = []; + for (let i = 0; i < 100; i++) { + const promise = server + .request( + { + method: 'ping', + params: {} + }, + z.object({}), + { + relatedTask: { taskId: task.taskId } + } + ) + .catch(() => { + // Requests will remain pending until task completes or fails + }); + promises.push(promise); + } + + // Try to add one more request - should throw an error + await expect( + server.request( + { + method: 'ping', + params: {} + }, + z.object({}), + { + relatedTask: { taskId: task.taskId } + } + ) + ).rejects.toThrow('overflow'); + + // Verify the task was NOT automatically failed by the Protocol + // (implementations can choose to fail tasks on overflow if they want) + expect(taskStore.updateTaskStatus).not.toHaveBeenCalledWith(task.taskId, 'failed', expect.anything(), expect.anything()); + }); +}); + +describe('Message Interception', () => { + let protocol: Protocol; + let transport: MockTransport; + let mockTaskStore: TaskStore & { [K in keyof TaskStore]: MockInstance }; + + beforeEach(() => { + transport = new MockTransport(); + mockTaskStore = createMockTaskStore(); + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + }); + + describe('messages with relatedTask metadata are queued', () => { + it('should queue notifications with relatedTask metadata', async () => { + await protocol.connect(transport); + + // Send a notification with relatedTask metadata + await protocol.notification( + { + method: 'notifications/message', + params: { level: 'info', data: 'test message' } + }, + { + relatedTask: { + taskId: 'task-123' + } + } + ); + + // Access the private _taskMessageQueue to verify the message was queued + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + const queuedMessage = await queue!.dequeue('task-123'); + assertQueuedNotification(queuedMessage); + expect(queuedMessage!.message.method).toBe('notifications/message'); + }); + + it('should queue requests with relatedTask metadata', async () => { + await protocol.connect(transport); + + const mockSchema = z.object({ result: z.string() }); + + // Send a request with relatedTask metadata + const requestPromise = protocol.request( + { + method: 'test/request', + params: { data: 'test' } + }, + mockSchema, + { + relatedTask: { + taskId: 'task-456' + } + } + ); + + // Access the private _taskMessageQueue to verify the message was queued + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + const queuedMessage = await queue!.dequeue('task-456'); + assertQueuedRequest(queuedMessage); + expect(queuedMessage.message.method).toBe('test/request'); + + // Verify resolver is stored in _requestResolvers map (not in the message) + const requestId = queuedMessage.message.id as RequestId; + const resolvers = (protocol as unknown as TestProtocol)._requestResolvers; + expect(resolvers.has(requestId)).toBe(true); + + // Clean up the pending request + transport.onmessage?.({ + jsonrpc: '2.0', + id: requestId, + result: { result: 'success' } + }); + await requestPromise; + }); + }); + + describe('messages without metadata bypass the queue', () => { + it('should not queue notifications without relatedTask metadata', async () => { + await protocol.connect(transport); + + // Send a notification without relatedTask metadata + await protocol.notification({ + method: 'notifications/message', + params: { level: 'info', data: 'test message' } + }); + + // Access the private _taskMessageQueue to verify no messages were queued + // Since we can't check if queues exist without messages, we verify that + // attempting to dequeue returns undefined (no messages queued) + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + }); + + it('should not queue requests without relatedTask metadata', async () => { + await protocol.connect(transport); + + const mockSchema = z.object({ result: z.string() }); + const sendSpy = vi.spyOn(transport, 'send'); + + // Send a request without relatedTask metadata + const requestPromise = protocol.request( + { + method: 'test/request', + params: { data: 'test' } + }, + mockSchema + ); + + // Access the private _taskMessageQueue to verify no messages were queued + // Since we can't check if queues exist without messages, we verify that + // attempting to dequeue returns undefined (no messages queued) + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + // Clean up the pending request + const requestId = (sendSpy.mock.calls[0][0] as JSONRPCResponse).id; + transport.onmessage?.({ + jsonrpc: '2.0', + id: requestId, + result: { result: 'success' } + }); + await requestPromise; + }); + }); + + describe('task ID extraction from metadata', () => { + it('should extract correct task ID from relatedTask metadata for notifications', async () => { + await protocol.connect(transport); + + const taskId = 'extracted-task-789'; + + // Send a notification with relatedTask metadata + await protocol.notification( + { + method: 'notifications/message', + params: { data: 'test' } + }, + { + relatedTask: { + taskId: taskId + } + } + ); + + // Verify the message was queued under the correct task ID + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + // Verify a message was queued for this task + const queuedMessage = await queue!.dequeue(taskId); + assertQueuedNotification(queuedMessage); + expect(queuedMessage.message.method).toBe('notifications/message'); + }); + + it('should extract correct task ID from relatedTask metadata for requests', async () => { + await protocol.connect(transport); + + const taskId = 'extracted-task-999'; + const mockSchema = z.object({ result: z.string() }); + + // Send a request with relatedTask metadata + const requestPromise = protocol.request( + { + method: 'test/request', + params: { data: 'test' } + }, + mockSchema, + { + relatedTask: { + taskId: taskId + } + } + ); + + // Verify the message was queued under the correct task ID + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + // Clean up the pending request + const queuedMessage = await queue!.dequeue(taskId); + assertQueuedRequest(queuedMessage); + expect(queuedMessage.message.method).toBe('test/request'); + transport.onmessage?.({ + jsonrpc: '2.0', + id: queuedMessage.message.id, + result: { result: 'success' } + }); + await requestPromise; + }); + + it('should handle multiple messages for different task IDs', async () => { + await protocol.connect(transport); + + // Send messages for different tasks + await protocol.notification({ method: 'test1', params: {} }, { relatedTask: { taskId: 'task-A' } }); + await protocol.notification({ method: 'test2', params: {} }, { relatedTask: { taskId: 'task-B' } }); + await protocol.notification({ method: 'test3', params: {} }, { relatedTask: { taskId: 'task-A' } }); + + // Verify messages are queued under correct task IDs + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + // Verify two messages for task-A + const msg1A = await queue!.dequeue('task-A'); + const msg2A = await queue!.dequeue('task-A'); + const msg3A = await queue!.dequeue('task-A'); // Should be undefined + expect(msg1A).toBeDefined(); + expect(msg2A).toBeDefined(); + expect(msg3A).toBeUndefined(); + + // Verify one message for task-B + const msg1B = await queue!.dequeue('task-B'); + const msg2B = await queue!.dequeue('task-B'); // Should be undefined + expect(msg1B).toBeDefined(); + expect(msg2B).toBeUndefined(); + }); + }); + + describe('queue creation on first message', () => { + it('should queue messages for a task', async () => { + await protocol.connect(transport); + + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + // Send first message for a task + await protocol.notification({ method: 'test', params: {} }, { relatedTask: { taskId: 'new-task' } }); + + // Verify message was queued + const msg = await queue!.dequeue('new-task'); + assertQueuedNotification(msg); + expect(msg.message.method).toBe('test'); + }); + + it('should queue multiple messages for the same task', async () => { + await protocol.connect(transport); + + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + // Send first message + await protocol.notification({ method: 'test1', params: {} }, { relatedTask: { taskId: 'reuse-task' } }); + + // Send second message + await protocol.notification({ method: 'test2', params: {} }, { relatedTask: { taskId: 'reuse-task' } }); + + // Verify both messages were queued in order + const msg1 = await queue!.dequeue('reuse-task'); + const msg2 = await queue!.dequeue('reuse-task'); + assertQueuedNotification(msg1); + expect(msg1.message.method).toBe('test1'); + assertQueuedNotification(msg2); + expect(msg2.message.method).toBe('test2'); + }); + + it('should queue messages for different tasks separately', async () => { + await protocol.connect(transport); + + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + // Send messages for different tasks + await protocol.notification({ method: 'test1', params: {} }, { relatedTask: { taskId: 'task-1' } }); + await protocol.notification({ method: 'test2', params: {} }, { relatedTask: { taskId: 'task-2' } }); + + // Verify messages are queued separately + const msg1 = await queue!.dequeue('task-1'); + const msg2 = await queue!.dequeue('task-2'); + assertQueuedNotification(msg1); + expect(msg1?.message.method).toBe('test1'); + assertQueuedNotification(msg2); + expect(msg2?.message.method).toBe('test2'); + }); + }); + + describe('metadata preservation in queued messages', () => { + it('should preserve relatedTask metadata in queued notification', async () => { + await protocol.connect(transport); + + const relatedTask = { taskId: 'task-meta-123' }; + + await protocol.notification( + { + method: 'test/notification', + params: { data: 'test' } + }, + { relatedTask } + ); + + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queuedMessage = await queue!.dequeue('task-meta-123'); + + // Verify the metadata is preserved in the queued message + expect(queuedMessage).toBeDefined(); + assertQueuedNotification(queuedMessage); + expect(queuedMessage.message.params!._meta).toBeDefined(); + expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual(relatedTask); + }); + + it('should preserve relatedTask metadata in queued request', async () => { + await protocol.connect(transport); + + const relatedTask = { taskId: 'task-meta-456' }; + const mockSchema = z.object({ result: z.string() }); + + const requestPromise = protocol.request( + { + method: 'test/request', + params: { data: 'test' } + }, + mockSchema, + { relatedTask } + ); + + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queuedMessage = await queue!.dequeue('task-meta-456'); + + // Verify the metadata is preserved in the queued message + expect(queuedMessage).toBeDefined(); + assertQueuedRequest(queuedMessage); + expect(queuedMessage.message.params!._meta).toBeDefined(); + expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual(relatedTask); + + // Clean up + transport.onmessage?.({ + jsonrpc: '2.0', + id: (queuedMessage!.message as JSONRPCRequest).id, + result: { result: 'success' } + }); + await requestPromise; + }); + + it('should preserve existing _meta fields when adding relatedTask', async () => { + await protocol.connect(transport); + + await protocol.notification( + { + method: 'test/notification', + params: { + data: 'test', + _meta: { + customField: 'customValue', + anotherField: 123 + } + } + }, + { + relatedTask: { taskId: 'task-preserve-meta' } + } + ); + + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const queuedMessage = await queue!.dequeue('task-preserve-meta'); + + // Verify both existing and new metadata are preserved + expect(queuedMessage).toBeDefined(); + assertQueuedNotification(queuedMessage); + expect(queuedMessage.message.params!._meta!.customField).toBe('customValue'); + expect(queuedMessage.message.params!._meta!.anotherField).toBe(123); + expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ + taskId: 'task-preserve-meta' + }); + }); + }); +}); + +describe('Queue lifecycle management', () => { + let protocol: Protocol; + let transport: MockTransport; + let mockTaskStore: TaskStore & { [K in keyof TaskStore]: MockInstance }; + + beforeEach(() => { + transport = new MockTransport(); + mockTaskStore = createMockTaskStore(); + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); + }); + + describe('queue cleanup on task completion', () => { + it('should clear queue when task reaches completed status', async () => { + await protocol.connect(transport); + + // Create a task + const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); + const taskId = task.taskId; + + // Queue some messages for the task + await protocol.notification({ method: 'test/notification', params: { data: 'test1' } }, { relatedTask: { taskId } }); + await protocol.notification({ method: 'test/notification', params: { data: 'test2' } }, { relatedTask: { taskId } }); + + // Verify messages are queued + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + // Verify messages can be dequeued + const msg1 = await queue!.dequeue(taskId); + const msg2 = await queue!.dequeue(taskId); + expect(msg1).toBeDefined(); + expect(msg2).toBeDefined(); + + // Directly call the cleanup method (simulating what happens when task reaches terminal status) + (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + + // After cleanup, no more messages should be available + const msg3 = await queue!.dequeue(taskId); + expect(msg3).toBeUndefined(); + }); + + it('should clear queue after delivering messages on tasks/result for completed task', async () => { + await protocol.connect(transport); + + // Create a task + const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); + const taskId = task.taskId; + + // Queue a message + await protocol.notification({ method: 'test/notification', params: { data: 'test' } }, { relatedTask: { taskId } }); + + // Mark task as completed + const completedTask = { ...task, status: 'completed' as const }; + mockTaskStore.getTask.mockResolvedValue(completedTask); + mockTaskStore.getTaskResult.mockResolvedValue({ content: [{ type: 'text', text: 'done' }] }); + + // Simulate tasks/result request + const resultPromise = new Promise(resolve => { + transport.onmessage?.({ + jsonrpc: '2.0', + id: 100, + method: 'tasks/result', + params: { taskId } + }); + setTimeout(resolve, 50); + }); + + await resultPromise; + + // Verify queue is cleared after delivery (no messages available) + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const msg = await queue!.dequeue(taskId); + expect(msg).toBeUndefined(); + }); + }); + + describe('queue cleanup on task cancellation', () => { + it('should clear queue when task is cancelled', async () => { + await protocol.connect(transport); + + // Create a task + const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); + const taskId = task.taskId; + + // Queue some messages + await protocol.notification({ method: 'test/notification', params: { data: 'test1' } }, { relatedTask: { taskId } }); + + // Verify message is queued + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + const msg1 = await queue!.dequeue(taskId); + expect(msg1).toBeDefined(); + + // Re-queue the message for cancellation test + await protocol.notification({ method: 'test/notification', params: { data: 'test1' } }, { relatedTask: { taskId } }); + + // Mock task as non-terminal + mockTaskStore.getTask.mockResolvedValue(task); + + // Cancel the task + transport.onmessage?.({ + jsonrpc: '2.0', + id: 200, + method: 'tasks/cancel', + params: { taskId } + }); + + // Wait for cancellation to process + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify queue is cleared (no messages available) + const msg2 = await queue!.dequeue(taskId); + expect(msg2).toBeUndefined(); + }); + + it('should reject pending request resolvers when task is cancelled', async () => { + await protocol.connect(transport); + + // Create a task + const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); + const taskId = task.taskId; + + // Queue a request (catch rejection to avoid unhandled promise rejection) + const requestPromise = protocol + .request({ method: 'test/request', params: { data: 'test' } }, z.object({ result: z.string() }), { + relatedTask: { taskId } + }) + .catch(err => err); + + // Verify request is queued + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + // Mock task as non-terminal + mockTaskStore.getTask.mockResolvedValue(task); + + // Cancel the task + transport.onmessage?.({ + jsonrpc: '2.0', + id: 201, + method: 'tasks/cancel', + params: { taskId } + }); + + // Wait for cancellation to process + await new Promise(resolve => setTimeout(resolve, 50)); + + // Verify the request promise is rejected + const result = await requestPromise; + expect(result).toBeInstanceOf(McpError); + expect(result.message).toContain('Task cancelled or completed'); + + // Verify queue is cleared (no messages available) + const msg = await queue!.dequeue(taskId); + expect(msg).toBeUndefined(); + }); + }); + + describe('queue cleanup on task failure', () => { + it('should clear queue when task reaches failed status', async () => { + await protocol.connect(transport); + + // Create a task + const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); + const taskId = task.taskId; + + // Queue some messages + await protocol.notification({ method: 'test/notification', params: { data: 'test1' } }, { relatedTask: { taskId } }); + await protocol.notification({ method: 'test/notification', params: { data: 'test2' } }, { relatedTask: { taskId } }); + + // Verify messages are queued + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + // Verify messages can be dequeued + const msg1 = await queue!.dequeue(taskId); + const msg2 = await queue!.dequeue(taskId); + expect(msg1).toBeDefined(); + expect(msg2).toBeDefined(); + + // Directly call the cleanup method (simulating what happens when task reaches terminal status) + (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + + // After cleanup, no more messages should be available + const msg3 = await queue!.dequeue(taskId); + expect(msg3).toBeUndefined(); + }); + + it('should reject pending request resolvers when task fails', async () => { + await protocol.connect(transport); + + // Create a task + const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); + const taskId = task.taskId; + + // Queue a request (catch the rejection to avoid unhandled promise rejection) + const requestPromise = protocol + .request({ method: 'test/request', params: { data: 'test' } }, z.object({ result: z.string() }), { + relatedTask: { taskId } + }) + .catch(err => err); + + // Verify request is queued + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + // Directly call the cleanup method (simulating what happens when task reaches terminal status) + (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + + // Verify the request promise is rejected + const result = await requestPromise; + expect(result).toBeInstanceOf(McpError); + expect(result.message).toContain('Task cancelled or completed'); + + // Verify queue is cleared (no messages available) + const msg = await queue!.dequeue(taskId); + expect(msg).toBeUndefined(); + }); + }); + + describe('resolver rejection on cleanup', () => { + it('should reject all pending request resolvers when queue is cleared', async () => { + await protocol.connect(transport); + + // Create a task + const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); + const taskId = task.taskId; + + // Queue multiple requests (catch rejections to avoid unhandled promise rejections) + const request1Promise = protocol + .request({ method: 'test/request1', params: { data: 'test1' } }, z.object({ result: z.string() }), { + relatedTask: { taskId } + }) + .catch(err => err); + + const request2Promise = protocol + .request({ method: 'test/request2', params: { data: 'test2' } }, z.object({ result: z.string() }), { + relatedTask: { taskId } + }) + .catch(err => err); + + const request3Promise = protocol + .request({ method: 'test/request3', params: { data: 'test3' } }, z.object({ result: z.string() }), { + relatedTask: { taskId } + }) + .catch(err => err); + + // Verify requests are queued + const queue = (protocol as unknown as TestProtocol)._taskMessageQueue; + expect(queue).toBeDefined(); + + // Directly call the cleanup method (simulating what happens when task reaches terminal status) + (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + + // Verify all request promises are rejected + const result1 = await request1Promise; + const result2 = await request2Promise; + const result3 = await request3Promise; + + expect(result1).toBeInstanceOf(McpError); + expect(result1.message).toContain('Task cancelled or completed'); + expect(result2).toBeInstanceOf(McpError); + expect(result2.message).toContain('Task cancelled or completed'); + expect(result3).toBeInstanceOf(McpError); + expect(result3.message).toContain('Task cancelled or completed'); + + // Verify queue is cleared (no messages available) + const msg = await queue!.dequeue(taskId); + expect(msg).toBeUndefined(); + }); + + it('should clean up resolver mappings when rejecting requests', async () => { + await protocol.connect(transport); + + // Create a task + const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); + const taskId = task.taskId; + + // Queue a request (catch rejection to avoid unhandled promise rejection) + const requestPromise = protocol + .request({ method: 'test/request', params: { data: 'test' } }, z.object({ result: z.string() }), { + relatedTask: { taskId } + }) + .catch(err => err); + + // Get the request ID that was sent + const requestResolvers = (protocol as unknown as TestProtocol)._requestResolvers; + const initialResolverCount = requestResolvers.size; + expect(initialResolverCount).toBeGreaterThan(0); + + // Complete the task (triggers cleanup) + const completedTask = { ...task, status: 'completed' as const }; + mockTaskStore.getTask.mockResolvedValue(completedTask); + + // Directly call the cleanup method (simulating what happens when task reaches terminal status) + (protocol as unknown as TestProtocol)._clearTaskQueue(taskId); + + // Verify request promise is rejected + const result = await requestPromise; + expect(result).toBeInstanceOf(McpError); + expect(result.message).toContain('Task cancelled or completed'); + + // Verify resolver mapping is cleaned up + // The resolver should be removed from the map + expect(requestResolvers.size).toBeLessThan(initialResolverCount); + }); + }); +}); + +describe('requestStream() method', () => { + const CallToolResultSchema = z.object({ + content: z.array(z.object({ type: z.string(), text: z.string() })), + _meta: z.object({}).optional() + }); + + test('should yield result immediately for non-task requests', async () => { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + // Start the request stream + const streamPromise = (async () => { + const messages = []; + const stream = protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema); + for await (const message of stream) { + messages.push(message); + } + return messages; + })(); + + // Simulate server response + await new Promise(resolve => setTimeout(resolve, 10)); + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + result: { + content: [{ type: 'text', text: 'test result' }], + _meta: {} + } + }); + + const messages = await streamPromise; + + // Should yield exactly one result message + expect(messages).toHaveLength(1); + expect(messages[0].type).toBe('result'); + expect(messages[0]).toHaveProperty('result'); + }); + + test('should yield error message on request failure', async () => { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + // Start the request stream + const streamPromise = (async () => { + const messages = []; + const stream = protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema); + for await (const message of stream) { + messages.push(message); + } + return messages; + })(); + + // Simulate server error response + await new Promise(resolve => setTimeout(resolve, 10)); + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + error: { + code: ErrorCode.InternalError, + message: 'Test error' + } + }); + + const messages = await streamPromise; + + // Should yield exactly one error message + expect(messages).toHaveLength(1); + expect(messages[0].type).toBe('error'); + expect(messages[0]).toHaveProperty('error'); + if (messages[0].type === 'error') { + expect(messages[0].error.message).toContain('Test error'); + } + }); + + test('should handle cancellation via AbortSignal', async () => { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + const abortController = new AbortController(); + + // Abort immediately before starting the stream + abortController.abort('User cancelled'); + + // Start the request stream with already-aborted signal + const messages = []; + const stream = protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema, { + signal: abortController.signal + }); + for await (const message of stream) { + messages.push(message); + } + + // Should yield error message about cancellation + expect(messages).toHaveLength(1); + expect(messages[0].type).toBe('error'); + if (messages[0].type === 'error') { + expect(messages[0].error.message).toContain('cancelled'); + } + }); + + describe('Error responses', () => { + test('should yield error as terminal message for server error response', async () => { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + const messagesPromise = toArrayAsync( + protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema) + ); + + // Simulate server error response + await new Promise(resolve => setTimeout(resolve, 10)); + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + error: { + code: ErrorCode.InternalError, + message: 'Server error' + } + }); + + // Collect messages + const messages = await messagesPromise; + + // Verify error is terminal and last message + expect(messages.length).toBeGreaterThan(0); + const lastMessage = messages[messages.length - 1]; + assertErrorResponse(lastMessage); + expect(lastMessage.error).toBeDefined(); + expect(lastMessage.error.message).toContain('Server error'); + }); + + test('should yield error as terminal message for timeout', async () => { + vi.useFakeTimers(); + try { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + const messagesPromise = toArrayAsync( + protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema, { + timeout: 100 + }) + ); + + // Advance time to trigger timeout + await vi.advanceTimersByTimeAsync(101); + + // Collect messages + const messages = await messagesPromise; + + // Verify error is terminal and last message + expect(messages.length).toBeGreaterThan(0); + const lastMessage = messages[messages.length - 1]; + assertErrorResponse(lastMessage); + expect(lastMessage.error).toBeDefined(); + expect(lastMessage.error.code).toBe(ErrorCode.RequestTimeout); + } finally { + vi.useRealTimers(); + } + }); + + test('should yield error as terminal message for cancellation', async () => { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + const abortController = new AbortController(); + abortController.abort('User cancelled'); + + // Collect messages + const messages = await toArrayAsync( + protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema, { + signal: abortController.signal + }) + ); + + // Verify error is terminal and last message + expect(messages.length).toBeGreaterThan(0); + const lastMessage = messages[messages.length - 1]; + assertErrorResponse(lastMessage); + expect(lastMessage.error).toBeDefined(); + expect(lastMessage.error.message).toContain('cancelled'); + }); + + test('should not yield any messages after error message', async () => { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + const messagesPromise = toArrayAsync( + protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema) + ); + + // Simulate server error response + await new Promise(resolve => setTimeout(resolve, 10)); + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + error: { + code: ErrorCode.InternalError, + message: 'Test error' + } + }); + + // Collect messages + const messages = await messagesPromise; + + // Verify only one message (the error) was yielded + expect(messages).toHaveLength(1); + expect(messages[0].type).toBe('error'); + + // Try to send another message (should be ignored) + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + result: { + content: [{ type: 'text', text: 'should not appear' }] + } + }); + + await new Promise(resolve => setTimeout(resolve, 10)); + + // Verify no additional messages were yielded + expect(messages).toHaveLength(1); + }); + + test('should yield error as terminal message for task failure', async () => { + const transport = new MockTransport(); + const mockTaskStore = createMockTaskStore(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })({ taskStore: mockTaskStore }); + await protocol.connect(transport); + + const messagesPromise = toArrayAsync( + protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema) + ); + + // Simulate task creation response + await new Promise(resolve => setTimeout(resolve, 10)); + const taskId = 'test-task-123'; + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + result: { + _meta: { + task: { + taskId, + status: 'working', + createdAt: new Date().toISOString(), + pollInterval: 100 + } + } + } + }); + + // Wait for task creation to be processed + await new Promise(resolve => setTimeout(resolve, 20)); + + // Update task to failed status + const failedTask = { + taskId, + status: 'failed' as const, + createdAt: new Date().toISOString(), + pollInterval: 100, + ttl: null, + statusMessage: 'Task failed' + }; + mockTaskStore.getTask.mockResolvedValue(failedTask); + + // Collect messages + const messages = await messagesPromise; + + // Verify error is terminal and last message + expect(messages.length).toBeGreaterThan(0); + const lastMessage = messages[messages.length - 1]; + assertErrorResponse(lastMessage); + expect(lastMessage.error).toBeDefined(); + }); + + test('should yield error as terminal message for network error', async () => { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + // Override send to simulate network error + transport.send = vi.fn().mockRejectedValue(new Error('Network error')); + + const messages = await toArrayAsync( + protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema) + ); + + // Verify error is terminal and last message + expect(messages.length).toBeGreaterThan(0); + const lastMessage = messages[messages.length - 1]; + assertErrorResponse(lastMessage); + expect(lastMessage.error).toBeDefined(); + }); + + test('should ensure error is always the final message', async () => { + const transport = new MockTransport(); + const protocol = new (class extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + })(); + await protocol.connect(transport); + + const messagesPromise = toArrayAsync( + protocol.requestStream({ method: 'tools/call', params: { name: 'test', arguments: {} } }, CallToolResultSchema) + ); + + // Simulate server error response + await new Promise(resolve => setTimeout(resolve, 10)); + transport.onmessage?.({ + jsonrpc: '2.0', + id: 0, + error: { + code: ErrorCode.InternalError, + message: 'Test error' + } + }); + + // Collect messages + const messages = await messagesPromise; + + // Verify error is the last message + expect(messages.length).toBeGreaterThan(0); + const lastMessage = messages[messages.length - 1]; + expect(lastMessage.type).toBe('error'); + + // Verify all messages before the last are not terminal + for (let i = 0; i < messages.length - 1; i++) { + expect(messages[i].type).not.toBe('error'); + expect(messages[i].type).not.toBe('result'); + } + }); + }); +}); + +describe('Error handling for missing resolvers', () => { + let protocol: Protocol; + let transport: MockTransport; + let taskStore: TaskStore & { [K in keyof TaskStore]: MockInstance }; + let taskMessageQueue: TaskMessageQueue; + let errorHandler: MockInstance; + + beforeEach(() => { + taskStore = createMockTaskStore(); + taskMessageQueue = new InMemoryTaskMessageQueue(); + errorHandler = vi.fn(); + + protocol = new (class extends Protocol { + protected assertCapabilityForMethod(_method: string): void {} + protected assertNotificationCapability(_method: string): void {} + protected assertRequestHandlerCapability(_method: string): void {} + protected assertTaskCapability(_method: string): void {} + protected assertTaskHandlerCapability(_method: string): void {} + })({ + taskStore, + taskMessageQueue, + defaultTaskPollInterval: 100 + }); + + // @ts-expect-error deliberately overriding error handler with mock + protocol.onerror = errorHandler; + transport = new MockTransport(); + }); + + describe('Response routing with missing resolvers', () => { + it('should log error for unknown request ID without throwing', async () => { + await protocol.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + // Enqueue a response message without a corresponding resolver + await taskMessageQueue.enqueue(task.taskId, { + type: 'response', + message: { + jsonrpc: '2.0', + id: 999, // Non-existent request ID + result: { content: [] } + }, + timestamp: Date.now() + }); + + // Set up the GetTaskPayloadRequest handler to process the message + const testProtocol = protocol as unknown as TestProtocol; + + // Simulate dequeuing and processing the response + const queuedMessage = await taskMessageQueue.dequeue(task.taskId); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage?.type).toBe('response'); + + // Manually trigger the response handling logic + if (queuedMessage && queuedMessage.type === 'response') { + const responseMessage = queuedMessage.message as JSONRPCResponse; + const requestId = responseMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(requestId); + + if (!resolver) { + // This simulates what happens in the actual handler + protocol.onerror?.(new Error(`Response handler missing for request ${requestId}`)); + } + } + + // Verify error was logged + expect(errorHandler).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('Response handler missing for request 999') + }) + ); + }); + + it('should continue processing after missing resolver error', async () => { + await protocol.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + // Enqueue a response with missing resolver, then a valid notification + await taskMessageQueue.enqueue(task.taskId, { + type: 'response', + message: { + jsonrpc: '2.0', + id: 999, + result: { content: [] } + }, + timestamp: Date.now() + }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'notification', + message: { + jsonrpc: '2.0', + method: 'notifications/progress', + params: { progress: 50, total: 100 } + }, + timestamp: Date.now() + }); + + // Process first message (response with missing resolver) + const msg1 = await taskMessageQueue.dequeue(task.taskId); + expect(msg1?.type).toBe('response'); + + // Process second message (should work fine) + const msg2 = await taskMessageQueue.dequeue(task.taskId); + expect(msg2?.type).toBe('notification'); + expect(msg2?.message).toMatchObject({ + method: 'notifications/progress' + }); + }); + }); + + describe('Task cancellation with missing resolvers', () => { + it('should log error when resolver is missing during cleanup', async () => { + await protocol.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + // Enqueue a request without storing a resolver + await taskMessageQueue.enqueue(task.taskId, { + type: 'request', + message: { + jsonrpc: '2.0', + id: 42, + method: 'tools/call', + params: { name: 'test-tool', arguments: {} } + }, + timestamp: Date.now() + }); + + // Clear the task queue (simulating cancellation) + const testProtocol = protocol as unknown as TestProtocol; + await testProtocol._clearTaskQueue(task.taskId); + + // Verify error was logged for missing resolver + expect(errorHandler).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('Resolver missing for request 42') + }) + ); + }); + + it('should handle cleanup gracefully when resolver exists', async () => { + await protocol.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + const requestId = 42; + const resolverMock = vi.fn(); + + // Store a resolver + const testProtocol = protocol as unknown as TestProtocol; + testProtocol._requestResolvers.set(requestId, resolverMock); + + // Enqueue a request + await taskMessageQueue.enqueue(task.taskId, { + type: 'request', + message: { + jsonrpc: '2.0', + id: requestId, + method: 'tools/call', + params: { name: 'test-tool', arguments: {} } + }, + timestamp: Date.now() + }); + + // Clear the task queue + await testProtocol._clearTaskQueue(task.taskId); + + // Verify resolver was called with cancellation error + expect(resolverMock).toHaveBeenCalledWith(expect.any(McpError)); + + // Verify the error has the correct properties + const calledError = resolverMock.mock.calls[0][0]; + expect(calledError.code).toBe(ErrorCode.InternalError); + expect(calledError.message).toContain('Task cancelled or completed'); + + // Verify resolver was removed + expect(testProtocol._requestResolvers.has(requestId)).toBe(false); + }); + + it('should handle mixed messages during cleanup', async () => { + await protocol.connect(transport); + + // Create a task + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + const testProtocol = protocol as unknown as TestProtocol; + + // Enqueue multiple messages: request with resolver, request without, notification + const requestId1 = 42; + const resolverMock = vi.fn(); + testProtocol._requestResolvers.set(requestId1, resolverMock); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'request', + message: { + jsonrpc: '2.0', + id: requestId1, + method: 'tools/call', + params: { name: 'test-tool', arguments: {} } + }, + timestamp: Date.now() + }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'request', + message: { + jsonrpc: '2.0', + id: 43, // No resolver for this one + method: 'tools/call', + params: { name: 'test-tool', arguments: {} } + }, + timestamp: Date.now() + }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'notification', + message: { + jsonrpc: '2.0', + method: 'notifications/progress', + params: { progress: 50, total: 100 } + }, + timestamp: Date.now() + }); + + // Clear the task queue + await testProtocol._clearTaskQueue(task.taskId); + + // Verify resolver was called for first request + expect(resolverMock).toHaveBeenCalledWith(expect.any(McpError)); + + // Verify the error has the correct properties + const calledError = resolverMock.mock.calls[0][0]; + expect(calledError.code).toBe(ErrorCode.InternalError); + expect(calledError.message).toContain('Task cancelled or completed'); + + // Verify error was logged for second request + expect(errorHandler).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('Resolver missing for request 43') + }) + ); + + // Verify queue is empty + const remaining = await taskMessageQueue.dequeue(task.taskId); + expect(remaining).toBeUndefined(); + }); + }); + + describe('Side-channeled request error handling', () => { + it('should log error when response handler is missing for side-channeled request', async () => { + await protocol.connect(transport); + + const testProtocol = protocol as unknown as TestProtocol; + const messageId = 123; + + // Create a response resolver without a corresponding response handler + const responseResolver = (response: JSONRPCResponse | Error) => { + const handler = testProtocol._responseHandlers.get(messageId); + if (handler) { + handler(response); + } else { + protocol.onerror?.(new Error(`Response handler missing for side-channeled request ${messageId}`)); + } + }; + + // Simulate the resolver being called without a handler + const mockResponse: JSONRPCResponse = { + jsonrpc: '2.0', + id: messageId, + result: { content: [] } + }; + + responseResolver(mockResponse); + + // Verify error was logged + expect(errorHandler).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('Response handler missing for side-channeled request 123') + }) + ); + }); + }); + + describe('Error handling does not throw exceptions', () => { + it('should not throw when processing response with missing resolver', async () => { + await protocol.connect(transport); + + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'response', + message: { + jsonrpc: '2.0', + id: 999, + result: { content: [] } + }, + timestamp: Date.now() + }); + + // This should not throw + const processMessage = async () => { + const msg = await taskMessageQueue.dequeue(task.taskId); + if (msg && msg.type === 'response') { + const testProtocol = protocol as unknown as TestProtocol; + const responseMessage = msg.message as JSONRPCResponse; + const requestId = responseMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(requestId); + if (!resolver) { + protocol.onerror?.(new Error(`Response handler missing for request ${requestId}`)); + } + } + }; + + await expect(processMessage()).resolves.not.toThrow(); + }); + + it('should not throw during task cleanup with missing resolvers', async () => { + await protocol.connect(transport); + + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'request', + message: { + jsonrpc: '2.0', + id: 42, + method: 'tools/call', + params: { name: 'test-tool', arguments: {} } + }, + timestamp: Date.now() + }); + + const testProtocol = protocol as unknown as TestProtocol; + + // This should not throw + await expect(testProtocol._clearTaskQueue(task.taskId)).resolves.not.toThrow(); + }); + }); + + describe('Error message routing', () => { + it('should route error messages to resolvers correctly', async () => { + await protocol.connect(transport); + + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const requestId = 42; + const resolverMock = vi.fn(); + + // Store a resolver + const testProtocol = protocol as unknown as TestProtocol; + testProtocol._requestResolvers.set(requestId, resolverMock); + + // Enqueue an error message + await taskMessageQueue.enqueue(task.taskId, { + type: 'error', + message: { + jsonrpc: '2.0', + id: requestId, + error: { + code: ErrorCode.InvalidRequest, + message: 'Invalid request parameters' + } + }, + timestamp: Date.now() + }); + + // Simulate dequeuing and processing the error + const queuedMessage = await taskMessageQueue.dequeue(task.taskId); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage?.type).toBe('error'); + + // Manually trigger the error handling logic + if (queuedMessage && queuedMessage.type === 'error') { + const errorMessage = queuedMessage.message as JSONRPCError; + const reqId = errorMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(reqId); + + if (resolver) { + testProtocol._requestResolvers.delete(reqId); + const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + resolver(error); + } + } + + // Verify resolver was called with McpError + expect(resolverMock).toHaveBeenCalledWith(expect.any(McpError)); + const calledError = resolverMock.mock.calls[0][0]; + expect(calledError.code).toBe(ErrorCode.InvalidRequest); + expect(calledError.message).toContain('Invalid request parameters'); + + // Verify resolver was removed from map + expect(testProtocol._requestResolvers.has(requestId)).toBe(false); + }); + + it('should log error for unknown request ID in error messages', async () => { + await protocol.connect(transport); + + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + // Enqueue an error message without a corresponding resolver + await taskMessageQueue.enqueue(task.taskId, { + type: 'error', + message: { + jsonrpc: '2.0', + id: 999, + error: { + code: ErrorCode.InternalError, + message: 'Something went wrong' + } + }, + timestamp: Date.now() + }); + + // Simulate dequeuing and processing the error + const queuedMessage = await taskMessageQueue.dequeue(task.taskId); + expect(queuedMessage).toBeDefined(); + expect(queuedMessage?.type).toBe('error'); + + // Manually trigger the error handling logic + if (queuedMessage && queuedMessage.type === 'error') { + const testProtocol = protocol as unknown as TestProtocol; + const errorMessage = queuedMessage.message as JSONRPCError; + const requestId = errorMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(requestId); + + if (!resolver) { + protocol.onerror?.(new Error(`Error handler missing for request ${requestId}`)); + } + } + + // Verify error was logged + expect(errorHandler).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('Error handler missing for request 999') + }) + ); + }); + + it('should handle error messages with data field', async () => { + await protocol.connect(transport); + + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const requestId = 42; + const resolverMock = vi.fn(); + + // Store a resolver + const testProtocol = protocol as unknown as TestProtocol; + testProtocol._requestResolvers.set(requestId, resolverMock); + + // Enqueue an error message with data field + await taskMessageQueue.enqueue(task.taskId, { + type: 'error', + message: { + jsonrpc: '2.0', + id: requestId, + error: { + code: ErrorCode.InvalidParams, + message: 'Validation failed', + data: { field: 'userName', reason: 'required' } + } + }, + timestamp: Date.now() + }); + + // Simulate dequeuing and processing the error + const queuedMessage = await taskMessageQueue.dequeue(task.taskId); + + if (queuedMessage && queuedMessage.type === 'error') { + const errorMessage = queuedMessage.message as JSONRPCError; + const reqId = errorMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(reqId); + + if (resolver) { + testProtocol._requestResolvers.delete(reqId); + const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + resolver(error); + } + } + + // Verify resolver was called with McpError including data + expect(resolverMock).toHaveBeenCalledWith(expect.any(McpError)); + const calledError = resolverMock.mock.calls[0][0]; + expect(calledError.code).toBe(ErrorCode.InvalidParams); + expect(calledError.message).toContain('Validation failed'); + expect(calledError.data).toEqual({ field: 'userName', reason: 'required' }); + }); + + it('should not throw when processing error with missing resolver', async () => { + await protocol.connect(transport); + + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'error', + message: { + jsonrpc: '2.0', + id: 999, + error: { + code: ErrorCode.InternalError, + message: 'Error occurred' + } + }, + timestamp: Date.now() + }); + + // This should not throw + const processMessage = async () => { + const msg = await taskMessageQueue.dequeue(task.taskId); + if (msg && msg.type === 'error') { + const testProtocol = protocol as unknown as TestProtocol; + const errorMessage = msg.message as JSONRPCError; + const requestId = errorMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(requestId); + if (!resolver) { + protocol.onerror?.(new Error(`Error handler missing for request ${requestId}`)); + } + } + }; + + await expect(processMessage()).resolves.not.toThrow(); + }); + }); + + describe('Response and error message routing integration', () => { + it('should handle mixed response and error messages in queue', async () => { + await protocol.connect(transport); + + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const testProtocol = protocol as unknown as TestProtocol; + + // Set up resolvers for multiple requests + const resolver1 = vi.fn(); + const resolver2 = vi.fn(); + const resolver3 = vi.fn(); + + testProtocol._requestResolvers.set(1, resolver1); + testProtocol._requestResolvers.set(2, resolver2); + testProtocol._requestResolvers.set(3, resolver3); + + // Enqueue mixed messages: response, error, response + await taskMessageQueue.enqueue(task.taskId, { + type: 'response', + message: { + jsonrpc: '2.0', + id: 1, + result: { content: [{ type: 'text', text: 'Success' }] } + }, + timestamp: Date.now() + }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'error', + message: { + jsonrpc: '2.0', + id: 2, + error: { + code: ErrorCode.InvalidRequest, + message: 'Request failed' + } + }, + timestamp: Date.now() + }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'response', + message: { + jsonrpc: '2.0', + id: 3, + result: { content: [{ type: 'text', text: 'Another success' }] } + }, + timestamp: Date.now() + }); + + // Process all messages + let msg; + while ((msg = await taskMessageQueue.dequeue(task.taskId))) { + if (msg.type === 'response') { + const responseMessage = msg.message as JSONRPCResponse; + const requestId = responseMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(requestId); + if (resolver) { + testProtocol._requestResolvers.delete(requestId); + resolver(responseMessage); + } + } else if (msg.type === 'error') { + const errorMessage = msg.message as JSONRPCError; + const requestId = errorMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(requestId); + if (resolver) { + testProtocol._requestResolvers.delete(requestId); + const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + resolver(error); + } + } + } + + // Verify all resolvers were called correctly + expect(resolver1).toHaveBeenCalledWith(expect.objectContaining({ id: 1 })); + expect(resolver2).toHaveBeenCalledWith(expect.any(McpError)); + expect(resolver3).toHaveBeenCalledWith(expect.objectContaining({ id: 3 })); + + // Verify error has correct properties + const error = resolver2.mock.calls[0][0]; + expect(error.code).toBe(ErrorCode.InvalidRequest); + expect(error.message).toContain('Request failed'); + + // Verify all resolvers were removed + expect(testProtocol._requestResolvers.size).toBe(0); + }); + + it('should maintain FIFO order when processing responses and errors', async () => { + await protocol.connect(transport); + + const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); + const testProtocol = protocol as unknown as TestProtocol; + + const callOrder: number[] = []; + const resolver1 = vi.fn(() => callOrder.push(1)); + const resolver2 = vi.fn(() => callOrder.push(2)); + const resolver3 = vi.fn(() => callOrder.push(3)); + + testProtocol._requestResolvers.set(1, resolver1); + testProtocol._requestResolvers.set(2, resolver2); + testProtocol._requestResolvers.set(3, resolver3); + + // Enqueue in specific order + await taskMessageQueue.enqueue(task.taskId, { + type: 'response', + message: { jsonrpc: '2.0', id: 1, result: {} }, + timestamp: 1000 + }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'error', + message: { + jsonrpc: '2.0', + id: 2, + error: { code: -32600, message: 'Error' } + }, + timestamp: 2000 + }); + + await taskMessageQueue.enqueue(task.taskId, { + type: 'response', + message: { jsonrpc: '2.0', id: 3, result: {} }, + timestamp: 3000 + }); + + // Process all messages + let msg; + while ((msg = await taskMessageQueue.dequeue(task.taskId))) { + if (msg.type === 'response') { + const responseMessage = msg.message as JSONRPCResponse; + const requestId = responseMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(requestId); + if (resolver) { + testProtocol._requestResolvers.delete(requestId); + resolver(responseMessage); + } + } else if (msg.type === 'error') { + const errorMessage = msg.message as JSONRPCError; + const requestId = errorMessage.id as RequestId; + const resolver = testProtocol._requestResolvers.get(requestId); + if (resolver) { + testProtocol._requestResolvers.delete(requestId); + const error = new McpError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); + resolver(error); + } + } + } + + // Verify FIFO order was maintained + expect(callOrder).toEqual([1, 2, 3]); + }); + }); +}); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index add69163c..15d74fe7f 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -2,7 +2,17 @@ import { AnySchema, AnyObjectSchema, SchemaOutput, safeParse } from '../server/z import { CancelledNotificationSchema, ClientCapabilities, + CreateTaskResultSchema, ErrorCode, + GetTaskRequest, + GetTaskRequestSchema, + GetTaskResultSchema, + GetTaskPayloadRequest, + GetTaskPayloadRequestSchema, + ListTasksRequestSchema, + ListTasksResultSchema, + CancelTaskRequestSchema, + CancelTaskResultSchema, isJSONRPCError, isJSONRPCRequest, isJSONRPCResponse, @@ -17,17 +27,27 @@ import { Progress, ProgressNotification, ProgressNotificationSchema, + RELATED_TASK_META_KEY, Request, RequestId, Result, ServerCapabilities, RequestMeta, MessageExtraInfo, - RequestInfo + RequestInfo, + GetTaskResult, + TaskCreationParams, + RelatedTaskMetadata, + CancelledNotification, + Task, + TaskStatusNotification, + TaskStatusNotificationSchema } from '../types.js'; import { Transport, TransportSendOptions } from './transport.js'; import { AuthInfo } from '../server/auth/types.js'; +import { isTerminal, TaskStore, TaskMessageQueue, QueuedMessage, CreateTaskOptions } from './task.js'; import { getMethodLiteral, parseWithCompat } from '../server/zod-json-schema-compat.js'; +import { ResponseMessage } from './responseMessage.js'; /** * Callback for progress notifications. @@ -53,6 +73,29 @@ export type ProtocolOptions = { * e.g., ['notifications/tools/list_changed'] */ debouncedNotificationMethods?: string[]; + /** + * Optional task storage implementation. If provided, enables task-related request handlers + * and provides task storage capabilities to request handlers. + */ + taskStore?: TaskStore; + /** + * Optional task message queue implementation for managing server-initiated messages + * that will be delivered through the tasks/result response stream. + */ + taskMessageQueue?: TaskMessageQueue; + /** + * Default polling interval (in milliseconds) for task status checks when no pollInterval + * is provided by the server. Defaults to 5000ms if not specified. + */ + defaultTaskPollInterval?: number; + /** + * Maximum number of messages that can be queued per task for side-channel delivery. + * If undefined, the queue size is unbounded. + * When the limit is exceeded, the TaskMessageQueue implementation's enqueue() method + * will throw an error. It's the implementation's responsibility to handle overflow + * appropriately (e.g., by failing the task, dropping messages, etc.). + */ + maxTaskQueueSize?: number; }; /** @@ -66,6 +109,8 @@ export const DEFAULT_REQUEST_TIMEOUT_MSEC = 60000; export type RequestOptions = { /** * If set, requests progress notifications from the remote end (if supported). When progress notifications are received, this callback will be invoked. + * + * For task-augmented requests: progress notifications continue after CreateTaskResult is returned and stop automatically when the task reaches a terminal status. */ onprogress?: ProgressCallback; @@ -94,6 +139,16 @@ export type RequestOptions = { * If not specified, there is no maximum total timeout. */ maxTotalTimeout?: number; + + /** + * If provided, augments the request with task creation parameters to enable call-now, fetch-later execution patterns. + */ + task?: TaskCreationParams; + + /** + * If provided, associates this request with a related task. + */ + relatedTask?: RelatedTaskMetadata; } & TransportSendOptions; /** @@ -104,8 +159,76 @@ export type NotificationOptions = { * May be used to indicate to the transport which incoming request to associate this outgoing notification with. */ relatedRequestId?: RequestId; + + /** + * If provided, associates this notification with a related task. + */ + relatedTask?: RelatedTaskMetadata; }; +/** + * Options that can be given per request. + */ +// relatedTask is excluded as the SDK controls if this is sent according to if the source is a task. +export type TaskRequestOptions = Omit; + +/** + * Request-scoped TaskStore interface. + */ +export interface RequestTaskStore { + /** + * Creates a new task with the given creation parameters. + * The implementation generates a unique taskId and createdAt timestamp. + * + * @param taskParams - The task creation parameters from the request + * @returns The created task object + */ + createTask(taskParams: CreateTaskOptions): Promise; + + /** + * Gets the current status of a task. + * + * @param taskId - The task identifier + * @returns The task object + * @throws If the task does not exist + */ + getTask(taskId: string): Promise; + + /** + * Stores the result of a task and sets its final status. + * + * @param taskId - The task identifier + * @param status - The final status: 'completed' for success, 'failed' for errors + * @param result - The result to store + */ + storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result): Promise; + + /** + * Retrieves the stored result of a task. + * + * @param taskId - The task identifier + * @returns The stored result + */ + getTaskResult(taskId: string): Promise; + + /** + * Updates a task's status (e.g., to 'cancelled', 'failed', 'completed'). + * + * @param taskId - The task identifier + * @param status - The new status + * @param statusMessage - Optional diagnostic message for failed tasks or other status information + */ + updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string): Promise; + + /** + * Lists tasks, optionally starting from a pagination cursor. + * + * @param cursor - Optional cursor for pagination + * @returns An object containing the tasks array and an optional nextCursor + */ + listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; +} + /** * Extra data given to request handlers. */ @@ -136,6 +259,12 @@ export type RequestHandlerExtra(request: SendRequestT, resultSchema: U, options?: RequestOptions) => Promise>; + sendRequest: (request: SendRequestT, resultSchema: U, options?: TaskRequestOptions) => Promise>; }; /** @@ -186,6 +315,14 @@ export abstract class Protocol = new Map(); private _pendingDebouncedNotifications = new Set(); + // Maps task IDs to progress tokens to keep handlers alive after CreateTaskResult + private _taskProgressTokens: Map = new Map(); + + private _taskStore?: TaskStore; + private _taskMessageQueue?: TaskMessageQueue; + + private _requestResolvers: Map void> = new Map(); + /** * Callback for when the connection is closed for any reason. * @@ -212,8 +349,7 @@ export abstract class Protocol { - const controller = this._requestHandlerAbortControllers.get(notification.params.requestId); - controller?.abort(notification.params.reason); + this._oncancel(notification); }); this.setNotificationHandler(ProgressNotificationSchema, notification => { @@ -225,6 +361,186 @@ export abstract class Protocol ({}) as SendResultT ); + + // Install task handlers if TaskStore is provided + this._taskStore = _options?.taskStore; + this._taskMessageQueue = _options?.taskMessageQueue; + if (this._taskStore) { + this.setRequestHandler(GetTaskRequestSchema, async (request, extra) => { + const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId); + if (!task) { + throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); + } + + // Per spec: tasks/get responses SHALL NOT include related-task metadata + // as the taskId parameter is the source of truth + // @ts-expect-error SendResultT cannot contain GetTaskResult, but we include it in our derived types everywhere else + return { + ...task + } as SendResultT; + }); + + this.setRequestHandler(GetTaskPayloadRequestSchema, async (request, extra) => { + const handleTaskResult = async (): Promise => { + const taskId = request.params.taskId; + + // Deliver queued messages + if (this._taskMessageQueue) { + let queuedMessage: QueuedMessage | undefined; + while ((queuedMessage = await this._taskMessageQueue.dequeue(taskId, extra.sessionId))) { + // Handle response and error messages by routing them to the appropriate resolver + if (queuedMessage.type === 'response' || queuedMessage.type === 'error') { + const message = queuedMessage.message; + const requestId = message.id; + + // Lookup resolver in _requestResolvers map + const resolver = this._requestResolvers.get(requestId); + + if (resolver) { + // Remove resolver from map after invocation + this._requestResolvers.delete(requestId); + + // Invoke resolver with response or error + if (queuedMessage.type === 'response') { + resolver(message as JSONRPCResponse); + } else { + // Convert JSONRPCError to McpError + const errorMessage = message as JSONRPCError; + const error = new McpError( + errorMessage.error.code, + errorMessage.error.message, + errorMessage.error.data + ); + resolver(error); + } + } else { + // Handle missing resolver gracefully with error logging + const messageType = queuedMessage.type === 'response' ? 'Response' : 'Error'; + this._onerror(new Error(`${messageType} handler missing for request ${requestId}`)); + } + + // Continue to next message + continue; + } + + // At this point, message must be a request or notification (not a response) + // Strip relatedTask metadata when dequeuing for delivery + // The metadata was used for queuing, but shouldn't be sent to the client + const messageToSend = { ...queuedMessage.message }; + if (messageToSend.params?._meta?.[RELATED_TASK_META_KEY]) { + const metaCopy = { ...messageToSend.params._meta }; + delete metaCopy[RELATED_TASK_META_KEY]; + messageToSend.params = { + ...messageToSend.params, + _meta: metaCopy + }; + } + + // Send the message on the response stream by passing the relatedRequestId + // This tells the transport to write the message to the tasks/result response stream + await this._transport?.send(messageToSend, { relatedRequestId: extra.requestId }); + } + } + + // Now check task status + const task = await this._taskStore!.getTask(taskId, extra.sessionId); + if (!task) { + throw new McpError(ErrorCode.InvalidParams, `Task not found: ${taskId}`); + } + + // Block if task is not terminal (we've already delivered all queued messages above) + if (!isTerminal(task.status)) { + // Wait for status change or new messages + await this._waitForTaskUpdate(taskId, extra.signal); + + // After waking up, recursively call to deliver any new messages or result + return await handleTaskResult(); + } + + // If task is terminal, return the result + if (isTerminal(task.status)) { + const result = await this._taskStore!.getTaskResult(taskId, extra.sessionId); + + this._clearTaskQueue(taskId); + + return { + ...result, + _meta: { + ...result._meta, + [RELATED_TASK_META_KEY]: { + taskId: taskId + } + } + } as SendResultT; + } + + return await handleTaskResult(); + }; + + return await handleTaskResult(); + }); + + this.setRequestHandler(ListTasksRequestSchema, async (request, extra) => { + try { + const { tasks, nextCursor } = await this._taskStore!.listTasks(request.params?.cursor, extra.sessionId); + // @ts-expect-error SendResultT cannot contain ListTasksResult, but we include it in our derived types everywhere else + return { + tasks, + nextCursor, + _meta: {} + } as SendResultT; + } catch (error) { + throw new McpError( + ErrorCode.InvalidParams, + `Failed to list tasks: ${error instanceof Error ? error.message : String(error)}` + ); + } + }); + + this.setRequestHandler(CancelTaskRequestSchema, async (request, extra) => { + try { + // Get the current task to check if it's in a terminal state, in case the implementation is not atomic + const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId); + + if (!task) { + throw new McpError(ErrorCode.InvalidParams, `Task not found: ${request.params.taskId}`); + } + + // Reject cancellation of terminal tasks + if (isTerminal(task.status)) { + throw new McpError(ErrorCode.InvalidParams, `Cannot cancel task in terminal status: ${task.status}`); + } + + await this._taskStore!.updateTaskStatus( + request.params.taskId, + 'cancelled', + 'Client cancelled task execution.', + extra.sessionId + ); + + this._clearTaskQueue(request.params.taskId); + + return { + _meta: {} + } as SendResultT; + } catch (error) { + // Re-throw McpError as-is + if (error instanceof McpError) { + throw error; + } + throw new McpError( + ErrorCode.InvalidRequest, + `Failed to cancel task: ${error instanceof Error ? error.message : String(error)}` + ); + } + }); + } + } + + private async _oncancel(notification: CancelledNotification): Promise { + // Handle request cancellation + const controller = this._requestHandlerAbortControllers.get(notification.params.requestId); + controller?.abort(notification.params.reason); } private _setupTimeout( @@ -310,11 +626,14 @@ export abstract class Protocol = { signal: abortController.signal, sessionId: capturedTransport?.sessionId, _meta: request.params?._meta, - sendNotification: notification => this.notification(notification, { relatedRequestId: request.id }), - sendRequest: (r, resultSchema, options?) => this.request(r, resultSchema, { ...options, relatedRequestId: request.id }), + sendNotification: async notification => { + // Include related-task metadata if this request is part of a task + const notificationOptions: NotificationOptions = { relatedRequestId: request.id }; + if (relatedTaskId) { + notificationOptions.relatedTask = { taskId: relatedTaskId }; + } + await this.notification(notification, notificationOptions); + }, + sendRequest: async (r, resultSchema, options?) => { + // Include related-task metadata if this request is part of a task + const requestOptions: RequestOptions = { ...options, relatedRequestId: request.id }; + if (relatedTaskId && !requestOptions.relatedTask) { + requestOptions.relatedTask = { taskId: relatedTaskId }; + } + return await this.request(r, resultSchema, requestOptions); + }, authInfo: extra?.authInfo, requestId: request.id, - requestInfo: extra?.requestInfo + requestInfo: extra?.requestInfo, + taskId: relatedTaskId, + taskStore: taskStore, + taskRequestedTtl: taskCreationParams?.ttl }; // Starting with Promise.resolve() puts any synchronous errors into the monad as well. Promise.resolve() + .then(() => { + // If this request asked for task creation, check capability first + if (taskCreationParams) { + // Check if the request method supports task creation + this.assertTaskHandlerCapability(request.method); + } + }) .then(() => handler(request, fullExtra)) .then( - result => { + async result => { if (abortController.signal.aborted) { + // Request was cancelled return; } - return capturedTransport?.send({ + // Send the response + await capturedTransport?.send({ result, jsonrpc: '2.0', id: request.id }); }, - error => { + async error => { if (abortController.signal.aborted) { + // Request was cancelled return; } @@ -426,6 +778,10 @@ export abstract class Protocol; + if (result.task && typeof result.task === 'object') { + const task = result.task as Record; + if (typeof task.taskId === 'string') { + isTaskResponse = true; + this._taskProgressTokens.set(task.taskId, messageId); + } + } + } + + if (!isTaskResponse) { + this._progressHandlers.delete(messageId); + } + if (isJSONRPCResponse(response)) { handler(response); } else { @@ -487,21 +873,152 @@ export abstract class Protocol( + request: SendRequestT, + resultSchema: T, + options?: RequestOptions + ): AsyncGenerator>, void, void> { + const { task } = options ?? {}; + + // For non-task requests, just yield the result + if (!task) { + try { + const result = await this.request(request, resultSchema, options); + yield { type: 'result', result }; + } catch (error) { + yield { + type: 'error', + error: error instanceof McpError ? error : new McpError(ErrorCode.InternalError, String(error)) + }; + } + return; + } + + // For task-augmented requests, we need to poll for status + // First, make the request to create the task + let taskId: string | undefined; + try { + // Send the request and get the CreateTaskResult + const createResult = await this.request(request, CreateTaskResultSchema, options); + + // Extract taskId from the result + if (createResult.task) { + taskId = createResult.task.taskId; + yield { type: 'taskCreated', task: createResult.task }; + } else { + throw new McpError(ErrorCode.InternalError, 'Task creation did not return a task'); + } + + // Poll for task completion + while (true) { + // Get current task status + const task = await this.getTask({ taskId }, options); + yield { type: 'taskStatus', task }; + + // Check if task is terminal + if (isTerminal(task.status)) { + if (task.status === 'completed') { + // Get the final result + const result = await this.getTaskResult({ taskId }, resultSchema, options); + yield { type: 'result', result }; + } else if (task.status === 'failed') { + yield { + type: 'error', + error: new McpError(ErrorCode.InternalError, `Task ${taskId} failed`) + }; + } else if (task.status === 'cancelled') { + yield { + type: 'error', + error: new McpError(ErrorCode.InternalError, `Task ${taskId} was cancelled`) + }; + } + return; + } + + // Wait before polling again + const pollInterval = task.pollInterval ?? this._options?.defaultTaskPollInterval ?? 1000; + await new Promise(resolve => setTimeout(resolve, pollInterval)); + + // Check if cancelled + options?.signal?.throwIfAborted(); + } + } catch (error) { + yield { + type: 'error', + error: error instanceof McpError ? error : new McpError(ErrorCode.InternalError, String(error)) + }; + } + } + + /** + * Sends a request and waits for a response. * * Do not use this method to emit notifications! Use notification() instead. */ request(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise> { - const { relatedRequestId, resumptionToken, onresumptiontoken } = options ?? {}; + const { relatedRequestId, resumptionToken, onresumptiontoken, task, relatedTask } = options ?? {}; + + // Send the request + return new Promise>((resolve, reject) => { + const earlyReject = (error: unknown) => { + reject(error); + }; - return new Promise((resolve, reject) => { if (!this._transport) { - reject(new Error('Not connected')); + earlyReject(new Error('Not connected')); return; } if (this._options?.enforceStrictCapabilities === true) { - this.assertCapabilityForMethod(request.method); + try { + this.assertCapabilityForMethod(request.method); + + // If task creation is requested, also check task capabilities + if (task) { + this.assertTaskCapability(request.method); + } + } catch (e) { + earlyReject(e); + return; + } } options?.signal?.throwIfAborted(); @@ -524,6 +1041,25 @@ export abstract class Protocol { this._responseHandlers.delete(messageId); this._progressHandlers.delete(messageId); @@ -543,7 +1079,9 @@ export abstract class Protocol this._onerror(new Error(`Failed to send cancellation: ${error}`))); - reject(reason); + // Wrap the reason in an McpError if it isn't already + const error = reason instanceof McpError ? reason : new McpError(ErrorCode.RequestTimeout, String(reason)); + reject(error); }; this._responseHandlers.set(messageId, response => { @@ -577,13 +1115,78 @@ export abstract class Protocol { - this._cleanupTimeout(messageId); - reject(error); - }); + // Queue request if related to a task + const relatedTaskId = relatedTask?.taskId; + if (relatedTaskId) { + // Store the response resolver for this request so responses can be routed back + const responseResolver = (response: JSONRPCResponse | Error) => { + const handler = this._responseHandlers.get(messageId); + if (handler) { + handler(response); + } else { + // Log error when resolver is missing, but don't fail + this._onerror(new Error(`Response handler missing for side-channeled request ${messageId}`)); + } + }; + this._requestResolvers.set(messageId, responseResolver); + + this._enqueueTaskMessage(relatedTaskId, { + type: 'request', + message: jsonrpcRequest, + timestamp: Date.now() + }).catch(error => { + this._cleanupTimeout(messageId); + reject(error); + }); + + // Don't send through transport - queued messages are delivered via tasks/result only + // This prevents duplicate delivery for bidirectional transports + } else { + // No related task - send through transport normally + this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { + this._cleanupTimeout(messageId); + reject(error); + }); + } }); } + /** + * Gets the current status of a task. + */ + async getTask(params: GetTaskRequest['params'], options?: RequestOptions): Promise { + // @ts-expect-error SendRequestT cannot directly contain GetTaskRequest, but we ensure all type instantiations contain it anyways + return this.request({ method: 'tasks/get', params }, GetTaskResultSchema, options); + } + + /** + * Retrieves the result of a completed task. + */ + async getTaskResult( + params: GetTaskPayloadRequest['params'], + resultSchema: T, + options?: RequestOptions + ): Promise> { + // @ts-expect-error SendRequestT cannot directly contain GetTaskPayloadRequest, but we ensure all type instantiations contain it anyways + return this.request({ method: 'tasks/result', params }, resultSchema, options); + } + + /** + * Lists tasks, optionally starting from a pagination cursor. + */ + async listTasks(params?: { cursor?: string }, options?: RequestOptions): Promise> { + // @ts-expect-error SendRequestT cannot directly contain ListTasksRequest, but we ensure all type instantiations contain it anyways + return this.request({ method: 'tasks/list', params }, ListTasksResultSchema, options); + } + + /** + * Cancels a specific task. + */ + async cancelTask(params: { taskId: string }, options?: RequestOptions): Promise> { + // @ts-expect-error SendRequestT cannot directly contain CancelTaskRequest, but we ensure all type instantiations contain it anyways + return this.request({ method: 'tasks/cancel', params }, CancelTaskResultSchema, options); + } + /** * Emits a notification, which is a one-way message that does not expect a response. */ @@ -594,10 +1197,38 @@ export abstract class Protocol this._onerror(error)); @@ -632,11 +1278,25 @@ export abstract class Protocol { + // Task message queues are only used when taskStore is configured + if (!this._taskStore || !this._taskMessageQueue) { + throw new Error('Cannot enqueue task message: taskStore and taskMessageQueue are not configured'); + } + + const maxQueueSize = this._options?.maxTaskQueueSize; + await this._taskMessageQueue.enqueue(taskId, message, sessionId, maxQueueSize); + } + + /** + * Clears the message queue for a task and rejects any pending request resolvers. + * @param taskId The task ID whose queue should be cleared + * @param sessionId Optional session ID for binding the operation to a specific session + */ + private async _clearTaskQueue(taskId: string, sessionId?: string): Promise { + if (this._taskMessageQueue) { + // Reject any pending request resolvers + const messages = await this._taskMessageQueue.dequeueAll(taskId, sessionId); + for (const message of messages) { + if (message.type === 'request' && isJSONRPCRequest(message.message)) { + // Extract request ID from the message + const requestId = message.message.id as RequestId; + const resolver = this._requestResolvers.get(requestId); + if (resolver) { + resolver(new McpError(ErrorCode.InternalError, 'Task cancelled or completed')); + this._requestResolvers.delete(requestId); + } else { + // Log error when resolver is missing during cleanup for better observability + this._onerror(new Error(`Resolver missing for request ${requestId} during task ${taskId} cleanup`)); + } + } + } + } + } + + /** + * Waits for a task update (new messages or status change) with abort signal support. + * Uses polling to check for updates at the task's configured poll interval. + * @param taskId The task ID to wait for + * @param signal Abort signal to cancel the wait + * @returns Promise that resolves when an update occurs or rejects if aborted + */ + private async _waitForTaskUpdate(taskId: string, signal: AbortSignal): Promise { + // Get the task's poll interval, falling back to default + let interval = this._options?.defaultTaskPollInterval ?? 1000; + try { + const task = await this._taskStore?.getTask(taskId); + if (task?.pollInterval) { + interval = task.pollInterval; + } + } catch { + // Use default interval if task lookup fails + } + + return new Promise((resolve, reject) => { + if (signal.aborted) { + reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled')); + return; + } + + // Wait for the poll interval, then resolve so caller can check for updates + const timeoutId = setTimeout(resolve, interval); + + // Clean up timeout and reject if aborted + signal.addEventListener( + 'abort', + () => { + clearTimeout(timeoutId); + reject(new McpError(ErrorCode.InvalidRequest, 'Request cancelled')); + }, + { once: true } + ); + }); + } + + private requestTaskStore(request?: JSONRPCRequest, sessionId?: string): RequestTaskStore { + const taskStore = this._taskStore; + if (!taskStore) { + throw new Error('No task store configured'); + } + + return { + createTask: async taskParams => { + if (!request) { + throw new Error('No request provided'); + } + + return await taskStore.createTask( + taskParams, + request.id, + { + method: request.method, + params: request.params + }, + sessionId + ); + }, + getTask: async taskId => { + const task = await taskStore.getTask(taskId, sessionId); + if (!task) { + throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); + } + + return task; + }, + storeTaskResult: async (taskId, status, result) => { + await taskStore.storeTaskResult(taskId, status, result, sessionId); + + // Get updated task state and send notification + const task = await taskStore.getTask(taskId, sessionId); + if (task) { + const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ + method: 'notifications/tasks/status', + params: task + }); + await this.notification(notification as SendNotificationT); + + if (isTerminal(task.status)) { + this._cleanupTaskProgressHandler(taskId); + // Don't clear queue here - it will be cleared after delivery via tasks/result + } + } + }, + getTaskResult: taskId => { + return taskStore.getTaskResult(taskId, sessionId); + }, + updateTaskStatus: async (taskId, status, statusMessage) => { + try { + // Check if task is in terminal state before attempting to update + const task = await taskStore.getTask(taskId, sessionId); + if (!task) { + return; + } + + // Don't allow transitions from terminal states + if (isTerminal(task.status)) { + this._onerror( + new Error( + `Cannot update task "${taskId}" from terminal status "${task.status}" to "${status}". Terminal states (completed, failed, cancelled) cannot transition to other states.` + ) + ); + return; + } + + await taskStore.updateTaskStatus(taskId, status, statusMessage, sessionId); + + // Get updated task state and send notification + const updatedTask = await taskStore.getTask(taskId, sessionId); + if (updatedTask) { + const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ + method: 'notifications/tasks/status', + params: updatedTask + }); + await this.notification(notification as SendNotificationT); + + if (isTerminal(updatedTask.status)) { + this._cleanupTaskProgressHandler(taskId); + // Don't clear queue here - it will be cleared after delivery via tasks/result + } + } + } catch (error) { + throw new Error(`Failed to update status of task "${taskId}" to "${status}": ${error}`); + } + }, + listTasks: cursor => { + return taskStore.listTasks(cursor, sessionId); + } + }; + } } function isPlainObject(value: unknown): value is Record { diff --git a/src/shared/responseMessage.ts b/src/shared/responseMessage.ts new file mode 100644 index 000000000..6fefcf1f6 --- /dev/null +++ b/src/shared/responseMessage.ts @@ -0,0 +1,70 @@ +import { Result, Task, McpError } from '../types.js'; + +/** + * Base message type + */ +export interface BaseResponseMessage { + type: string; +} + +/** + * Task status update message + */ +export interface TaskStatusMessage extends BaseResponseMessage { + type: 'taskStatus'; + task: Task; +} + +/** + * Task created message (first message for task-augmented requests) + */ +export interface TaskCreatedMessage extends BaseResponseMessage { + type: 'taskCreated'; + task: Task; +} + +/** + * Final result message (terminal) + */ +export interface ResultMessage extends BaseResponseMessage { + type: 'result'; + result: T; +} + +/** + * Error message (terminal) + */ +export interface ErrorMessage extends BaseResponseMessage { + type: 'error'; + error: McpError; +} + +/** + * Union type representing all possible messages that can be yielded during request processing. + * Note: Progress notifications are handled through the existing onprogress callback mechanism. + * Side-channeled messages (server requests/notifications) are handled through registered handlers. + */ +export type ResponseMessage = TaskStatusMessage | TaskCreatedMessage | ResultMessage | ErrorMessage; + +export type AsyncGeneratorValue = T extends AsyncGenerator ? U : never; + +export async function toArrayAsync>(it: T): Promise[]> { + const arr: AsyncGeneratorValue[] = []; + for await (const o of it) { + arr.push(o as AsyncGeneratorValue); + } + + return arr; +} + +export async function takeResult>>(it: U): Promise { + for await (const o of it) { + if (o.type === 'result') { + return o.result; + } else if (o.type === 'error') { + throw o.error; + } + } + + throw new Error('No result in stream.'); +} diff --git a/src/shared/task-listing.test.ts b/src/shared/task-listing.test.ts new file mode 100644 index 000000000..975706070 --- /dev/null +++ b/src/shared/task-listing.test.ts @@ -0,0 +1,168 @@ +import { describe, it, expect, beforeEach, afterEach } from 'vitest'; +import { InMemoryTransport } from '../inMemory.js'; +import { Client } from '../client/index.js'; +import { Server } from '../server/index.js'; +import { InMemoryTaskStore, InMemoryTaskMessageQueue } from '../examples/shared/inMemoryTaskStore.js'; + +describe('Task Listing with Pagination', () => { + let client: Client; + let server: Server; + let taskStore: InMemoryTaskStore; + let clientTransport: InMemoryTransport; + let serverTransport: InMemoryTransport; + + beforeEach(async () => { + taskStore = new InMemoryTaskStore(); + + [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + list: {}, + requests: { + tools: { + call: {} + } + } + } + } + } + ); + + server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tasks: { + list: {}, + requests: { + tools: { + call: {} + } + } + } + }, + taskStore, + taskMessageQueue: new InMemoryTaskMessageQueue() + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + }); + + afterEach(async () => { + taskStore.cleanup(); + await client.close(); + await server.close(); + }); + + it('should return empty list when no tasks exist', async () => { + const result = await client.listTasks(); + + expect(result.tasks).toEqual([]); + expect(result.nextCursor).toBeUndefined(); + }); + + it('should return all tasks when less than page size', async () => { + // Create 3 tasks + for (let i = 0; i < 3; i++) { + await taskStore.createTask({}, i, { + method: 'tools/call', + params: { name: 'test-tool' } + }); + } + + const result = await client.listTasks(); + + expect(result.tasks).toHaveLength(3); + expect(result.nextCursor).toBeUndefined(); + }); + + it('should paginate when more than page size exists', async () => { + // Create 15 tasks (page size is 10 in InMemoryTaskStore) + for (let i = 0; i < 15; i++) { + await taskStore.createTask({}, i, { + method: 'tools/call', + params: { name: 'test-tool' } + }); + } + + // Get first page + const page1 = await client.listTasks(); + expect(page1.tasks).toHaveLength(10); + expect(page1.nextCursor).toBeDefined(); + + // Get second page using cursor + const page2 = await client.listTasks({ cursor: page1.nextCursor }); + expect(page2.tasks).toHaveLength(5); + expect(page2.nextCursor).toBeUndefined(); + }); + + it('should treat cursor as opaque token', async () => { + // Create 5 tasks + for (let i = 0; i < 5; i++) { + await taskStore.createTask({}, i, { + method: 'tools/call', + params: { name: 'test-tool' } + }); + } + + // Get all tasks to get a valid cursor + const allTasks = taskStore.getAllTasks(); + const validCursor = allTasks[2].taskId; + + // Use the cursor - should work even though we don't know its internal structure + const result = await client.listTasks({ cursor: validCursor }); + expect(result.tasks).toHaveLength(2); + }); + + it('should return error for invalid cursor', async () => { + await taskStore.createTask({}, 1, { + method: 'tools/call', + params: { name: 'test-tool' } + }); + + // Try to use an invalid cursor + await expect(client.listTasks({ cursor: 'invalid-cursor' })).rejects.toThrow(); + }); + + it('should ensure tasks accessible via tasks/get are also accessible via tasks/list', async () => { + // Create a task + const task = await taskStore.createTask({}, 1, { + method: 'tools/call', + params: { name: 'test-tool' } + }); + + // Verify it's accessible via tasks/get + const getResult = await client.getTask({ taskId: task.taskId }); + expect(getResult.taskId).toBe(task.taskId); + + // Verify it's also accessible via tasks/list + const listResult = await client.listTasks(); + expect(listResult.tasks).toHaveLength(1); + expect(listResult.tasks[0].taskId).toBe(task.taskId); + }); + + it('should not include related-task metadata in list response', async () => { + // Create a task + await taskStore.createTask({}, 1, { + method: 'tools/call', + params: { name: 'test-tool' } + }); + + const result = await client.listTasks(); + + // The response should have _meta but not include related-task metadata + expect(result._meta).toBeDefined(); + expect(result._meta?.['io.modelcontextprotocol/related-task']).toBeUndefined(); + }); +}); diff --git a/src/shared/task.test.ts b/src/shared/task.test.ts new file mode 100644 index 000000000..4d21e3dc3 --- /dev/null +++ b/src/shared/task.test.ts @@ -0,0 +1,117 @@ +import { describe, it, expect } from 'vitest'; +import { isTerminal } from './task.js'; +import type { Task } from '../types.js'; + +describe('Task utility functions', () => { + describe('isTerminal', () => { + it('should return true for completed status', () => { + expect(isTerminal('completed')).toBe(true); + }); + + it('should return true for failed status', () => { + expect(isTerminal('failed')).toBe(true); + }); + + it('should return true for cancelled status', () => { + expect(isTerminal('cancelled')).toBe(true); + }); + + it('should return false for working status', () => { + expect(isTerminal('working')).toBe(false); + }); + + it('should return false for input_required status', () => { + expect(isTerminal('input_required')).toBe(false); + }); + }); +}); + +describe('Task Schema Validation', () => { + it('should validate task with ttl field', () => { + const createdAt = new Date().toISOString(); + const task: Task = { + taskId: 'test-123', + status: 'working', + ttl: 60000, + createdAt, + lastUpdatedAt: createdAt, + pollInterval: 1000 + }; + + expect(task.ttl).toBe(60000); + expect(task.createdAt).toBeDefined(); + expect(typeof task.createdAt).toBe('string'); + }); + + it('should validate task with null ttl', () => { + const createdAt = new Date().toISOString(); + const task: Task = { + taskId: 'test-456', + status: 'completed', + ttl: null, + createdAt, + lastUpdatedAt: createdAt + }; + + expect(task.ttl).toBeNull(); + }); + + it('should validate task with statusMessage field', () => { + const createdAt = new Date().toISOString(); + const task: Task = { + taskId: 'test-789', + status: 'failed', + ttl: null, + createdAt, + lastUpdatedAt: createdAt, + statusMessage: 'Operation failed due to timeout' + }; + + expect(task.statusMessage).toBe('Operation failed due to timeout'); + }); + + it('should validate task with createdAt in ISO 8601 format', () => { + const now = new Date(); + const createdAt = now.toISOString(); + const task: Task = { + taskId: 'test-iso', + status: 'working', + ttl: 30000, + createdAt, + lastUpdatedAt: createdAt + }; + + expect(task.createdAt).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z$/); + expect(new Date(task.createdAt).getTime()).toBe(now.getTime()); + }); + + it('should validate task with lastUpdatedAt in ISO 8601 format', () => { + const now = new Date(); + const createdAt = now.toISOString(); + const task: Task = { + taskId: 'test-iso', + status: 'working', + ttl: 30000, + createdAt, + lastUpdatedAt: createdAt + }; + + expect(task.lastUpdatedAt).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z$/); + }); + + it('should validate all task statuses', () => { + const statuses: Task['status'][] = ['working', 'input_required', 'completed', 'failed', 'cancelled']; + + const createdAt = new Date().toISOString(); + statuses.forEach(status => { + const task: Task = { + taskId: `test-${status}`, + status, + ttl: null, + createdAt, + lastUpdatedAt: createdAt + }; + expect(task.status).toBe(status); + }); + }); +}); diff --git a/src/shared/task.ts b/src/shared/task.ts new file mode 100644 index 000000000..ae4517f6f --- /dev/null +++ b/src/shared/task.ts @@ -0,0 +1,189 @@ +import { Task, Request, RequestId, Result, JSONRPCRequest, JSONRPCNotification, JSONRPCResponse, JSONRPCError } from '../types.js'; + +/** + * Represents a message queued for side-channel delivery via tasks/result. + * + * This is a serializable data structure that can be stored in external systems. + * All fields are JSON-serializable. + */ +export type QueuedMessage = QueuedRequest | QueuedNotification | QueuedResponse | QueuedError; + +export interface BaseQueuedMessage { + /** Type of message */ + type: string; + /** When the message was queued (milliseconds since epoch) */ + timestamp: number; +} + +export interface QueuedRequest extends BaseQueuedMessage { + type: 'request'; + /** The actual JSONRPC request */ + message: JSONRPCRequest; +} + +export interface QueuedNotification extends BaseQueuedMessage { + type: 'notification'; + /** The actual JSONRPC notification */ + message: JSONRPCNotification; +} + +export interface QueuedResponse extends BaseQueuedMessage { + type: 'response'; + /** The actual JSONRPC response */ + message: JSONRPCResponse; +} + +export interface QueuedError extends BaseQueuedMessage { + type: 'error'; + /** The actual JSONRPC error */ + message: JSONRPCError; +} + +/** + * Interface for managing per-task FIFO message queues. + * + * Similar to TaskStore, this allows pluggable queue implementations + * (in-memory, Redis, other distributed queues, etc.). + * + * Each method accepts taskId and optional sessionId parameters to enable + * a single queue instance to manage messages for multiple tasks, with + * isolation based on task ID and session ID. + * + * All methods are async to support external storage implementations. + * All data in QueuedMessage must be JSON-serializable. + */ +export interface TaskMessageQueue { + /** + * Adds a message to the end of the queue for a specific task. + * Atomically checks queue size and throws if maxSize would be exceeded. + * @param taskId The task identifier + * @param message The message to enqueue + * @param sessionId Optional session ID for binding the operation to a specific session + * @param maxSize Optional maximum queue size - if specified and queue is full, throws an error + * @throws Error if maxSize is specified and would be exceeded + */ + enqueue(taskId: string, message: QueuedMessage, sessionId?: string, maxSize?: number): Promise; + + /** + * Removes and returns the first message from the queue for a specific task. + * @param taskId The task identifier + * @param sessionId Optional session ID for binding the query to a specific session + * @returns The first message, or undefined if the queue is empty + */ + dequeue(taskId: string, sessionId?: string): Promise; + + /** + * Removes and returns all messages from the queue for a specific task. + * Used when tasks are cancelled or failed to clean up pending messages. + * @param taskId The task identifier + * @param sessionId Optional session ID for binding the query to a specific session + * @returns Array of all messages that were in the queue + */ + dequeueAll(taskId: string, sessionId?: string): Promise; +} + +/** + * Task creation options. + */ +export interface CreateTaskOptions { + /** + * Time in milliseconds to keep task results available after completion. + * If null, the task has unlimited lifetime until manually cleaned up. + */ + ttl?: number | null; + + /** + * Time in milliseconds to wait between task status requests. + */ + pollInterval?: number; + + /** + * Additional context to pass to the task store. + */ + context?: Record; +} + +/** + * Interface for storing and retrieving task state and results. + * + * Similar to Transport, this allows pluggable task storage implementations + * (in-memory, database, distributed cache, etc.). + */ +export interface TaskStore { + /** + * Creates a new task with the given creation parameters and original request. + * The implementation must generate a unique taskId and createdAt timestamp. + * + * TTL Management: + * - The implementation receives the TTL suggested by the requestor via taskParams.ttl + * - The implementation MAY override the requested TTL (e.g., to enforce limits) + * - The actual TTL used MUST be returned in the Task object + * - Null TTL indicates unlimited task lifetime (no automatic cleanup) + * - Cleanup SHOULD occur automatically after TTL expires, regardless of task status + * + * @param taskParams - The task creation parameters from the request (ttl, pollInterval) + * @param requestId - The JSON-RPC request ID + * @param request - The original request that triggered task creation + * @param sessionId - Optional session ID for binding the task to a specific session + * @returns The created task object + */ + createTask(taskParams: CreateTaskOptions, requestId: RequestId, request: Request, sessionId?: string): Promise; + + /** + * Gets the current status of a task. + * + * @param taskId - The task identifier + * @param sessionId - Optional session ID for binding the query to a specific session + * @returns The task object, or null if it does not exist + */ + getTask(taskId: string, sessionId?: string): Promise; + + /** + * Stores the result of a task and sets its final status. + * + * @param taskId - The task identifier + * @param status - The final status: 'completed' for success, 'failed' for errors + * @param result - The result to store + * @param sessionId - Optional session ID for binding the operation to a specific session + */ + storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result, sessionId?: string): Promise; + + /** + * Retrieves the stored result of a task. + * + * @param taskId - The task identifier + * @param sessionId - Optional session ID for binding the query to a specific session + * @returns The stored result + */ + getTaskResult(taskId: string, sessionId?: string): Promise; + + /** + * Updates a task's status (e.g., to 'cancelled', 'failed', 'completed'). + * + * @param taskId - The task identifier + * @param status - The new status + * @param statusMessage - Optional diagnostic message for failed tasks or other status information + * @param sessionId - Optional session ID for binding the operation to a specific session + */ + updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string, sessionId?: string): Promise; + + /** + * Lists tasks, optionally starting from a pagination cursor. + * + * @param cursor - Optional cursor for pagination + * @param sessionId - Optional session ID for binding the query to a specific session + * @returns An object containing the tasks array and an optional nextCursor + */ + listTasks(cursor?: string, sessionId?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; +} + +/** + * Checks if a task status represents a terminal state. + * Terminal states are those where the task has finished and will not change. + * + * @param status - The task status to check + * @returns True if the status is terminal (completed, failed, or cancelled) + */ +export function isTerminal(status: Task['status']): boolean { + return status === 'completed' || status === 'failed' || status === 'cancelled'; +} diff --git a/src/spec.types.test.ts b/src/spec.types.test.ts index 1c0b6ab5d..14fb039d0 100644 --- a/src/spec.types.test.ts +++ b/src/spec.types.test.ts @@ -67,6 +67,11 @@ type FixSpecClientCapabilities = T extends { elicitation?: object } ? Omit & { elicitation?: Record } : T; +// Targeted fix: in spec, ServerCapabilities needs index signature to match SDK's passthrough +type FixSpecServerCapabilities = T & { [x: string]: unknown }; + +type FixSpecInitializeResult = T extends { capabilities: infer C } ? T & { capabilities: FixSpecServerCapabilities } : T; + type FixSpecInitializeRequestParams = T extends { capabilities: infer C } ? Omit & { capabilities: FixSpecClientCapabilities } : T; @@ -558,7 +563,7 @@ const sdkTypeChecks = { sdk = spec; spec = sdk; }, - InitializeResult: (sdk: SDKTypes.InitializeResult, spec: SpecTypes.InitializeResult) => { + InitializeResult: (sdk: SDKTypes.InitializeResult, spec: FixSpecInitializeResult) => { sdk = spec; spec = sdk; }, @@ -566,7 +571,7 @@ const sdkTypeChecks = { sdk = spec; spec = sdk; }, - ServerCapabilities: (sdk: SDKTypes.ServerCapabilities, spec: SpecTypes.ServerCapabilities) => { + ServerCapabilities: (sdk: SDKTypes.ServerCapabilities, spec: FixSpecServerCapabilities) => { sdk = spec; spec = sdk; }, diff --git a/src/types.ts b/src/types.ts index 5f34ed1b1..49b4fe713 100644 --- a/src/types.ts +++ b/src/types.ts @@ -5,6 +5,8 @@ export const LATEST_PROTOCOL_VERSION = '2025-06-18'; export const DEFAULT_NEGOTIATED_PROTOCOL_VERSION = '2025-03-26'; export const SUPPORTED_PROTOCOL_VERSIONS = [LATEST_PROTOCOL_VERSION, '2025-03-26', '2024-11-05', '2024-10-07']; +export const RELATED_TASK_META_KEY = 'io.modelcontextprotocol/related-task'; + /* JSON-RPC types */ export const JSONRPC_VERSION = '2.0'; @@ -28,17 +30,49 @@ export const ProgressTokenSchema = z.union([z.string(), z.number().int()]); */ export const CursorSchema = z.string(); +/** + * Task creation parameters, used to ask that the server create a task to represent a request. + */ +export const TaskCreationParamsSchema = z.looseObject({ + /** + * Time in milliseconds to keep task results available after completion. + * If null, the task has unlimited lifetime until manually cleaned up. + */ + ttl: z.union([z.number(), z.null()]).optional(), + + /** + * Time in milliseconds to wait between task status requests. + */ + pollInterval: z.number().optional() +}); + +/** + * Task association metadata, used to signal which task a message originated from. + */ +export const RelatedTaskMetadataSchema = z.looseObject({ + taskId: z.string() +}); + const RequestMetaSchema = z.looseObject({ /** * If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications. */ - progressToken: ProgressTokenSchema.optional() + progressToken: ProgressTokenSchema.optional(), + /** + * If specified, this request is related to the provided task. + */ + [RELATED_TASK_META_KEY]: RelatedTaskMetadataSchema.optional() }); /** * Common params for any request. */ const BaseRequestParamsSchema = z.looseObject({ + /** + * If specified, the caller is requesting that the receiver create a task to represent the request. + * Task creation parameters are now at the top level instead of in _meta. + */ + task: TaskCreationParamsSchema.optional(), /** * See [General fields: `_meta`](/specification/draft/basic/index#meta) for notes on `_meta` usage. */ @@ -55,7 +89,15 @@ const NotificationsParamsSchema = z.looseObject({ * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) * for notes on _meta usage. */ - _meta: z.record(z.string(), z.unknown()).optional() + _meta: z + .object({ + /** + * If specified, this notification is related to the provided task. + */ + [RELATED_TASK_META_KEY]: z.optional(RelatedTaskMetadataSchema) + }) + .passthrough() + .optional() }); export const NotificationSchema = z.object({ @@ -68,7 +110,14 @@ export const ResultSchema = z.looseObject({ * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) * for notes on _meta usage. */ - _meta: z.record(z.string(), z.unknown()).optional() + _meta: z + .looseObject({ + /** + * If specified, this result is related to the provided task. + */ + [RELATED_TASK_META_KEY]: RelatedTaskMetadataSchema.optional() + }) + .optional() }); /** @@ -291,6 +340,86 @@ const ElicitationCapabilitySchema = z.preprocess( ) ); +/** + * Task capabilities for clients, indicating which request types support task creation. + */ +export const ClientTasksCapabilitySchema = z + .object({ + /** + * Present if the client supports listing tasks. + */ + list: z.optional(z.object({}).passthrough()), + /** + * Present if the client supports cancelling tasks. + */ + cancel: z.optional(z.object({}).passthrough()), + /** + * Capabilities for task creation on specific request types. + */ + requests: z.optional( + z + .object({ + /** + * Task support for sampling requests. + */ + sampling: z.optional( + z + .object({ + createMessage: z.optional(z.object({}).passthrough()) + }) + .passthrough() + ), + /** + * Task support for elicitation requests. + */ + elicitation: z.optional( + z + .object({ + create: z.optional(z.object({}).passthrough()) + }) + .passthrough() + ) + }) + .passthrough() + ) + }) + .passthrough(); + +/** + * Task capabilities for servers, indicating which request types support task creation. + */ +export const ServerTasksCapabilitySchema = z + .object({ + /** + * Present if the server supports listing tasks. + */ + list: z.optional(z.object({}).passthrough()), + /** + * Present if the server supports cancelling tasks. + */ + cancel: z.optional(z.object({}).passthrough()), + /** + * Capabilities for task creation on specific request types. + */ + requests: z.optional( + z + .object({ + /** + * Task support for tool requests. + */ + tools: z.optional( + z + .object({ + call: z.optional(z.object({}).passthrough()) + }) + .passthrough() + ) + }) + .passthrough() + ) + }) + .passthrough(); + /** * Capabilities a client may support. Known capabilities are defined here, in this schema, but this is not a closed set: any client can define its own, additional capabilities. */ @@ -329,7 +458,11 @@ export const ClientCapabilitiesSchema = z.object({ */ listChanged: z.boolean().optional() }) - .optional() + .optional(), + /** + * Present if the client supports task creation. + */ + tasks: z.optional(ClientTasksCapabilitySchema) }); export const InitializeRequestParamsSchema = BaseRequestParamsSchema.extend({ @@ -353,58 +486,64 @@ export const isInitializeRequest = (value: unknown): value is InitializeRequest /** * Capabilities that a server may support. Known capabilities are defined here, in this schema, but this is not a closed set: any server can define its own, additional capabilities. */ -export const ServerCapabilitiesSchema = z.object({ - /** - * Experimental, non-standard capabilities that the server supports. - */ - experimental: z.record(z.string(), AssertObjectSchema).optional(), - /** - * Present if the server supports sending log messages to the client. - */ - logging: AssertObjectSchema.optional(), - /** - * Present if the server supports sending completions to the client. - */ - completions: AssertObjectSchema.optional(), - /** - * Present if the server offers any prompt templates. - */ - prompts: z.optional( - z.object({ - /** - * Whether this server supports issuing notifications for changes to the prompt list. - */ - listChanged: z.optional(z.boolean()) - }) - ), - /** - * Present if the server offers any resources to read. - */ - resources: z - .object({ - /** - * Whether this server supports clients subscribing to resource updates. - */ - subscribe: z.boolean().optional(), - - /** - * Whether this server supports issuing notifications for changes to the resource list. - */ - listChanged: z.boolean().optional() - }) - .optional(), - /** - * Present if the server offers any tools to call. - */ - tools: z - .object({ - /** - * Whether this server supports issuing notifications for changes to the tool list. - */ - listChanged: z.boolean().optional() - }) - .optional() -}); +export const ServerCapabilitiesSchema = z + .object({ + /** + * Experimental, non-standard capabilities that the server supports. + */ + experimental: z.record(z.string(), AssertObjectSchema).optional(), + /** + * Present if the server supports sending log messages to the client. + */ + logging: AssertObjectSchema.optional(), + /** + * Present if the server supports sending completions to the client. + */ + completions: AssertObjectSchema.optional(), + /** + * Present if the server offers any prompt templates. + */ + prompts: z.optional( + z.object({ + /** + * Whether this server supports issuing notifications for changes to the prompt list. + */ + listChanged: z.optional(z.boolean()) + }) + ), + /** + * Present if the server offers any resources to read. + */ + resources: z + .object({ + /** + * Whether this server supports clients subscribing to resource updates. + */ + subscribe: z.boolean().optional(), + + /** + * Whether this server supports issuing notifications for changes to the resource list. + */ + listChanged: z.boolean().optional() + }) + .optional(), + /** + * Present if the server offers any tools to call. + */ + tools: z + .object({ + /** + * Whether this server supports issuing notifications for changes to the tool list. + */ + listChanged: z.boolean().optional() + }) + .optional(), + /** + * Present if the server supports task creation. + */ + tasks: z.optional(ServerTasksCapabilitySchema) + }) + .passthrough(); /** * After receiving an initialize request from the client, the server sends this response. @@ -497,6 +636,108 @@ export const PaginatedResultSchema = ResultSchema.extend({ nextCursor: z.optional(CursorSchema) }); +/* Tasks */ +/** + * A pollable state object associated with a request. + */ +export const TaskSchema = z.object({ + taskId: z.string(), + status: z.enum(['working', 'input_required', 'completed', 'failed', 'cancelled']), + /** + * Time in milliseconds to keep task results available after completion. + * If null, the task has unlimited lifetime until manually cleaned up. + */ + ttl: z.union([z.number(), z.null()]), + /** + * ISO 8601 timestamp when the task was created. + */ + createdAt: z.string(), + /** + * ISO 8601 timestamp when the task was last updated. + */ + lastUpdatedAt: z.string(), + pollInterval: z.optional(z.number()), + /** + * Optional diagnostic message for failed tasks or other status information. + */ + statusMessage: z.optional(z.string()) +}); + +/** + * Result returned when a task is created, containing the task data wrapped in a task field. + */ +export const CreateTaskResultSchema = ResultSchema.extend({ + task: TaskSchema +}); + +/** + * Parameters for task status notification. + * Task fields are spread directly into params per the spec (NotificationParams & Task). + */ +export const TaskStatusNotificationParamsSchema = NotificationsParamsSchema.merge(TaskSchema); + +/** + * A notification sent when a task's status changes. + */ +export const TaskStatusNotificationSchema = NotificationSchema.extend({ + method: z.literal('notifications/tasks/status'), + params: TaskStatusNotificationParamsSchema +}); + +/** + * A request to get the state of a specific task. + */ +export const GetTaskRequestSchema = RequestSchema.extend({ + method: z.literal('tasks/get'), + params: BaseRequestParamsSchema.extend({ + taskId: z.string() + }) +}); + +/** + * The response to a tasks/get request. + */ +export const GetTaskResultSchema = ResultSchema.merge(TaskSchema); + +/** + * A request to get the result of a specific task. + */ +export const GetTaskPayloadRequestSchema = RequestSchema.extend({ + method: z.literal('tasks/result'), + params: BaseRequestParamsSchema.extend({ + taskId: z.string() + }) +}); + +/** + * A request to list tasks. + */ +export const ListTasksRequestSchema = PaginatedRequestSchema.extend({ + method: z.literal('tasks/list') +}); + +/** + * The response to a tasks/list request. + */ +export const ListTasksResultSchema = PaginatedResultSchema.extend({ + tasks: z.array(TaskSchema) +}); + +/** + * A request to cancel a specific task. + */ +export const CancelTaskRequestSchema = RequestSchema.extend({ + method: z.literal('tasks/cancel'), + params: BaseRequestParamsSchema.extend({ + taskId: z.string() + }) +}); + +/** + * The response to a tasks/cancel request. + */ +export const CancelTaskResultSchema = ResultSchema; + /* Resources */ /** * The contents of a specific resource or sub-resource. @@ -988,6 +1229,21 @@ export const ToolAnnotationsSchema = z.object({ openWorldHint: z.boolean().optional() }); +/** + * Execution-related properties for a tool. + */ +export const ToolExecutionSchema = z.object({ + /** + * Indicates the tool's preference for task-augmented execution. + * - "required": Clients MUST invoke the tool as a task + * - "optional": Clients MAY invoke the tool as a task or normal request + * - "forbidden": Clients MUST NOT attempt to invoke the tool as a task + * + * If not present, defaults to "forbidden". + */ + taskSupport: z.enum(['required', 'optional', 'forbidden']).optional() +}); + /** * Definition for a tool the client can call. */ @@ -1026,6 +1282,10 @@ export const ToolSchema = z.object({ * Optional additional tool information. */ annotations: z.optional(ToolAnnotationsSchema), + /** + * Execution-related properties for this tool. + */ + execution: z.optional(ToolExecutionSchema), /** * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) @@ -1729,20 +1989,40 @@ export const ClientRequestSchema = z.union([ SubscribeRequestSchema, UnsubscribeRequestSchema, CallToolRequestSchema, - ListToolsRequestSchema + ListToolsRequestSchema, + GetTaskRequestSchema, + GetTaskPayloadRequestSchema, + ListTasksRequestSchema ]); export const ClientNotificationSchema = z.union([ CancelledNotificationSchema, ProgressNotificationSchema, InitializedNotificationSchema, - RootsListChangedNotificationSchema + RootsListChangedNotificationSchema, + TaskStatusNotificationSchema ]); -export const ClientResultSchema = z.union([EmptyResultSchema, CreateMessageResultSchema, ElicitResultSchema, ListRootsResultSchema]); +export const ClientResultSchema = z.union([ + EmptyResultSchema, + CreateMessageResultSchema, + ElicitResultSchema, + ListRootsResultSchema, + GetTaskResultSchema, + ListTasksResultSchema, + CreateTaskResultSchema +]); /* Server messages */ -export const ServerRequestSchema = z.union([PingRequestSchema, CreateMessageRequestSchema, ElicitRequestSchema, ListRootsRequestSchema]); +export const ServerRequestSchema = z.union([ + PingRequestSchema, + CreateMessageRequestSchema, + ElicitRequestSchema, + ListRootsRequestSchema, + GetTaskRequestSchema, + GetTaskPayloadRequestSchema, + ListTasksRequestSchema +]); export const ServerNotificationSchema = z.union([ CancelledNotificationSchema, @@ -1752,6 +2032,7 @@ export const ServerNotificationSchema = z.union([ ResourceListChangedNotificationSchema, ToolListChangedNotificationSchema, PromptListChangedNotificationSchema, + TaskStatusNotificationSchema, ElicitationCompleteNotificationSchema ]); @@ -1765,7 +2046,10 @@ export const ServerResultSchema = z.union([ ListResourceTemplatesResultSchema, ReadResourceResultSchema, CallToolResultSchema, - ListToolsResultSchema + ListToolsResultSchema, + GetTaskResultSchema, + ListTasksResultSchema, + CreateTaskResultSchema ]); export class McpError extends Error { @@ -1901,6 +2185,21 @@ export type Progress = Infer; export type ProgressNotificationParams = Infer; export type ProgressNotification = Infer; +/* Tasks */ +export type Task = Infer; +export type TaskCreationParams = Infer; +export type RelatedTaskMetadata = Infer; +export type CreateTaskResult = Infer; +export type TaskStatusNotificationParams = Infer; +export type TaskStatusNotification = Infer; +export type GetTaskRequest = Infer; +export type GetTaskResult = Infer; +export type GetTaskPayloadRequest = Infer; +export type ListTasksRequest = Infer; +export type ListTasksResult = Infer; +export type CancelTaskRequest = Infer; +export type CancelTaskResult = Infer; + /* Pagination */ export type PaginatedRequestParams = Infer; export type PaginatedRequest = Infer; @@ -1949,6 +2248,7 @@ export type PromptListChangedNotification = Infer; +export type ToolExecution = Infer; export type Tool = Infer; export type ListToolsRequest = Infer; export type ListToolsResult = Infer;