Skip to content

registerTool: accept ZodType<object> for input and output schema #816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 221 additions & 2 deletions src/server/mcp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3703,7 +3703,7 @@ describe("Tool title precedence", () => {
description: "Tool with regular title"
},
async () => ({
content: [{ type: "text", text: "Response" }],
content: [{ type: "text" as const, text: "Response" }],
})
);

Expand All @@ -3718,7 +3718,7 @@ describe("Tool title precedence", () => {
}
},
async () => ({
content: [{ type: "text", text: "Response" }],
content: [{ type: "text" as const, text: "Response" }],
})
);

Expand Down Expand Up @@ -4291,3 +4291,222 @@ describe("elicitInput()", () => {
}]);
});
});

describe("Tools with union and intersection schemas", () => {
test("should support union schemas", async () => {
const server = new McpServer({
name: "test",
version: "1.0.0",
});

const client = new Client({
name: "test-client",
version: "1.0.0",
});

// Define the union schema for email/phone contact
const unionSchema = z.union([
z.object({ type: z.literal("email"), email: z.string().email() }),
z.object({ type: z.literal("phone"), phone: z.string() })
]);

// Register tool before connecting
server.registerTool("contact", { inputSchema: unionSchema }, async (args) => {
if (args.type === "email") {
return {
content: [{ type: "text" as const, text: `Email contact: ${args.email}` }]
};
} else {
return {
content: [{ type: "text" as const, text: `Phone contact: ${args.phone}` }]
};
}
});

// Connect after registering tools
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
await server.connect(serverTransport);
await client.connect(clientTransport);

// Test with email
const emailResult = await client.callTool({
name: "contact",
arguments: {
type: "email",
email: "[email protected]"
}
});

expect(emailResult.content).toEqual([{
type: "text",
text: "Email contact: [email protected]"
}]);

// Test with phone
const phoneResult = await client.callTool({
name: "contact",
arguments: {
type: "phone",
phone: "+1234567890"
}
});

expect(phoneResult.content).toEqual([{
type: "text",
text: "Phone contact: +1234567890"
}]);
});

test("should support intersection schemas", async () => {
const server = new McpServer({
name: "test",
version: "1.0.0",
});

const client = new Client({
name: "test-client",
version: "1.0.0",
});

// Define intersection schema
const baseSchema = z.object({ id: z.string() });
const extendedSchema = z.object({ name: z.string(), age: z.number() });
const intersectionSchema = z.intersection(baseSchema, extendedSchema);

// Register tool before connecting
server.registerTool("user", { inputSchema: intersectionSchema }, async (args) => {
return {
content: [{
type: "text" as const,
text: `User: ${args.id}, ${args.name}, ${args.age} years old`
}]
};
});

// Connect after registering tools
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
await server.connect(serverTransport);
await client.connect(clientTransport);

const result = await client.callTool({
name: "user",
arguments: {
id: "123",
name: "John Doe",
age: 30
}
});

expect(result.content).toEqual([{
type: "text",
text: "User: 123, John Doe, 30 years old"
}]);
});

test("should support complex nested schemas", async () => {
const server = new McpServer({
name: "test",
version: "1.0.0",
});

const client = new Client({
name: "test-client",
version: "1.0.0",
});

// A more complex schema that wouldn't work with ZodRawShape
const schema = z.object({
items: z.array(
z.union([
z.object({ type: z.literal("text"), content: z.string() }),
z.object({ type: z.literal("number"), value: z.number() })
])
)
});

// Register tool before connecting
server.registerTool("process", { inputSchema: schema }, async (args) => {
const processed = args.items.map(item => {
if (item.type === "text") {
return item.content.toUpperCase();
} else {
return item.value * 2;
}
});
return {
content: [{
type: "text" as const,
text: `Processed: ${processed.join(", ")}`
}]
};
});

// Connect after registering tools
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
await server.connect(serverTransport);
await client.connect(clientTransport);

const result = await client.callTool({
name: "process",
arguments: {
items: [
{ type: "text", content: "hello" },
{ type: "number", value: 5 },
{ type: "text", content: "world" }
]
}
});

expect(result.content).toEqual([{
type: "text",
text: "Processed: HELLO, 10, WORLD"
}]);
});

test("should validate union schema inputs correctly", async () => {
const server = new McpServer({
name: "test",
version: "1.0.0",
});

const client = new Client({
name: "test-client",
version: "1.0.0",
});

const unionSchema = z.union([
z.object({ type: z.literal("a"), value: z.string() }),
z.object({ type: z.literal("b"), value: z.number() })
]);

// Register tool before connecting
server.registerTool("union-test", { inputSchema: unionSchema }, async () => {
return {
content: [{ type: "text" as const, text: "Success" }]
};
});

// Connect after registering tools
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
await server.connect(serverTransport);
await client.connect(clientTransport);

// Should fail with invalid input - wrong type for value
await expect(client.callTool({
name: "union-test",
arguments: {
type: "a",
value: 123 // Wrong type - should be string
}
})).rejects.toThrow();

// Should fail with invalid discriminator
await expect(client.callTool({
name: "union-test",
arguments: {
type: "c", // Invalid discriminator
value: "test"
}
})).rejects.toThrow();
});
});
40 changes: 29 additions & 11 deletions src/server/mcp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import {
ZodRawShape,
ZodObject,
ZodString,
AnyZodObject,
ZodTypeAny,
ZodType,
ZodTypeDef,
Expand Down Expand Up @@ -769,18 +768,16 @@ export class McpServer {
name: string,
title: string | undefined,
description: string | undefined,
inputSchema: ZodRawShape | undefined,
outputSchema: ZodRawShape | undefined,
inputSchema: ZodRawShape | ZodType<object> | undefined,
outputSchema: ZodRawShape | ZodType<object> | undefined,
annotations: ToolAnnotations | undefined,
callback: ToolCallback<ZodRawShape | undefined>
): RegisteredTool {
const registeredTool: RegisteredTool = {
title,
description,
inputSchema:
inputSchema === undefined ? undefined : z.object(inputSchema),
outputSchema:
outputSchema === undefined ? undefined : z.object(outputSchema),
inputSchema: getZodSchemaObject(inputSchema),
outputSchema: getZodSchemaObject(outputSchema),
annotations,
callback,
enabled: true,
Expand Down Expand Up @@ -920,7 +917,7 @@ export class McpServer {
/**
* Registers a tool with a config object and callback.
*/
registerTool<InputArgs extends ZodRawShape, OutputArgs extends ZodRawShape>(
registerTool<InputArgs extends ZodRawShape | ZodType<object>, OutputArgs extends ZodRawShape | ZodType<object>>(
name: string,
config: {
title?: string;
Expand Down Expand Up @@ -1148,19 +1145,24 @@ export class ResourceTemplate {
* - `content` if the tool does not have an outputSchema
* - Both fields are optional but typically one should be provided
*/
export type ToolCallback<Args extends undefined | ZodRawShape = undefined> =
export type ToolCallback<Args extends undefined | ZodRawShape | ZodType<object> = undefined> =
Args extends ZodRawShape
? (
args: z.objectOutputType<Args, ZodTypeAny>,
extra: RequestHandlerExtra<ServerRequest, ServerNotification>,
) => CallToolResult | Promise<CallToolResult>
: Args extends ZodType<infer T>
? (
args: T,
extra: RequestHandlerExtra<ServerRequest, ServerNotification>,
) => CallToolResult | Promise<CallToolResult>
: (extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => CallToolResult | Promise<CallToolResult>;

export type RegisteredTool = {
title?: string;
description?: string;
inputSchema?: AnyZodObject;
outputSchema?: AnyZodObject;
inputSchema?: ZodType<object>;
outputSchema?: ZodType<object>;
annotations?: ToolAnnotations;
callback: ToolCallback<undefined | ZodRawShape>;
enabled: boolean;
Expand Down Expand Up @@ -1203,6 +1205,22 @@ function isZodTypeLike(value: unknown): value is ZodType {
'safeParse' in value && typeof value.safeParse === 'function';
}

/**
* Converts a provided Zod schema to a Zod object if it is a ZodRawShape,
* otherwise returns the schema as is.
*/
function getZodSchemaObject(schema: ZodRawShape | ZodType<object> | undefined): ZodType<object> | undefined {
if (!schema) {
return undefined;
}

if (isZodRawShape(schema)) {
return z.object(schema);
}

return schema;
}

/**
* Additional, optional information for annotating a resource.
*/
Expand Down