diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 10e550df4..b6e78414b 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -3703,7 +3703,7 @@ describe("Tool title precedence", () => { description: "Tool with regular title" }, async () => ({ - content: [{ type: "text", text: "Response" }], + content: [{ type: "text" as const, text: "Response" }], }) ); @@ -3718,7 +3718,7 @@ describe("Tool title precedence", () => { } }, async () => ({ - content: [{ type: "text", text: "Response" }], + content: [{ type: "text" as const, text: "Response" }], }) ); @@ -4291,3 +4291,222 @@ describe("elicitInput()", () => { }]); }); }); + +describe("Tools with union and intersection schemas", () => { + test("should support union schemas", async () => { + const server = new McpServer({ + name: "test", + version: "1.0.0", + }); + + const client = new Client({ + name: "test-client", + version: "1.0.0", + }); + + // Define the union schema for email/phone contact + const unionSchema = z.union([ + z.object({ type: z.literal("email"), email: z.string().email() }), + z.object({ type: z.literal("phone"), phone: z.string() }) + ]); + + // Register tool before connecting + server.registerTool("contact", { inputSchema: unionSchema }, async (args) => { + if (args.type === "email") { + return { + content: [{ type: "text" as const, text: `Email contact: ${args.email}` }] + }; + } else { + return { + content: [{ type: "text" as const, text: `Phone contact: ${args.phone}` }] + }; + } + }); + + // Connect after registering tools + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await server.connect(serverTransport); + await client.connect(clientTransport); + + // Test with email + const emailResult = await client.callTool({ + name: "contact", + arguments: { + type: "email", + email: "test@example.com" + } + }); + + expect(emailResult.content).toEqual([{ + type: "text", + text: "Email contact: test@example.com" + }]); + + // Test with phone + const phoneResult = await client.callTool({ + name: "contact", + arguments: { + type: "phone", + phone: "+1234567890" + } + }); + + expect(phoneResult.content).toEqual([{ + type: "text", + text: "Phone contact: +1234567890" + }]); + }); + + test("should support intersection schemas", async () => { + const server = new McpServer({ + name: "test", + version: "1.0.0", + }); + + const client = new Client({ + name: "test-client", + version: "1.0.0", + }); + + // Define intersection schema + const baseSchema = z.object({ id: z.string() }); + const extendedSchema = z.object({ name: z.string(), age: z.number() }); + const intersectionSchema = z.intersection(baseSchema, extendedSchema); + + // Register tool before connecting + server.registerTool("user", { inputSchema: intersectionSchema }, async (args) => { + return { + content: [{ + type: "text" as const, + text: `User: ${args.id}, ${args.name}, ${args.age} years old` + }] + }; + }); + + // Connect after registering tools + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await server.connect(serverTransport); + await client.connect(clientTransport); + + const result = await client.callTool({ + name: "user", + arguments: { + id: "123", + name: "John Doe", + age: 30 + } + }); + + expect(result.content).toEqual([{ + type: "text", + text: "User: 123, John Doe, 30 years old" + }]); + }); + + test("should support complex nested schemas", async () => { + const server = new McpServer({ + name: "test", + version: "1.0.0", + }); + + const client = new Client({ + name: "test-client", + version: "1.0.0", + }); + + // A more complex schema that wouldn't work with ZodRawShape + const schema = z.object({ + items: z.array( + z.union([ + z.object({ type: z.literal("text"), content: z.string() }), + z.object({ type: z.literal("number"), value: z.number() }) + ]) + ) + }); + + // Register tool before connecting + server.registerTool("process", { inputSchema: schema }, async (args) => { + const processed = args.items.map(item => { + if (item.type === "text") { + return item.content.toUpperCase(); + } else { + return item.value * 2; + } + }); + return { + content: [{ + type: "text" as const, + text: `Processed: ${processed.join(", ")}` + }] + }; + }); + + // Connect after registering tools + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await server.connect(serverTransport); + await client.connect(clientTransport); + + const result = await client.callTool({ + name: "process", + arguments: { + items: [ + { type: "text", content: "hello" }, + { type: "number", value: 5 }, + { type: "text", content: "world" } + ] + } + }); + + expect(result.content).toEqual([{ + type: "text", + text: "Processed: HELLO, 10, WORLD" + }]); + }); + + test("should validate union schema inputs correctly", async () => { + const server = new McpServer({ + name: "test", + version: "1.0.0", + }); + + const client = new Client({ + name: "test-client", + version: "1.0.0", + }); + + const unionSchema = z.union([ + z.object({ type: z.literal("a"), value: z.string() }), + z.object({ type: z.literal("b"), value: z.number() }) + ]); + + // Register tool before connecting + server.registerTool("union-test", { inputSchema: unionSchema }, async () => { + return { + content: [{ type: "text" as const, text: "Success" }] + }; + }); + + // Connect after registering tools + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await server.connect(serverTransport); + await client.connect(clientTransport); + + // Should fail with invalid input - wrong type for value + await expect(client.callTool({ + name: "union-test", + arguments: { + type: "a", + value: 123 // Wrong type - should be string + } + })).rejects.toThrow(); + + // Should fail with invalid discriminator + await expect(client.callTool({ + name: "union-test", + arguments: { + type: "c", // Invalid discriminator + value: "test" + } + })).rejects.toThrow(); + }); +}); diff --git a/src/server/mcp.ts b/src/server/mcp.ts index 791facef1..d412fc193 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -5,7 +5,6 @@ import { ZodRawShape, ZodObject, ZodString, - AnyZodObject, ZodTypeAny, ZodType, ZodTypeDef, @@ -769,18 +768,16 @@ export class McpServer { name: string, title: string | undefined, description: string | undefined, - inputSchema: ZodRawShape | undefined, - outputSchema: ZodRawShape | undefined, + inputSchema: ZodRawShape | ZodType | undefined, + outputSchema: ZodRawShape | ZodType | undefined, annotations: ToolAnnotations | undefined, callback: ToolCallback ): RegisteredTool { const registeredTool: RegisteredTool = { title, description, - inputSchema: - inputSchema === undefined ? undefined : z.object(inputSchema), - outputSchema: - outputSchema === undefined ? undefined : z.object(outputSchema), + inputSchema: getZodSchemaObject(inputSchema), + outputSchema: getZodSchemaObject(outputSchema), annotations, callback, enabled: true, @@ -920,7 +917,7 @@ export class McpServer { /** * Registers a tool with a config object and callback. */ - registerTool( + registerTool, OutputArgs extends ZodRawShape | ZodType>( name: string, config: { title?: string; @@ -1148,19 +1145,24 @@ export class ResourceTemplate { * - `content` if the tool does not have an outputSchema * - Both fields are optional but typically one should be provided */ -export type ToolCallback = +export type ToolCallback = undefined> = Args extends ZodRawShape ? ( args: z.objectOutputType, extra: RequestHandlerExtra, ) => CallToolResult | Promise + : Args extends ZodType + ? ( + args: T, + extra: RequestHandlerExtra, + ) => CallToolResult | Promise : (extra: RequestHandlerExtra) => CallToolResult | Promise; export type RegisteredTool = { title?: string; description?: string; - inputSchema?: AnyZodObject; - outputSchema?: AnyZodObject; + inputSchema?: ZodType; + outputSchema?: ZodType; annotations?: ToolAnnotations; callback: ToolCallback; enabled: boolean; @@ -1203,6 +1205,22 @@ function isZodTypeLike(value: unknown): value is ZodType { 'safeParse' in value && typeof value.safeParse === 'function'; } +/** + * Converts a provided Zod schema to a Zod object if it is a ZodRawShape, + * otherwise returns the schema as is. + */ +function getZodSchemaObject(schema: ZodRawShape | ZodType | undefined): ZodType | undefined { + if (!schema) { + return undefined; + } + + if (isZodRawShape(schema)) { + return z.object(schema); + } + + return schema; +} + /** * Additional, optional information for annotating a resource. */