Skip to content
55 changes: 55 additions & 0 deletions src/client/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,58 @@ test("should typecheck", () => {
},
});
});

test("should handle client cancelling a request", async () => {
const server = new Server(
{
name: "test server",
version: "1.0",
},
{
capabilities: {
resources: {},
},
},
);

// Set up server to delay responding to listResources
server.setRequestHandler(
ListResourcesRequestSchema,
async (request, extra) => {
await new Promise((resolve) => setTimeout(resolve, 1000));
return {
resources: [],
};
},
);

const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();

const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {},
},
);

await Promise.all([
client.connect(clientTransport),
server.connect(serverTransport),
]);

// Set up abort controller
const controller = new AbortController();

// Issue request but cancel it immediately
const listResourcesPromise = client.listResources(undefined, {
signal: controller.signal,
});
controller.abort("Cancelled by test");

// Request should be rejected
await expect(listResourcesPromise).rejects.toBe("Cancelled by test");
});
41 changes: 21 additions & 20 deletions src/client/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import {
ProgressCallback,
Protocol,
ProtocolOptions,
RequestOptions,
} from "../shared/protocol.js";
import { Transport } from "../shared/transport.js";
import {
Expand Down Expand Up @@ -244,6 +244,10 @@ export class Client<
// No specific capability required for initialized
break;

case "notifications/cancelled":
// Cancellation notifications are always allowed
break;

case "notifications/progress":
// Progress notifications are always allowed
break;
Expand Down Expand Up @@ -278,14 +282,11 @@ export class Client<
return this.request({ method: "ping" }, EmptyResultSchema);
}

async complete(
params: CompleteRequest["params"],
onprogress?: ProgressCallback,
) {
async complete(params: CompleteRequest["params"], options?: RequestOptions) {
return this.request(
{ method: "completion/complete", params },
CompleteResultSchema,
onprogress,
options,
);
}

Expand All @@ -298,56 +299,56 @@ export class Client<

async getPrompt(
params: GetPromptRequest["params"],
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "prompts/get", params },
GetPromptResultSchema,
onprogress,
options,
);
}

async listPrompts(
params?: ListPromptsRequest["params"],
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "prompts/list", params },
ListPromptsResultSchema,
onprogress,
options,
);
}

async listResources(
params?: ListResourcesRequest["params"],
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "resources/list", params },
ListResourcesResultSchema,
onprogress,
options,
);
}

async listResourceTemplates(
params?: ListResourceTemplatesRequest["params"],
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "resources/templates/list", params },
ListResourceTemplatesResultSchema,
onprogress,
options,
);
}

async readResource(
params: ReadResourceRequest["params"],
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "resources/read", params },
ReadResourceResultSchema,
onprogress,
options,
);
}

Expand All @@ -370,23 +371,23 @@ export class Client<
resultSchema:
| typeof CallToolResultSchema
| typeof CompatibilityCallToolResultSchema = CallToolResultSchema,
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "tools/call", params },
resultSchema,
onprogress,
options,
);
}

async listTools(
params?: ListToolsRequest["params"],
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "tools/list", params },
ListToolsResultSchema,
onprogress,
options,
);
}

Expand Down
68 changes: 68 additions & 0 deletions src/server/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,71 @@ test("should typecheck", () => {
},
);
});

test("should handle server cancelling a request", async () => {
const server = new Server(
{
name: "test server",
version: "1.0",
},
{
capabilities: {
sampling: {},
},
},
);

const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
sampling: {},
},
},
);

// Set up client to delay responding to createMessage
client.setRequestHandler(
CreateMessageRequestSchema,
async (_request, extra) => {
await new Promise((resolve) => setTimeout(resolve, 1000));
return {
model: "test",
role: "assistant",
content: {
type: "text",
text: "Test response",
},
};
},
);

const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();

await Promise.all([
client.connect(clientTransport),
server.connect(serverTransport),
]);

// Set up abort controller
const controller = new AbortController();

// Issue request but cancel it immediately
const createMessagePromise = server.createMessage(
{
messages: [],
maxTokens: 10,
},
{
signal: controller.signal,
},
);
controller.abort("Cancelled by test");

// Request should be rejected
await expect(createMessagePromise).rejects.toBe("Cancelled by test");
});
14 changes: 9 additions & 5 deletions src/server/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import {
ProgressCallback,
Protocol,
ProtocolOptions,
RequestOptions,
} from "../shared/protocol.js";
import {
ClientCapabilities,
Expand Down Expand Up @@ -157,6 +157,10 @@ export class Server<
}
break;

case "notifications/cancelled":
// Cancellation notifications are always allowed
break;

case "notifications/progress":
// Progress notifications are always allowed
break;
Expand Down Expand Up @@ -257,23 +261,23 @@ export class Server<

async createMessage(
params: CreateMessageRequest["params"],
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "sampling/createMessage", params },
CreateMessageResultSchema,
onprogress,
options,
);
}

async listRoots(
params?: ListRootsRequest["params"],
onprogress?: ProgressCallback,
options?: RequestOptions,
) {
return this.request(
{ method: "roots/list", params },
ListRootsResultSchema,
onprogress,
options,
);
}

Expand Down
Loading