Skip to content

Commit 48ebb1c

Browse files
committed
docs(mcp): document tool filtering
1 parent 7d25221 commit 48ebb1c

File tree

11 files changed

+456
-34
lines changed

11 files changed

+456
-34
lines changed

docs/src/content/docs/guides/mcp.mdx

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,23 @@ For **Streamable HTTP** and **Stdio** servers, each time an `Agent` runs it may
9797

9898
Only enable this if you're confident the tool list won't change. To invalidate the cache later, call `invalidateToolsCache()` on the server instance.
9999

100+
### Tool filtering
101+
102+
You can restrict which tools are exposed from each server. Pass either a static filter
103+
using `createStaticToolFilter` or a custom function:
104+
105+
```ts
106+
const server = new MCPServerStdio({
107+
fullCommand: 'my-server',
108+
toolFilter: createStaticToolFilter(['safe_tool'], ['danger_tool']),
109+
});
110+
111+
const dynamicServer = new MCPServerStreamableHttp({
112+
url: 'http://localhost:3000',
113+
toolFilter: ({ runContext }, tool) => runContext.context.allowAll || tool.name !== 'admin',
114+
});
115+
```
116+
100117
## Further reading
101118

102119
- [Model Context Protocol](https://modelcontextprotocol.io/) – official specification.

packages/agents-core/src/agent.ts

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -514,9 +514,11 @@ export class Agent<
514514
* Fetches the available tools from the MCP servers.
515515
* @returns the MCP powered tools
516516
*/
517-
async getMcpTools(): Promise<Tool<TContext>[]> {
517+
async getMcpTools(
518+
runContext: RunContext<TContext>,
519+
): Promise<Tool<TContext>[]> {
518520
if (this.mcpServers.length > 0) {
519-
return getAllMcpTools(this.mcpServers);
521+
return getAllMcpTools(this.mcpServers, false, runContext, this);
520522
}
521523

522524
return [];
@@ -527,8 +529,10 @@ export class Agent<
527529
*
528530
* @returns all configured tools
529531
*/
530-
async getAllTools(): Promise<Tool<TContext>[]> {
531-
return [...(await this.getMcpTools()), ...this.tools];
532+
async getAllTools(
533+
runContext: RunContext<TContext>,
534+
): Promise<Tool<TContext>[]> {
535+
return [...(await this.getMcpTools(runContext)), ...this.tools];
532536
}
533537

534538
/**

packages/agents-core/src/index.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ export {
7474
MCPServerStdio,
7575
MCPServerStreamableHttp,
7676
} from './mcp';
77+
export {
78+
ToolFilterCallable,
79+
ToolFilterContext,
80+
ToolFilterStatic,
81+
createStaticToolFilter,
82+
} from './mcpUtil';
7783
export {
7884
Model,
7985
ModelProvider,

packages/agents-core/src/mcp.ts

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ import {
1414
JsonObjectSchemaStrict,
1515
UnknownContext,
1616
} from './types';
17+
import type { ToolFilterCallable, ToolFilterStatic } from './mcpUtil';
18+
import type { RunContext } from './runContext';
19+
import type { Agent } from './agent';
1720

1821
export const DEFAULT_STDIO_MCP_CLIENT_LOGGER_NAME =
1922
'openai-agents:stdio-mcp-client';
@@ -30,7 +33,10 @@ export interface MCPServer {
3033
connect(): Promise<void>;
3134
readonly name: string;
3235
close(): Promise<void>;
33-
listTools(): Promise<MCPTool[]>;
36+
listTools(
37+
runContext?: RunContext<any>,
38+
agent?: Agent<any, any>,
39+
): Promise<MCPTool[]>;
3440
callTool(
3541
toolName: string,
3642
args: Record<string, unknown> | null,
@@ -41,18 +47,23 @@ export interface MCPServer {
4147
export abstract class BaseMCPServerStdio implements MCPServer {
4248
public cacheToolsList: boolean;
4349
protected _cachedTools: any[] | undefined = undefined;
50+
protected toolFilter?: ToolFilterCallable | ToolFilterStatic;
4451

4552
protected logger: Logger;
4653
constructor(options: MCPServerStdioOptions) {
4754
this.logger =
4855
options.logger ?? getLogger(DEFAULT_STDIO_MCP_CLIENT_LOGGER_NAME);
4956
this.cacheToolsList = options.cacheToolsList ?? false;
57+
this.toolFilter = options.toolFilter;
5058
}
5159

5260
abstract get name(): string;
5361
abstract connect(): Promise<void>;
5462
abstract close(): Promise<void>;
55-
abstract listTools(): Promise<any[]>;
63+
abstract listTools(
64+
runContext?: RunContext<any>,
65+
agent?: Agent<any, any>,
66+
): Promise<any[]>;
5667
abstract callTool(
5768
_toolName: string,
5869
_args: Record<string, unknown> | null,
@@ -74,19 +85,24 @@ export abstract class BaseMCPServerStdio implements MCPServer {
7485
export abstract class BaseMCPServerStreamableHttp implements MCPServer {
7586
public cacheToolsList: boolean;
7687
protected _cachedTools: any[] | undefined = undefined;
88+
protected toolFilter?: ToolFilterCallable | ToolFilterStatic;
7789

7890
protected logger: Logger;
7991
constructor(options: MCPServerStreamableHttpOptions) {
8092
this.logger =
8193
options.logger ??
8294
getLogger(DEFAULT_STREAMABLE_HTTP_MCP_CLIENT_LOGGER_NAME);
8395
this.cacheToolsList = options.cacheToolsList ?? false;
96+
this.toolFilter = options.toolFilter;
8497
}
8598

8699
abstract get name(): string;
87100
abstract connect(): Promise<void>;
88101
abstract close(): Promise<void>;
89-
abstract listTools(): Promise<any[]>;
102+
abstract listTools(
103+
runContext?: RunContext<any>,
104+
agent?: Agent<any, any>,
105+
): Promise<any[]>;
90106
abstract callTool(
91107
_toolName: string,
92108
_args: Record<string, unknown> | null,
@@ -141,11 +157,14 @@ export class MCPServerStdio extends BaseMCPServerStdio {
141157
close(): Promise<void> {
142158
return this.underlying.close();
143159
}
144-
async listTools(): Promise<MCPTool[]> {
160+
async listTools(
161+
runContext?: RunContext<any>,
162+
agent?: Agent<any, any>,
163+
): Promise<MCPTool[]> {
145164
if (this.cacheToolsList && this._cachedTools) {
146165
return this._cachedTools;
147166
}
148-
const tools = await this.underlying.listTools();
167+
const tools = await this.underlying.listTools(runContext, agent);
149168
if (this.cacheToolsList) {
150169
this._cachedTools = tools;
151170
}
@@ -177,11 +196,14 @@ export class MCPServerStreamableHttp extends BaseMCPServerStreamableHttp {
177196
close(): Promise<void> {
178197
return this.underlying.close();
179198
}
180-
async listTools(): Promise<MCPTool[]> {
199+
async listTools(
200+
runContext?: RunContext<any>,
201+
agent?: Agent<any, any>,
202+
): Promise<MCPTool[]> {
181203
if (this.cacheToolsList && this._cachedTools) {
182204
return this._cachedTools;
183205
}
184-
const tools = await this.underlying.listTools();
206+
const tools = await this.underlying.listTools(runContext, agent);
185207
if (this.cacheToolsList) {
186208
this._cachedTools = tools;
187209
}
@@ -205,13 +227,17 @@ export class MCPServerStreamableHttp extends BaseMCPServerStreamableHttp {
205227
export async function getAllMcpFunctionTools<TContext = UnknownContext>(
206228
mcpServers: MCPServer[],
207229
convertSchemasToStrict = false,
230+
runContext?: RunContext<TContext>,
231+
agent?: Agent<TContext, any>,
208232
): Promise<Tool<TContext>[]> {
209233
const allTools: Tool<TContext>[] = [];
210234
const toolNames = new Set<string>();
211235
for (const server of mcpServers) {
212236
const serverTools = await getFunctionToolsFromServer(
213237
server,
214238
convertSchemasToStrict,
239+
runContext,
240+
agent,
215241
);
216242
const serverToolNames = new Set(serverTools.map((t) => t.name));
217243
const intersection = [...serverToolNames].filter((n) => toolNames.has(n));
@@ -243,6 +269,8 @@ export async function invalidateServerToolsCache(serverName: string) {
243269
async function getFunctionToolsFromServer<TContext = UnknownContext>(
244270
server: MCPServer,
245271
convertSchemasToStrict: boolean,
272+
runContext?: RunContext<TContext>,
273+
agent?: Agent<TContext, any>,
246274
): Promise<FunctionTool<TContext, any, unknown>[]> {
247275
if (server.cacheToolsList && _cachedTools[server.name]) {
248276
return _cachedTools[server.name].map((t) =>
@@ -251,7 +279,7 @@ async function getFunctionToolsFromServer<TContext = UnknownContext>(
251279
}
252280
return withMCPListToolsSpan(
253281
async (span) => {
254-
const mcpTools = await server.listTools();
282+
const mcpTools = await server.listTools(runContext, agent);
255283
span.spanData.result = mcpTools.map((t) => t.name);
256284
const tools: FunctionTool<TContext, any, string>[] = mcpTools.map((t) =>
257285
mcpToFunctionTool(t, server, convertSchemasToStrict),
@@ -271,8 +299,15 @@ async function getFunctionToolsFromServer<TContext = UnknownContext>(
271299
export async function getAllMcpTools<TContext = UnknownContext>(
272300
mcpServers: MCPServer[],
273301
convertSchemasToStrict = false,
302+
runContext?: RunContext<TContext>,
303+
agent?: Agent<TContext, any>,
274304
): Promise<Tool<TContext>[]> {
275-
return getAllMcpFunctionTools(mcpServers, convertSchemasToStrict);
305+
return getAllMcpFunctionTools(
306+
mcpServers,
307+
convertSchemasToStrict,
308+
runContext,
309+
agent,
310+
);
276311
}
277312

278313
/**
@@ -363,6 +398,7 @@ export interface BaseMCPServerStdioOptions {
363398
encoding?: string;
364399
encodingErrorHandler?: 'strict' | 'ignore' | 'replace';
365400
logger?: Logger;
401+
toolFilter?: ToolFilterCallable | ToolFilterStatic;
366402
}
367403
export interface DefaultMCPServerStdioOptions
368404
extends BaseMCPServerStdioOptions {
@@ -383,6 +419,7 @@ export interface MCPServerStreamableHttpOptions {
383419
clientSessionTimeoutSeconds?: number;
384420
name?: string;
385421
logger?: Logger;
422+
toolFilter?: ToolFilterCallable | ToolFilterStatic;
386423

387424
// ----------------------------------------------------
388425
// OAuth
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import type { Agent } from './agent';
2+
import type { RunContext } from './runContext';
3+
import type { MCPTool } from './mcp';
4+
import type { UnknownContext } from './types';
5+
6+
/** Context information available to tool filter functions. */
7+
export interface ToolFilterContext<TContext = UnknownContext> {
8+
/** The current run context. */
9+
runContext: RunContext<TContext>;
10+
/** The agent requesting the tools. */
11+
agent: Agent<TContext, any>;
12+
/** Name of the MCP server providing the tools. */
13+
serverName: string;
14+
}
15+
16+
/** A function that determines whether a tool should be available. */
17+
export type ToolFilterCallable<TContext = UnknownContext> = (
18+
context: ToolFilterContext<TContext>,
19+
tool: MCPTool,
20+
) => boolean | Promise<boolean>;
21+
22+
/** Static tool filter configuration using allow and block lists. */
23+
export interface ToolFilterStatic {
24+
/** Optional list of tool names to allow. */
25+
allowedToolNames?: string[];
26+
/** Optional list of tool names to block. */
27+
blockedToolNames?: string[];
28+
}
29+
30+
/** Convenience helper to create a static tool filter. */
31+
export function createStaticToolFilter(
32+
allowedToolNames?: string[],
33+
blockedToolNames?: string[],
34+
): ToolFilterStatic | undefined {
35+
if (!allowedToolNames && !blockedToolNames) {
36+
return undefined;
37+
}
38+
const filter: ToolFilterStatic = {};
39+
if (allowedToolNames) {
40+
filter.allowedToolNames = allowedToolNames;
41+
}
42+
if (blockedToolNames) {
43+
filter.blockedToolNames = blockedToolNames;
44+
}
45+
return filter;
46+
}

packages/agents-core/src/run.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
322322
setCurrentSpan(state._currentAgentSpan);
323323
}
324324

325-
const tools = await state._currentAgent.getAllTools();
325+
const tools = await state._currentAgent.getAllTools(state._context);
326326
const serializedTools = tools.map((t) => serializeTool(t));
327327
const serializedHandoffs = handoffs.map((h) => serializeHandoff(h));
328328
if (state._currentAgentSpan) {
@@ -615,7 +615,7 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
615615
while (true) {
616616
const currentAgent = result.state._currentAgent;
617617
const handoffs = currentAgent.handoffs.map(getHandoff);
618-
const tools = await currentAgent.getAllTools();
618+
const tools = await currentAgent.getAllTools(result.state._context);
619619
const serializedTools = tools.map((t) => serializeTool(t));
620620
const serializedHandoffs = handoffs.map((h) => serializeHandoff(h));
621621

packages/agents-core/src/runState.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,7 @@ export class RunState<TContext, TAgent extends Agent<any, any>> {
558558
agentMap,
559559
state._currentAgent,
560560
stateJson.lastProcessedResponse,
561+
state._context,
561562
)
562563
: undefined;
563564

@@ -710,8 +711,9 @@ async function deserializeProcessedResponse<TContext = UnknownContext>(
710711
serializedProcessedResponse: z.infer<
711712
typeof serializedProcessedResponseSchema
712713
>,
714+
runContext: RunContext<TContext>,
713715
): Promise<ProcessedResponse<TContext>> {
714-
const allTools = await currentAgent.getAllTools();
716+
const allTools = await currentAgent.getAllTools(runContext);
715717
const tools = new Map(
716718
allTools
717719
.filter((tool) => tool.type === 'function')

packages/agents-core/src/shims/mcp-server/browser.ts

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import {
66
MCPServerStreamableHttpOptions,
77
MCPTool,
88
} from '../../mcp';
9+
import type { RunContext } from '../../runContext';
10+
import type { Agent } from '../../agent';
911

1012
export class MCPServerStdio extends BaseMCPServerStdio {
1113
constructor(params: MCPServerStdioOptions) {
@@ -20,7 +22,10 @@ export class MCPServerStdio extends BaseMCPServerStdio {
2022
close(): Promise<void> {
2123
throw new Error('Method not implemented.');
2224
}
23-
listTools(): Promise<MCPTool[]> {
25+
listTools(
26+
_runContext?: RunContext<any>,
27+
_agent?: Agent<any, any>,
28+
): Promise<MCPTool[]> {
2429
throw new Error('Method not implemented.');
2530
}
2631
callTool(
@@ -47,7 +52,10 @@ export class MCPServerStreamableHttp extends BaseMCPServerStreamableHttp {
4752
close(): Promise<void> {
4853
throw new Error('Method not implemented.');
4954
}
50-
listTools(): Promise<MCPTool[]> {
55+
listTools(
56+
_runContext?: RunContext<any>,
57+
_agent?: Agent<any, any>,
58+
): Promise<MCPTool[]> {
5159
throw new Error('Method not implemented.');
5260
}
5361
callTool(

0 commit comments

Comments
 (0)