Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/some-buses-run.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"agents": minor
---

Add generic type parameters to MCP handler functions for better type safety
22 changes: 14 additions & 8 deletions packages/agents/src/mcp/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,23 @@ export interface CreateMcpHandlerOptions extends WorkerTransportOptions {
transport?: WorkerTransport;
}

export function createMcpHandler(
export function createMcpHandler<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems like a good improvement independently of the types below.

Env extends Cloudflare.Env = Cloudflare.Env,
Props extends Record<string, unknown> = Record<string, unknown>
>(
server: McpServer | Server,
options: CreateMcpHandlerOptions = {}
): (
request: Request,
env: unknown,
ctx: ExecutionContext
env: Env,
ctx: ExecutionContext<Props>
) => Promise<Response> {
const route = options.route ?? "/mcp";

return async (
request: Request,
_env: unknown,
ctx: ExecutionContext
_env: Env,
ctx: ExecutionContext<Props>
): Promise<Response> => {
const url = new URL(request.url);
if (route && url.pathname !== route) {
Expand Down Expand Up @@ -109,13 +112,16 @@ let didWarnAboutExperimentalCreateMcpHandler = false;
/**
* @deprecated This has been renamed to createMcpHandler, and experimental_createMcpHandler will be removed in the next major version
*/
export function experimental_createMcpHandler(
export function experimental_createMcpHandler<
Env extends Cloudflare.Env = Cloudflare.Env,
Props extends Record<string, unknown> = Record<string, unknown>
>(
server: McpServer | Server,
options: CreateMcpHandlerOptions = {}
): (
request: Request,
env: unknown,
ctx: ExecutionContext
env: Env,
ctx: ExecutionContext<Props>
) => Promise<Response> {
if (!didWarnAboutExperimentalCreateMcpHandler) {
didWarnAboutExperimentalCreateMcpHandler = true;
Expand Down
9 changes: 6 additions & 3 deletions packages/agents/src/mcp/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,10 @@ export abstract class McpAgent<
/** Return a handler for the given path for this MCP.
* Defaults to Streamable HTTP transport.
*/
static serve(
static serve<
Env extends Cloudflare.Env = Cloudflare.Env,
Props extends Record<string, unknown> = Record<string, unknown>
>(
path: string,
{
binding = "MCP_OBJECT",
Expand All @@ -371,11 +374,11 @@ export abstract class McpAgent<
}: ServeOptions = {}
) {
return {
async fetch<Env>(
async fetch(
this: void,
request: Request,
env: Env,
ctx: ExecutionContext
ctx: ExecutionContext<Props>
): Promise<Response> {
// Handle CORS preflight
const corsResponse = handleCORS(request, corsOptions);
Expand Down
29 changes: 16 additions & 13 deletions packages/agents/src/tests/mcp/handler.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ declare module "cloudflare:test" {
interface ProvidedEnv {}
}

const createTestExecutionContext = () =>
createExecutionContext() as ExecutionContext<Record<string, unknown>>;

/**
* Tests for createMcpHandler
* The handler primarily passes options to WorkerTransport and handles routing
Expand Down Expand Up @@ -42,7 +45,7 @@ describe("createMcpHandler", () => {
route: "/custom-mcp"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();

// Request to non-matching route
const wrongRequest = new Request("http://example.com/mcp", {
Expand All @@ -63,7 +66,7 @@ describe("createMcpHandler", () => {
const server = createTestServer();
const handler = createMcpHandler(server);

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "OPTIONS"
});
Expand All @@ -85,7 +88,7 @@ describe("createMcpHandler", () => {
}
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "OPTIONS"
});
Expand All @@ -109,7 +112,7 @@ describe("createMcpHandler", () => {
route: "/mcp"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -149,7 +152,7 @@ describe("createMcpHandler", () => {
transport: customTransport
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "OPTIONS"
});
Expand All @@ -176,7 +179,7 @@ describe("createMcpHandler", () => {
transport: customTransport
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "OPTIONS"
});
Expand All @@ -203,7 +206,7 @@ describe("createMcpHandler", () => {
sessionIdGenerator: customSessionIdGenerator
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -241,7 +244,7 @@ describe("createMcpHandler", () => {
}
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -274,7 +277,7 @@ describe("createMcpHandler", () => {
enableJsonResponse: true
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -311,7 +314,7 @@ describe("createMcpHandler", () => {
storage: mockStorage
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -343,7 +346,7 @@ describe("createMcpHandler", () => {
corsOptions: { origin: "https://example.com" }
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/custom-route", {
method: "OPTIONS"
});
Expand Down Expand Up @@ -373,7 +376,7 @@ describe("createMcpHandler", () => {
transport: errorTransport
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -418,7 +421,7 @@ describe("createMcpHandler", () => {
transport: errorTransport
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down
17 changes: 10 additions & 7 deletions packages/agents/src/tests/mcp/jurisdiction.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ declare module "cloudflare:test" {
}
}

const createTestExecutionContext = () =>
createExecutionContext() as ExecutionContext<Record<string, unknown>>;

/**
* Tests for jurisdiction option in McpAgent.serve()
*
Expand Down Expand Up @@ -61,7 +64,7 @@ describe("McpAgent jurisdiction option", () => {
transport: "streamable-http"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -95,7 +98,7 @@ describe("McpAgent jurisdiction option", () => {
transport: "sse"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "GET"
});
Expand All @@ -115,7 +118,7 @@ describe("McpAgent jurisdiction option", () => {
transport: "streamable-http"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -147,7 +150,7 @@ describe("McpAgent jurisdiction option", () => {
transport: "streamable-http"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down Expand Up @@ -183,7 +186,7 @@ describe("McpAgent jurisdiction option", () => {
transport: "streamable-http"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();

// First request (initialization)
const initRequest = new Request("http://example.com/mcp", {
Expand Down Expand Up @@ -241,7 +244,7 @@ describe("McpAgent jurisdiction option", () => {
transport: "sse"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();

// First, establish SSE connection
const sseRequest = new Request(
Expand Down Expand Up @@ -305,7 +308,7 @@ describe("McpAgent jurisdiction option", () => {
transport: "streamable-http"
});

const ctx = createExecutionContext();
const ctx = createTestExecutionContext();
const request = new Request("http://example.com/mcp", {
method: "POST",
headers: {
Expand Down
6 changes: 4 additions & 2 deletions packages/agents/src/tests/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -591,12 +591,14 @@ export default {
testValue: "123"
};

const typedCtx = ctx as ExecutionContext<Record<string, unknown>>;

if (url.pathname === "/sse" || url.pathname === "/sse/message") {
return TestMcpAgent.serveSSE("/sse").fetch(request, env, ctx);
return TestMcpAgent.serveSSE("/sse").fetch(request, env, typedCtx);
}

if (url.pathname === "/mcp") {
return TestMcpAgent.serve("/mcp").fetch(request, env, ctx);
return TestMcpAgent.serve("/mcp").fetch(request, env, typedCtx);
}

if (url.pathname === "/500") {
Expand Down