|
1 | 1 | import { invariant } from '@epic-web/invariant' |
2 | 2 | import { Client } from '@modelcontextprotocol/sdk/client/index.js' |
3 | 3 | import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js' |
| 4 | +import { |
| 5 | + CreateMessageRequestSchema, |
| 6 | + type CreateMessageResult, |
| 7 | +} from '@modelcontextprotocol/sdk/types.js' |
4 | 8 | import { test, beforeAll, afterAll, expect } from 'vitest' |
| 9 | +import { type z } from 'zod' |
5 | 10 |
|
6 | 11 | let client: Client |
7 | 12 |
|
8 | 13 | beforeAll(async () => { |
9 | | - client = new Client({ |
10 | | - name: 'EpicMeTester', |
11 | | - version: '1.0.0', |
12 | | - }) |
| 14 | + client = new Client( |
| 15 | + { |
| 16 | + name: 'EpicMeTester', |
| 17 | + version: '1.0.0', |
| 18 | + }, |
| 19 | + { |
| 20 | + capabilities: { |
| 21 | + sampling: {}, |
| 22 | + }, |
| 23 | + }, |
| 24 | + ) |
13 | 25 | const transport = new StdioClientTransport({ |
14 | 26 | command: 'tsx', |
15 | 27 | args: ['src/index.ts'], |
@@ -69,3 +81,84 @@ test('Tool Call', async () => { |
69 | 81 | }), |
70 | 82 | ) |
71 | 83 | }) |
| 84 | + |
| 85 | +async function deferred<ResolvedValue>() { |
| 86 | + const ref = {} as { |
| 87 | + promise: Promise<ResolvedValue> |
| 88 | + resolve: (value: ResolvedValue) => void |
| 89 | + reject: (reason?: any) => void |
| 90 | + value: ResolvedValue | undefined |
| 91 | + reason: any | undefined |
| 92 | + } |
| 93 | + ref.promise = new Promise<ResolvedValue>((resolve, reject) => { |
| 94 | + ref.resolve = (value) => { |
| 95 | + ref.value = value |
| 96 | + resolve(value) |
| 97 | + } |
| 98 | + ref.reject = (reason) => { |
| 99 | + ref.reason = reason |
| 100 | + reject(reason) |
| 101 | + } |
| 102 | + }) |
| 103 | + |
| 104 | + return ref |
| 105 | +} |
| 106 | + |
| 107 | +test('Sampling', async () => { |
| 108 | + const messageResultDeferred = await deferred<CreateMessageResult>() |
| 109 | + const messageRequestDeferred = |
| 110 | + await deferred<z.infer<typeof CreateMessageRequestSchema>>() |
| 111 | + |
| 112 | + client.setRequestHandler(CreateMessageRequestSchema, (r) => { |
| 113 | + messageRequestDeferred.resolve(r) |
| 114 | + return messageResultDeferred.promise |
| 115 | + }) |
| 116 | + |
| 117 | + await client.callTool({ |
| 118 | + name: 'create_entry', |
| 119 | + arguments: { |
| 120 | + title: 'Test Entry', |
| 121 | + content: 'This is a test entry', |
| 122 | + }, |
| 123 | + }) |
| 124 | + const request = await messageRequestDeferred.promise |
| 125 | + |
| 126 | + expect(request).toEqual( |
| 127 | + expect.objectContaining({ |
| 128 | + method: 'sampling/createMessage', |
| 129 | + params: expect.objectContaining({ |
| 130 | + maxTokens: expect.any(Number), |
| 131 | + systemPrompt: expect.stringMatching(/example/i), |
| 132 | + messages: expect.arrayContaining([ |
| 133 | + expect.objectContaining({ |
| 134 | + role: 'user', |
| 135 | + content: expect.objectContaining({ |
| 136 | + type: 'text', |
| 137 | + text: expect.stringMatching(/entry/i), |
| 138 | + mimeType: 'application/json', |
| 139 | + }), |
| 140 | + }), |
| 141 | + ]), |
| 142 | + }), |
| 143 | + }), |
| 144 | + ) |
| 145 | + |
| 146 | + messageResultDeferred.resolve({ |
| 147 | + model: 'stub-model', |
| 148 | + stopReason: 'endTurn', |
| 149 | + role: 'assistant', |
| 150 | + content: { |
| 151 | + type: 'text', |
| 152 | + text: JSON.stringify([ |
| 153 | + { id: 1 }, |
| 154 | + { |
| 155 | + name: 'Testing Sampling', |
| 156 | + description: 'Used when testing sampling. Hope it works', |
| 157 | + }, |
| 158 | + ]), |
| 159 | + }, |
| 160 | + }) |
| 161 | + |
| 162 | + // give the client a chance to process the result |
| 163 | + await new Promise((resolve) => setTimeout(resolve, 100)) |
| 164 | +}) |
0 commit comments