Skip to content

Commit bdc5aa4

Browse files
committed
Validate against CreateTaskResult in low-level client/server
1 parent db3280d commit bdc5aa4

File tree

4 files changed

+324
-184
lines changed

4 files changed

+324
-184
lines changed

src/client/index.test.ts

Lines changed: 69 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ import {
1919
ElicitResultSchema,
2020
ListRootsRequestSchema,
2121
ErrorCode,
22-
McpError
22+
McpError,
23+
CreateTaskResultSchema
2324
} from '../types.js';
2425
import { Transport } from '../shared/transport.js';
2526
import { Server } from '../server/index.js';
@@ -2150,22 +2151,22 @@ describe('Task-based execution', () => {
21502151
);
21512152

21522153
client.setRequestHandler(ElicitRequestSchema, async (request, extra) => {
2153-
let taskId: string | undefined;
2154+
const result = {
2155+
action: 'accept',
2156+
content: { username: 'list-user' }
2157+
};
21542158

21552159
// Check if task creation is requested
21562160
if (request.params.task && extra.taskStore) {
2157-
const createdTask = await extra.taskStore.createTask({
2161+
const task = await extra.taskStore.createTask({
21582162
ttl: extra.taskRequestedTtl
21592163
});
2160-
taskId = createdTask.taskId;
2161-
}
2162-
const result = {
2163-
action: 'accept',
2164-
content: { username: 'list-user' }
2165-
};
2166-
if (taskId && extra.taskStore) {
2167-
await extra.taskStore.storeTaskResult(taskId, 'completed', result);
2164+
await extra.taskStore.storeTaskResult(task.taskId, 'completed', result);
2165+
// Return CreateTaskResult when task creation is requested
2166+
return { task };
21682167
}
2168+
2169+
// Return ElicitResult for non-task requests
21692170
return result;
21702171
});
21712172

@@ -2192,7 +2193,7 @@ describe('Task-based execution', () => {
21922193
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
21932194

21942195
// Server creates task on client via elicitation
2195-
await server.request(
2196+
const createTaskResult = await server.request(
21962197
{
21972198
method: 'elicitation/create',
21982199
params: {
@@ -2207,14 +2208,14 @@ describe('Task-based execution', () => {
22072208
}
22082209
}
22092210
},
2210-
ElicitResultSchema,
2211+
CreateTaskResultSchema,
22112212
{ task: { ttl: 60000 } }
22122213
);
22132214

2214-
// Get the task ID from the task list since it's generated automatically
2215-
const taskList = await server.listTasks();
2216-
expect(taskList.tasks.length).toBeGreaterThan(0);
2217-
const taskId = taskList.tasks[0].taskId;
2215+
// Verify CreateTaskResult structure
2216+
expect(createTaskResult.task).toBeDefined();
2217+
expect(createTaskResult.task.taskId).toBeDefined();
2218+
const taskId = createTaskResult.task.taskId;
22182219

22192220
// Verify task was created
22202221
const task = await server.getTask({ taskId });
@@ -2243,22 +2244,22 @@ describe('Task-based execution', () => {
22432244
);
22442245

22452246
client.setRequestHandler(ElicitRequestSchema, async (request, extra) => {
2246-
let taskId: string | undefined;
2247+
const result = {
2248+
action: 'accept',
2249+
content: { username: 'list-user' }
2250+
};
22472251

22482252
// Check if task creation is requested
22492253
if (request.params.task && extra.taskStore) {
2250-
const createdTask = await extra.taskStore.createTask({
2254+
const task = await extra.taskStore.createTask({
22512255
ttl: extra.taskRequestedTtl
22522256
});
2253-
taskId = createdTask.taskId;
2254-
}
2255-
const result = {
2256-
action: 'accept',
2257-
content: { username: 'list-user' }
2258-
};
2259-
if (taskId && extra.taskStore) {
2260-
await extra.taskStore.storeTaskResult(taskId, 'completed', result);
2257+
await extra.taskStore.storeTaskResult(task.taskId, 'completed', result);
2258+
// Return CreateTaskResult when task creation is requested
2259+
return { task };
22612260
}
2261+
2262+
// Return ElicitResult for non-task requests
22622263
return result;
22632264
});
22642265

@@ -2284,8 +2285,8 @@ describe('Task-based execution', () => {
22842285

22852286
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
22862287

2287-
// Create a task on client and wait for completion
2288-
const result = await server.request(
2288+
// Create a task on client and wait for CreateTaskResult
2289+
const createTaskResult = await server.request(
22892290
{
22902291
method: 'elicitation/create',
22912292
params: {
@@ -2297,18 +2298,14 @@ describe('Task-based execution', () => {
22972298
}
22982299
}
22992300
},
2300-
ElicitResultSchema,
2301+
CreateTaskResultSchema,
23012302
{ task: { ttl: 60000 } }
23022303
);
23032304

2304-
// Verify the result was returned correctly
2305-
expect(result.action).toBe('accept');
2306-
expect(result.content).toEqual({ username: 'list-user' });
2307-
2308-
// Get the task ID from the task list since it's generated automatically
2309-
const taskList = await server.listTasks();
2310-
expect(taskList.tasks.length).toBeGreaterThan(0);
2311-
const taskId = taskList.tasks[0].taskId;
2305+
// Verify CreateTaskResult structure
2306+
expect(createTaskResult.task).toBeDefined();
2307+
expect(createTaskResult.task.taskId).toBeDefined();
2308+
const taskId = createTaskResult.task.taskId;
23122309

23132310
// Query task status
23142311
const task = await server.getTask({ taskId });
@@ -2339,22 +2336,22 @@ describe('Task-based execution', () => {
23392336
);
23402337

23412338
client.setRequestHandler(ElicitRequestSchema, async (request, extra) => {
2342-
let taskId: string | undefined;
2339+
const result = {
2340+
action: 'accept',
2341+
content: { username: 'result-user' }
2342+
};
23432343

23442344
// Check if task creation is requested
23452345
if (request.params.task && extra.taskStore) {
2346-
const createdTask = await extra.taskStore.createTask({
2346+
const task = await extra.taskStore.createTask({
23472347
ttl: extra.taskRequestedTtl
23482348
});
2349-
taskId = createdTask.taskId;
2350-
}
2351-
const result = {
2352-
action: 'accept',
2353-
content: { username: 'result-user' }
2354-
};
2355-
if (taskId && extra.taskStore) {
2356-
await extra.taskStore.storeTaskResult(taskId, 'completed', result);
2349+
await extra.taskStore.storeTaskResult(task.taskId, 'completed', result);
2350+
// Return CreateTaskResult when task creation is requested
2351+
return { task };
23572352
}
2353+
2354+
// Return ElicitResult for non-task requests
23582355
return result;
23592356
});
23602357

@@ -2380,8 +2377,8 @@ describe('Task-based execution', () => {
23802377

23812378
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
23822379

2383-
// Create a task on client and wait for completion
2384-
const result = await server.request(
2380+
// Create a task on client and wait for CreateTaskResult
2381+
const createTaskResult = await server.request(
23852382
{
23862383
method: 'elicitation/create',
23872384
params: {
@@ -2393,20 +2390,16 @@ describe('Task-based execution', () => {
23932390
}
23942391
}
23952392
},
2396-
ElicitResultSchema,
2393+
CreateTaskResultSchema,
23972394
{ task: { ttl: 60000 } }
23982395
);
23992396

2400-
// Verify the result was returned correctly
2401-
expect(result.action).toBe('accept');
2402-
expect(result.content).toEqual({ username: 'result-user' });
2403-
2404-
// Get the task ID from the task list since it's generated automatically
2405-
const taskList = await server.listTasks();
2406-
expect(taskList.tasks.length).toBeGreaterThan(0);
2407-
const taskId = taskList.tasks[0].taskId;
2397+
// Verify CreateTaskResult structure
2398+
expect(createTaskResult.task).toBeDefined();
2399+
expect(createTaskResult.task.taskId).toBeDefined();
2400+
const taskId = createTaskResult.task.taskId;
24082401

2409-
// Query task result using getTaskResult as well
2402+
// Query task result using getTaskResult
24102403
const taskResult = await server.getTaskResult({ taskId }, ElicitResultSchema);
24112404
expect(taskResult.action).toBe('accept');
24122405
expect(taskResult.content).toEqual({ username: 'result-user' });
@@ -2434,22 +2427,22 @@ describe('Task-based execution', () => {
24342427
);
24352428

24362429
client.setRequestHandler(ElicitRequestSchema, async (request, extra) => {
2437-
let taskId: string | undefined;
2430+
const result = {
2431+
action: 'accept',
2432+
content: { username: 'list-user' }
2433+
};
24382434

24392435
// Check if task creation is requested
24402436
if (request.params.task && extra.taskStore) {
2441-
const createdTask = await extra.taskStore.createTask({
2437+
const task = await extra.taskStore.createTask({
24422438
ttl: extra.taskRequestedTtl
24432439
});
2444-
taskId = createdTask.taskId;
2445-
}
2446-
const result = {
2447-
action: 'accept',
2448-
content: { username: 'list-user' }
2449-
};
2450-
if (taskId && extra.taskStore) {
2451-
await extra.taskStore.storeTaskResult(taskId, 'completed', result);
2440+
await extra.taskStore.storeTaskResult(task.taskId, 'completed', result);
2441+
// Return CreateTaskResult when task creation is requested
2442+
return { task };
24522443
}
2444+
2445+
// Return ElicitResult for non-task requests
24532446
return result;
24542447
});
24552448

@@ -2478,7 +2471,7 @@ describe('Task-based execution', () => {
24782471
// Create multiple tasks on client
24792472
const createdTaskIds: string[] = [];
24802473
for (let i = 0; i < 2; i++) {
2481-
const result = await server.request(
2474+
const createTaskResult = await server.request(
24822475
{
24832476
method: 'elicitation/create',
24842477
params: {
@@ -2490,20 +2483,14 @@ describe('Task-based execution', () => {
24902483
}
24912484
}
24922485
},
2493-
ElicitResultSchema,
2486+
CreateTaskResultSchema,
24942487
{ task: { ttl: 60000 } }
24952488
);
24962489

2497-
// Verify the result was returned correctly
2498-
expect(result.action).toBe('accept');
2499-
expect(result.content).toEqual({ username: 'list-user' });
2500-
2501-
// Get the task ID from the task list
2502-
const taskList = await server.listTasks();
2503-
const newTask = taskList.tasks.find(t => !createdTaskIds.includes(t.taskId));
2504-
if (newTask) {
2505-
createdTaskIds.push(newTask.taskId);
2506-
}
2490+
// Verify CreateTaskResult structure and capture taskId
2491+
expect(createTaskResult.task).toBeDefined();
2492+
expect(createTaskResult.task.taskId).toBeDefined();
2493+
createdTaskIds.push(createTaskResult.task.taskId);
25072494
}
25082495

25092496
// Query task list

src/client/index.ts

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ import {
4040
type Tool,
4141
type UnsubscribeRequest,
4242
ElicitResultSchema,
43-
ElicitRequestSchema
43+
ElicitRequestSchema,
44+
CreateTaskResultSchema,
45+
CreateMessageRequestSchema,
46+
CreateMessageResultSchema
4447
} from '../types.js';
4548
import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js';
4649
import type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator } from '../validation/types.js';
@@ -283,6 +286,20 @@ export class Client<
283286

284287
const result = await Promise.resolve(handler(request, extra));
285288

289+
// When task creation is requested, validate and return CreateTaskResult
290+
if (params.task) {
291+
const taskValidationResult = safeParse(CreateTaskResultSchema, result);
292+
if (!taskValidationResult.success) {
293+
const errorMessage =
294+
taskValidationResult.error instanceof Error
295+
? taskValidationResult.error.message
296+
: String(taskValidationResult.error);
297+
throw new McpError(ErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`);
298+
}
299+
return taskValidationResult.data;
300+
}
301+
302+
// For non-task requests, validate against ElicitResultSchema
286303
const validationResult = safeParse(ElicitResultSchema, result);
287304
if (!validationResult.success) {
288305
// Type guard: if success is false, error is guaranteed to exist
@@ -311,7 +328,51 @@ export class Client<
311328
return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler);
312329
}
313330

314-
// Non-elicitation handlers use default behavior
331+
if (method === 'sampling/createMessage') {
332+
const wrappedHandler = async (
333+
request: SchemaOutput<T>,
334+
extra: RequestHandlerExtra<ClientRequest | RequestT, ClientNotification | NotificationT>
335+
): Promise<ClientResult | ResultT> => {
336+
const validatedRequest = safeParse(CreateMessageRequestSchema, request);
337+
if (!validatedRequest.success) {
338+
const errorMessage =
339+
validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error);
340+
throw new McpError(ErrorCode.InvalidParams, `Invalid sampling request: ${errorMessage}`);
341+
}
342+
343+
const { params } = validatedRequest.data;
344+
345+
const result = await Promise.resolve(handler(request, extra));
346+
347+
// When task creation is requested, validate and return CreateTaskResult
348+
if (params.task) {
349+
const taskValidationResult = safeParse(CreateTaskResultSchema, result);
350+
if (!taskValidationResult.success) {
351+
const errorMessage =
352+
taskValidationResult.error instanceof Error
353+
? taskValidationResult.error.message
354+
: String(taskValidationResult.error);
355+
throw new McpError(ErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`);
356+
}
357+
return taskValidationResult.data;
358+
}
359+
360+
// For non-task requests, validate against CreateMessageResultSchema
361+
const validationResult = safeParse(CreateMessageResultSchema, result);
362+
if (!validationResult.success) {
363+
const errorMessage =
364+
validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error);
365+
throw new McpError(ErrorCode.InvalidParams, `Invalid sampling result: ${errorMessage}`);
366+
}
367+
368+
return validationResult.data;
369+
};
370+
371+
// Install the wrapped handler
372+
return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler);
373+
}
374+
375+
// Other handlers use default behavior
315376
return super.setRequestHandler(requestSchema, handler);
316377
}
317378

0 commit comments

Comments
 (0)