Skip to content

Commit 3e2c34b

Browse files
committed
Merge remote-tracking branch 'seratch/vrtnis-feat/mcp-tool-filtering-js' into feat/mcp-tool-filtering-js
2 parents ec4800a + 73eb37a commit 3e2c34b

File tree

9 files changed

+99
-108
lines changed

9 files changed

+99
-108
lines changed

examples/mcp/tool-filter-example.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ async function main() {
2424
await withTrace('MCP Tool Filter Example', async () => {
2525
const agent = new Agent({
2626
name: 'MCP Assistant',
27-
instructions:
28-
'Use the filesystem tools to answer questions. The write_file tool is blocked via toolFilter.',
27+
instructions: 'Use the filesystem tools to answer questions.',
2928
mcpServers: [mcpServer],
3029
});
3130

@@ -36,7 +35,7 @@ async function main() {
3635
console.log('\nAttempting to write a file (should be blocked):');
3736
result = await run(
3837
agent,
39-
'Create a file named test.txt with the text "hello"',
38+
'Create a file named sample_files/test.txt with the text "hello"',
4039
);
4140
console.log(result.finalOutput);
4241
});

packages/agents-core/src/agent.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ export class Agent<
518518
runContext: RunContext<TContext>,
519519
): Promise<Tool<TContext>[]> {
520520
if (this.mcpServers.length > 0) {
521-
return getAllMcpTools(this.mcpServers, false, runContext, this);
521+
return getAllMcpTools(this.mcpServers, runContext, this, false);
522522
}
523523

524524
return [];

packages/agents-core/src/mcp.ts

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,11 @@ export const DEFAULT_STREAMABLE_HTTP_MCP_CLIENT_LOGGER_NAME =
3030
*/
3131
export interface MCPServer {
3232
cacheToolsList: boolean;
33+
toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic;
3334
connect(): Promise<void>;
3435
readonly name: string;
3536
close(): Promise<void>;
36-
listTools(
37-
runContext?: RunContext<any>,
38-
agent?: Agent<any, any>,
39-
): Promise<MCPTool[]>;
37+
listTools(): Promise<MCPTool[]>;
4038
callTool(
4139
toolName: string,
4240
args: Record<string, unknown> | null,
@@ -46,7 +44,7 @@ export interface MCPServer {
4644
export abstract class BaseMCPServerStdio implements MCPServer {
4745
public cacheToolsList: boolean;
4846
protected _cachedTools: any[] | undefined = undefined;
49-
protected toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic;
47+
public toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic;
5048

5149
protected logger: Logger;
5250
constructor(options: MCPServerStdioOptions) {
@@ -59,10 +57,7 @@ export abstract class BaseMCPServerStdio implements MCPServer {
5957
abstract get name(): string;
6058
abstract connect(): Promise<void>;
6159
abstract close(): Promise<void>;
62-
abstract listTools(
63-
runContext?: RunContext<any>,
64-
agent?: Agent<any, any>,
65-
): Promise<any[]>;
60+
abstract listTools(): Promise<any[]>;
6661
abstract callTool(
6762
_toolName: string,
6863
_args: Record<string, unknown> | null,
@@ -83,7 +78,7 @@ export abstract class BaseMCPServerStdio implements MCPServer {
8378
export abstract class BaseMCPServerStreamableHttp implements MCPServer {
8479
public cacheToolsList: boolean;
8580
protected _cachedTools: any[] | undefined = undefined;
86-
protected toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic;
81+
public toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic;
8782

8883
protected logger: Logger;
8984
constructor(options: MCPServerStreamableHttpOptions) {
@@ -97,10 +92,7 @@ export abstract class BaseMCPServerStreamableHttp implements MCPServer {
9792
abstract get name(): string;
9893
abstract connect(): Promise<void>;
9994
abstract close(): Promise<void>;
100-
abstract listTools(
101-
runContext?: RunContext<any>,
102-
agent?: Agent<any, any>,
103-
): Promise<any[]>;
95+
abstract listTools(): Promise<any[]>;
10496
abstract callTool(
10597
_toolName: string,
10698
_args: Record<string, unknown> | null,
@@ -154,14 +146,11 @@ export class MCPServerStdio extends BaseMCPServerStdio {
154146
close(): Promise<void> {
155147
return this.underlying.close();
156148
}
157-
async listTools(
158-
runContext?: RunContext<any>,
159-
agent?: Agent<any, any>,
160-
): Promise<MCPTool[]> {
149+
async listTools(): Promise<MCPTool[]> {
161150
if (this.cacheToolsList && this._cachedTools) {
162151
return this._cachedTools;
163152
}
164-
const tools = await this.underlying.listTools(runContext, agent);
153+
const tools = await this.underlying.listTools();
165154
if (this.cacheToolsList) {
166155
this._cachedTools = tools;
167156
}
@@ -190,14 +179,11 @@ export class MCPServerStreamableHttp extends BaseMCPServerStreamableHttp {
190179
close(): Promise<void> {
191180
return this.underlying.close();
192181
}
193-
async listTools(
194-
runContext?: RunContext<any>,
195-
agent?: Agent<any, any>,
196-
): Promise<MCPTool[]> {
182+
async listTools(): Promise<MCPTool[]> {
197183
if (this.cacheToolsList && this._cachedTools) {
198184
return this._cachedTools;
199185
}
200-
const tools = await this.underlying.listTools(runContext, agent);
186+
const tools = await this.underlying.listTools();
201187
if (this.cacheToolsList) {
202188
this._cachedTools = tools;
203189
}
@@ -217,18 +203,18 @@ export class MCPServerStreamableHttp extends BaseMCPServerStreamableHttp {
217203
*/
218204
export async function getAllMcpFunctionTools<TContext = UnknownContext>(
219205
mcpServers: MCPServer[],
206+
runContext: RunContext<TContext>,
207+
agent: Agent<any, any>,
220208
convertSchemasToStrict = false,
221-
runContext?: RunContext<TContext>,
222-
agent?: Agent<TContext, any>,
223209
): Promise<Tool<TContext>[]> {
224210
const allTools: Tool<TContext>[] = [];
225211
const toolNames = new Set<string>();
226212
for (const server of mcpServers) {
227213
const serverTools = await getFunctionToolsFromServer(
228214
server,
229-
convertSchemasToStrict,
230215
runContext,
231216
agent,
217+
convertSchemasToStrict,
232218
);
233219
const serverToolNames = new Set(serverTools.map((t) => t.name));
234220
const intersection = [...serverToolNames].filter((n) => toolNames.has(n));
@@ -259,16 +245,62 @@ export function invalidateServerToolsCache(serverName: string) {
259245
*/
260246
async function getFunctionToolsFromServer<TContext = UnknownContext>(
261247
server: MCPServer,
248+
runContext: RunContext<TContext>,
249+
agent: Agent<any, any>,
262250
convertSchemasToStrict: boolean,
263-
runContext?: RunContext<TContext>,
264-
agent?: Agent<TContext, any>,
265251
): Promise<FunctionTool<TContext, any, unknown>[]> {
266252
if (server.cacheToolsList && _cachedTools[server.name]) {
267253
return _cachedTools[server.name];
268254
}
269255
return withMCPListToolsSpan(
270256
async (span) => {
271-
const mcpTools = await server.listTools(runContext, agent);
257+
const fetchedMcpTools = await server.listTools();
258+
const mcpTools: MCPTool[] = [];
259+
const context = {
260+
runContext,
261+
agent,
262+
serverName: server.name,
263+
};
264+
for (const tool of fetchedMcpTools) {
265+
const filter = server.toolFilter;
266+
if (filter) {
267+
if (filter && typeof filter === 'function') {
268+
const filtered = await filter(context, tool);
269+
if (!filtered) {
270+
globalLogger.debug(
271+
`MCP Tool (server: ${server.name}, tool: ${tool.name}) is blocked by the callable filter.`,
272+
);
273+
continue; // skip this tool
274+
}
275+
} else {
276+
const allowedToolNames = filter.allowedToolNames ?? [];
277+
const blockedToolNames = filter.blockedToolNames ?? [];
278+
if (allowedToolNames.length > 0 || blockedToolNames.length > 0) {
279+
const allowed =
280+
allowedToolNames.length > 0
281+
? allowedToolNames.includes(tool.name)
282+
: true;
283+
const blocked =
284+
blockedToolNames.length > 0
285+
? blockedToolNames.includes(tool.name)
286+
: false;
287+
if (!allowed || blocked) {
288+
if (blocked) {
289+
globalLogger.debug(
290+
`MCP Tool (server: ${server.name}, tool: ${tool.name}) is blocked by the static filter.`,
291+
);
292+
} else if (!allowed) {
293+
globalLogger.debug(
294+
`MCP Tool (server: ${server.name}, tool: ${tool.name}) is not allowed by the static filter.`,
295+
);
296+
}
297+
continue; // skip this tool
298+
}
299+
}
300+
}
301+
}
302+
mcpTools.push(tool);
303+
}
272304
span.spanData.result = mcpTools.map((t) => t.name);
273305
const tools: FunctionTool<TContext, any, string>[] = mcpTools.map((t) =>
274306
mcpToFunctionTool(t, server, convertSchemasToStrict),
@@ -287,15 +319,15 @@ async function getFunctionToolsFromServer<TContext = UnknownContext>(
287319
*/
288320
export async function getAllMcpTools<TContext = UnknownContext>(
289321
mcpServers: MCPServer[],
322+
runContext: RunContext<TContext>,
323+
agent: Agent<TContext, any>,
290324
convertSchemasToStrict = false,
291-
runContext?: RunContext<TContext>,
292-
agent?: Agent<TContext, any>,
293325
): Promise<Tool<TContext>[]> {
294326
return getAllMcpFunctionTools(
295327
mcpServers,
296-
convertSchemasToStrict,
297328
runContext,
298329
agent,
330+
convertSchemasToStrict,
299331
);
300332
}
301333

packages/agents-core/src/mcpUtil.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ export interface MCPToolFilterContext<TContext = UnknownContext> {
1717
export type MCPToolFilterCallable<TContext = UnknownContext> = (
1818
context: MCPToolFilterContext<TContext>,
1919
tool: MCPTool,
20-
) => boolean | Promise<boolean>;
20+
) => Promise<boolean>;
2121

2222
/** Static tool filter configuration using allow and block lists. */
2323
export interface MCPToolFilterStatic {

packages/agents-core/src/runState.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -557,8 +557,8 @@ export class RunState<TContext, TAgent extends Agent<any, any>> {
557557
? await deserializeProcessedResponse(
558558
agentMap,
559559
state._currentAgent,
560-
stateJson.lastProcessedResponse,
561560
state._context,
561+
stateJson.lastProcessedResponse,
562562
)
563563
: undefined;
564564

@@ -708,12 +708,12 @@ export function deserializeItem(
708708
async function deserializeProcessedResponse<TContext = UnknownContext>(
709709
agentMap: Map<string, Agent<any, any>>,
710710
currentAgent: Agent<TContext, any>,
711+
context: RunContext<TContext>,
711712
serializedProcessedResponse: z.infer<
712713
typeof serializedProcessedResponseSchema
713714
>,
714-
runContext: RunContext<TContext>,
715715
): Promise<ProcessedResponse<TContext>> {
716-
const allTools = await currentAgent.getAllTools(runContext);
716+
const allTools = await currentAgent.getAllTools(context);
717717
const tools = new Map(
718718
allTools
719719
.filter((tool) => tool.type === 'function')

packages/agents-core/src/shims/mcp-server/browser.ts

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ import {
66
MCPServerStreamableHttpOptions,
77
MCPTool,
88
} from '../../mcp';
9-
import type { RunContext } from '../../runContext';
10-
import type { Agent } from '../../agent';
119

1210
export class MCPServerStdio extends BaseMCPServerStdio {
1311
constructor(params: MCPServerStdioOptions) {
@@ -22,10 +20,7 @@ export class MCPServerStdio extends BaseMCPServerStdio {
2220
close(): Promise<void> {
2321
throw new Error('Method not implemented.');
2422
}
25-
listTools(
26-
_runContext?: RunContext<any>,
27-
_agent?: Agent<any, any>,
28-
): Promise<MCPTool[]> {
23+
listTools(): Promise<MCPTool[]> {
2924
throw new Error('Method not implemented.');
3025
}
3126
callTool(
@@ -49,10 +44,7 @@ export class MCPServerStreamableHttp extends BaseMCPServerStreamableHttp {
4944
close(): Promise<void> {
5045
throw new Error('Method not implemented.');
5146
}
52-
listTools(
53-
_runContext?: RunContext<any>,
54-
_agent?: Agent<any, any>,
55-
): Promise<MCPTool[]> {
47+
listTools(): Promise<MCPTool[]> {
5648
throw new Error('Method not implemented.');
5749
}
5850
callTool(

packages/agents-core/src/shims/mcp-server/node.ts

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ import {
1212
invalidateServerToolsCache,
1313
} from '../../mcp';
1414
import logger from '../../logger';
15-
import type { RunContext } from '../../runContext';
16-
import type { Agent } from '../../agent';
1715

1816
export interface SessionMessage {
1917
message: any;
@@ -98,10 +96,7 @@ export class NodeMCPServerStdio extends BaseMCPServerStdio {
9896
this._cacheDirty = true;
9997
}
10098

101-
async listTools(
102-
_runContext?: RunContext<any>,
103-
_agent?: Agent<any, any>,
104-
): Promise<MCPTool[]> {
99+
async listTools(): Promise<MCPTool[]> {
105100
const { ListToolsResultSchema } = await import(
106101
'@modelcontextprotocol/sdk/types.js'
107102
).catch(failedToImport);
@@ -218,10 +213,7 @@ export class NodeMCPServerStreamableHttp extends BaseMCPServerStreamableHttp {
218213
this._cacheDirty = true;
219214
}
220215

221-
async listTools(
222-
_runContext?: RunContext<any>,
223-
_agent?: Agent<any, any>,
224-
): Promise<MCPTool[]> {
216+
async listTools(): Promise<MCPTool[]> {
225217
const { ListToolsResultSchema } = await import(
226218
'@modelcontextprotocol/sdk/types.js'
227219
).catch(failedToImport);

packages/agents-core/test/mcpCache.test.ts

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ import { getAllMcpTools } from '../src/mcp';
33
import { withTrace } from '../src/tracing';
44
import { NodeMCPServerStdio } from '../src/shims/mcp-server/node';
55
import type { CallToolResultContent } from '../src/mcp';
6+
import { RunContext } from '../src/runContext';
7+
import { Agent } from '../src/agent';
68

79
class StubServer extends NodeMCPServerStdio {
810
public toolList: any[];
@@ -48,15 +50,27 @@ describe('MCP tools cache invalidation', () => {
4850
];
4951
const server = new StubServer('server', toolsA);
5052

51-
let tools = await getAllMcpTools([server]);
53+
let tools = await getAllMcpTools(
54+
[server],
55+
new RunContext({}),
56+
new Agent({ name: 'test' }),
57+
);
5258
expect(tools.map((t) => t.name)).toEqual(['a']);
5359

5460
server.toolList = toolsB;
55-
tools = await getAllMcpTools([server]);
61+
tools = await getAllMcpTools(
62+
[server],
63+
new RunContext({}),
64+
new Agent({ name: 'test' }),
65+
);
5666
expect(tools.map((t) => t.name)).toEqual(['a']);
5767

5868
server.invalidateToolsCache();
59-
tools = await getAllMcpTools([server]);
69+
tools = await getAllMcpTools(
70+
[server],
71+
new RunContext({}),
72+
new Agent({ name: 'test' }),
73+
);
6074
expect(tools.map((t) => t.name)).toEqual(['b']);
6175
});
6276
});

0 commit comments

Comments
 (0)