Skip to content

Commit c80b23b

Browse files
committed
add sampling test
1 parent 3cec764 commit c80b23b

File tree

1 file changed

+97
-4
lines changed

1 file changed

+97
-4
lines changed

exercises/06.sampling/02.solution.advanced/src/index.test.ts

Lines changed: 97 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
11
import { invariant } from '@epic-web/invariant'
22
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
33
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'
4+
import {
5+
CreateMessageRequestSchema,
6+
type CreateMessageResult,
7+
} from '@modelcontextprotocol/sdk/types.js'
48
import { test, beforeAll, afterAll, expect } from 'vitest'
9+
import { type z } from 'zod'
510

611
let client: Client
712

813
beforeAll(async () => {
9-
client = new Client({
10-
name: 'EpicMeTester',
11-
version: '1.0.0',
12-
})
14+
client = new Client(
15+
{
16+
name: 'EpicMeTester',
17+
version: '1.0.0',
18+
},
19+
{
20+
capabilities: {
21+
sampling: {},
22+
},
23+
},
24+
)
1325
const transport = new StdioClientTransport({
1426
command: 'tsx',
1527
args: ['src/index.ts'],
@@ -69,3 +81,84 @@ test('Tool Call', async () => {
6981
}),
7082
)
7183
})
84+
85+
async function deferred<ResolvedValue>() {
86+
const ref = {} as {
87+
promise: Promise<ResolvedValue>
88+
resolve: (value: ResolvedValue) => void
89+
reject: (reason?: any) => void
90+
value: ResolvedValue | undefined
91+
reason: any | undefined
92+
}
93+
ref.promise = new Promise<ResolvedValue>((resolve, reject) => {
94+
ref.resolve = (value) => {
95+
ref.value = value
96+
resolve(value)
97+
}
98+
ref.reject = (reason) => {
99+
ref.reason = reason
100+
reject(reason)
101+
}
102+
})
103+
104+
return ref
105+
}
106+
107+
test('Sampling', async () => {
108+
const messageResultDeferred = await deferred<CreateMessageResult>()
109+
const messageRequestDeferred =
110+
await deferred<z.infer<typeof CreateMessageRequestSchema>>()
111+
112+
client.setRequestHandler(CreateMessageRequestSchema, (r) => {
113+
messageRequestDeferred.resolve(r)
114+
return messageResultDeferred.promise
115+
})
116+
117+
await client.callTool({
118+
name: 'create_entry',
119+
arguments: {
120+
title: 'Test Entry',
121+
content: 'This is a test entry',
122+
},
123+
})
124+
const request = await messageRequestDeferred.promise
125+
126+
expect(request).toEqual(
127+
expect.objectContaining({
128+
method: 'sampling/createMessage',
129+
params: expect.objectContaining({
130+
maxTokens: expect.any(Number),
131+
systemPrompt: expect.stringMatching(/example/i),
132+
messages: expect.arrayContaining([
133+
expect.objectContaining({
134+
role: 'user',
135+
content: expect.objectContaining({
136+
type: 'text',
137+
text: expect.stringMatching(/entry/i),
138+
mimeType: 'application/json',
139+
}),
140+
}),
141+
]),
142+
}),
143+
}),
144+
)
145+
146+
messageResultDeferred.resolve({
147+
model: 'stub-model',
148+
stopReason: 'endTurn',
149+
role: 'assistant',
150+
content: {
151+
type: 'text',
152+
text: JSON.stringify([
153+
{ id: 1 },
154+
{
155+
name: 'Testing Sampling',
156+
description: 'Used when testing sampling. Hope it works',
157+
},
158+
]),
159+
},
160+
})
161+
162+
// give the client a chance to process the result
163+
await new Promise((resolve) => setTimeout(resolve, 100))
164+
})

0 commit comments

Comments
 (0)