Skip to content

Commit 2a55dfd

Browse files
authored
sampling: validate tools, tool_use, tool_result constraints (#1156)
1 parent 41c6b35 commit 2a55dfd

File tree

2 files changed

+380
-1
lines changed

2 files changed

+380
-1
lines changed

src/server/index.test.ts

Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,6 +1582,341 @@ test('should respect log level for transport without sessionId', async () => {
15821582
expect(clientTransport.onmessage).toHaveBeenCalled();
15831583
});
15841584

1585+
describe('createMessage validation', () => {
1586+
test('should throw when tools are provided without sampling.tools capability', async () => {
1587+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1588+
1589+
const client = new Client(
1590+
{ name: 'test client', version: '1.0' },
1591+
{ capabilities: { sampling: {} } } // No tools capability
1592+
);
1593+
1594+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1595+
model: 'test-model',
1596+
role: 'assistant',
1597+
content: { type: 'text', text: 'Response' }
1598+
}));
1599+
1600+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1601+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1602+
1603+
await expect(
1604+
server.createMessage({
1605+
messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }],
1606+
maxTokens: 100,
1607+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }]
1608+
})
1609+
).rejects.toThrow('Client does not support sampling tools capability.');
1610+
});
1611+
1612+
test('should throw when toolChoice is provided without sampling.tools capability', async () => {
1613+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1614+
1615+
const client = new Client(
1616+
{ name: 'test client', version: '1.0' },
1617+
{ capabilities: { sampling: {} } } // No tools capability
1618+
);
1619+
1620+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1621+
model: 'test-model',
1622+
role: 'assistant',
1623+
content: { type: 'text', text: 'Response' }
1624+
}));
1625+
1626+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1627+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1628+
1629+
await expect(
1630+
server.createMessage({
1631+
messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }],
1632+
maxTokens: 100,
1633+
toolChoice: { mode: 'auto' }
1634+
})
1635+
).rejects.toThrow('Client does not support sampling tools capability.');
1636+
});
1637+
1638+
test('should throw when tool_result is mixed with other content', async () => {
1639+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1640+
1641+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
1642+
1643+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1644+
model: 'test-model',
1645+
role: 'assistant',
1646+
content: { type: 'text', text: 'Response' }
1647+
}));
1648+
1649+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1650+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1651+
1652+
await expect(
1653+
server.createMessage({
1654+
messages: [
1655+
{ role: 'user', content: { type: 'text', text: 'hello' } },
1656+
{ role: 'assistant', content: { type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} } },
1657+
{
1658+
role: 'user',
1659+
content: [
1660+
{ type: 'tool_result', toolUseId: 'call_1', content: [] },
1661+
{ type: 'text', text: 'mixed content' } // Mixed!
1662+
]
1663+
}
1664+
],
1665+
maxTokens: 100,
1666+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }]
1667+
})
1668+
).rejects.toThrow('The last message must contain only tool_result content if any is present');
1669+
});
1670+
1671+
test('should throw when tool_result has no matching tool_use in previous message', async () => {
1672+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1673+
1674+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
1675+
1676+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1677+
model: 'test-model',
1678+
role: 'assistant',
1679+
content: { type: 'text', text: 'Response' }
1680+
}));
1681+
1682+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1683+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1684+
1685+
// tool_result without previous tool_use
1686+
await expect(
1687+
server.createMessage({
1688+
messages: [
1689+
{ role: 'user', content: { type: 'text', text: 'hello' } },
1690+
{ role: 'user', content: { type: 'tool_result', toolUseId: 'call_1', content: [] } }
1691+
],
1692+
maxTokens: 100,
1693+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }]
1694+
})
1695+
).rejects.toThrow('tool_result blocks are not matching any tool_use from the previous message');
1696+
});
1697+
1698+
test('should throw when tool_result IDs do not match tool_use IDs', async () => {
1699+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1700+
1701+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
1702+
1703+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1704+
model: 'test-model',
1705+
role: 'assistant',
1706+
content: { type: 'text', text: 'Response' }
1707+
}));
1708+
1709+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1710+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1711+
1712+
await expect(
1713+
server.createMessage({
1714+
messages: [
1715+
{ role: 'user', content: { type: 'text', text: 'hello' } },
1716+
{ role: 'assistant', content: { type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} } },
1717+
{ role: 'user', content: { type: 'tool_result', toolUseId: 'wrong_id', content: [] } }
1718+
],
1719+
maxTokens: 100,
1720+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }]
1721+
})
1722+
).rejects.toThrow('ids of tool_result blocks and tool_use blocks from previous message do not match');
1723+
});
1724+
1725+
test('should allow text-only messages with tools (no tool_results)', async () => {
1726+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1727+
1728+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
1729+
1730+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1731+
model: 'test-model',
1732+
role: 'assistant',
1733+
content: { type: 'text', text: 'Response' }
1734+
}));
1735+
1736+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1737+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1738+
1739+
await expect(
1740+
server.createMessage({
1741+
messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }],
1742+
maxTokens: 100,
1743+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }]
1744+
})
1745+
).resolves.toMatchObject({ model: 'test-model' });
1746+
});
1747+
1748+
test('should allow valid matching tool_result/tool_use IDs', async () => {
1749+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1750+
1751+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
1752+
1753+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1754+
model: 'test-model',
1755+
role: 'assistant',
1756+
content: { type: 'text', text: 'Response' }
1757+
}));
1758+
1759+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1760+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1761+
1762+
await expect(
1763+
server.createMessage({
1764+
messages: [
1765+
{ role: 'user', content: { type: 'text', text: 'hello' } },
1766+
{ role: 'assistant', content: { type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} } },
1767+
{ role: 'user', content: { type: 'tool_result', toolUseId: 'call_1', content: [] } }
1768+
],
1769+
maxTokens: 100,
1770+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }]
1771+
})
1772+
).resolves.toMatchObject({ model: 'test-model' });
1773+
});
1774+
1775+
test('should throw when user sends text instead of tool_result after tool_use', async () => {
1776+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1777+
1778+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
1779+
1780+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1781+
model: 'test-model',
1782+
role: 'assistant',
1783+
content: { type: 'text', text: 'Response' }
1784+
}));
1785+
1786+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1787+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1788+
1789+
// User ignores tool_use and sends text instead
1790+
await expect(
1791+
server.createMessage({
1792+
messages: [
1793+
{ role: 'user', content: { type: 'text', text: 'hello' } },
1794+
{ role: 'assistant', content: { type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} } },
1795+
{ role: 'user', content: { type: 'text', text: 'actually nevermind' } }
1796+
],
1797+
maxTokens: 100,
1798+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }]
1799+
})
1800+
).rejects.toThrow('ids of tool_result blocks and tool_use blocks from previous message do not match');
1801+
});
1802+
1803+
test('should throw when only some tool_results are provided for parallel tool_use', async () => {
1804+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1805+
1806+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
1807+
1808+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1809+
model: 'test-model',
1810+
role: 'assistant',
1811+
content: { type: 'text', text: 'Response' }
1812+
}));
1813+
1814+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1815+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1816+
1817+
// Parallel tool_use but only one tool_result provided
1818+
await expect(
1819+
server.createMessage({
1820+
messages: [
1821+
{ role: 'user', content: { type: 'text', text: 'hello' } },
1822+
{
1823+
role: 'assistant',
1824+
content: [
1825+
{ type: 'tool_use', id: 'call_1', name: 'tool_a', input: {} },
1826+
{ type: 'tool_use', id: 'call_2', name: 'tool_b', input: {} }
1827+
]
1828+
},
1829+
{ role: 'user', content: { type: 'tool_result', toolUseId: 'call_1', content: [] } }
1830+
],
1831+
maxTokens: 100,
1832+
tools: [
1833+
{ name: 'tool_a', inputSchema: { type: 'object' } },
1834+
{ name: 'tool_b', inputSchema: { type: 'object' } }
1835+
]
1836+
})
1837+
).rejects.toThrow('ids of tool_result blocks and tool_use blocks from previous message do not match');
1838+
});
1839+
1840+
test('should validate tool_use/tool_result even without tools in current request', async () => {
1841+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1842+
1843+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
1844+
1845+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1846+
model: 'test-model',
1847+
role: 'assistant',
1848+
content: { type: 'text', text: 'Response' }
1849+
}));
1850+
1851+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1852+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1853+
1854+
// Previous request returned tool_use, now sending tool_result without tools param
1855+
await expect(
1856+
server.createMessage({
1857+
messages: [
1858+
{ role: 'user', content: { type: 'text', text: 'hello' } },
1859+
{ role: 'assistant', content: { type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} } },
1860+
{ role: 'user', content: { type: 'tool_result', toolUseId: 'wrong_id', content: [] } }
1861+
],
1862+
maxTokens: 100
1863+
// Note: no tools param - this is a follow-up request after tool execution
1864+
})
1865+
).rejects.toThrow('ids of tool_result blocks and tool_use blocks from previous message do not match');
1866+
});
1867+
1868+
test('should allow valid tool_use/tool_result without tools in current request', async () => {
1869+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1870+
1871+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
1872+
1873+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1874+
model: 'test-model',
1875+
role: 'assistant',
1876+
content: { type: 'text', text: 'Response' }
1877+
}));
1878+
1879+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1880+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1881+
1882+
// Previous request returned tool_use, now sending matching tool_result without tools param
1883+
await expect(
1884+
server.createMessage({
1885+
messages: [
1886+
{ role: 'user', content: { type: 'text', text: 'hello' } },
1887+
{ role: 'assistant', content: { type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} } },
1888+
{ role: 'user', content: { type: 'tool_result', toolUseId: 'call_1', content: [] } }
1889+
],
1890+
maxTokens: 100
1891+
// Note: no tools param - this is a follow-up request after tool execution
1892+
})
1893+
).resolves.toMatchObject({ model: 'test-model' });
1894+
});
1895+
1896+
test('should handle empty messages array', async () => {
1897+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1898+
1899+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } });
1900+
1901+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1902+
model: 'test-model',
1903+
role: 'assistant',
1904+
content: { type: 'text', text: 'Response' }
1905+
}));
1906+
1907+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1908+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1909+
1910+
// Empty messages array should not crash
1911+
await expect(
1912+
server.createMessage({
1913+
messages: [],
1914+
maxTokens: 100
1915+
})
1916+
).resolves.toMatchObject({ model: 'test-model' });
1917+
});
1918+
});
1919+
15851920
test('should respect log level for transport with sessionId', async () => {
15861921
const server = new Server(
15871922
{

0 commit comments

Comments
 (0)