Skip to content

Commit 04297fa

Browse files
seratchvrtnis
authored andcommitted
wip
1 parent eeaa067 commit 04297fa

File tree

9 files changed

+99
-108
lines changed

9 files changed

+99
-108
lines changed

examples/mcp/tool-filter-example.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ async function main() {
2424
await withTrace('MCP Tool Filter Example', async () => {
2525
const agent = new Agent({
2626
name: 'MCP Assistant',
27-
instructions:
28-
'Use the filesystem tools to answer questions. The write_file tool is blocked via toolFilter.',
27+
instructions: 'Use the filesystem tools to answer questions.',
2928
mcpServers: [mcpServer],
3029
});
3130

@@ -36,7 +35,7 @@ async function main() {
3635
console.log('\nAttempting to write a file (should be blocked):');
3736
result = await run(
3837
agent,
39-
'Create a file named test.txt with the text "hello"',
38+
'Create a file named sample_files/test.txt with the text "hello"',
4039
);
4140
console.log(result.finalOutput);
4241
});

packages/agents-core/src/agent.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ export class Agent<
518518
runContext: RunContext<TContext>,
519519
): Promise<Tool<TContext>[]> {
520520
if (this.mcpServers.length > 0) {
521-
return getAllMcpTools(this.mcpServers, false, runContext, this);
521+
return getAllMcpTools(this.mcpServers, runContext, this, false);
522522
}
523523

524524
return [];

packages/agents-core/src/mcp.ts

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,11 @@ export const DEFAULT_STREAMABLE_HTTP_MCP_CLIENT_LOGGER_NAME =
3030
*/
3131
export interface MCPServer {
3232
cacheToolsList: boolean;
33+
toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic;
3334
connect(): Promise<void>;
3435
readonly name: string;
3536
close(): Promise<void>;
36-
listTools(
37-
runContext?: RunContext<any>,
38-
agent?: Agent<any, any>,
39-
): Promise<MCPTool[]>;
37+
listTools(): Promise<MCPTool[]>;
4038
callTool(
4139
toolName: string,
4240
args: Record<string, unknown> | null,
@@ -47,7 +45,7 @@ export interface MCPServer {
4745
export abstract class BaseMCPServerStdio implements MCPServer {
4846
public cacheToolsList: boolean;
4947
protected _cachedTools: any[] | undefined = undefined;
50-
protected toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic;
48+
public toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic;
5149

5250
protected logger: Logger;
5351
constructor(options: MCPServerStdioOptions) {
@@ -60,10 +58,7 @@ export abstract class BaseMCPServerStdio implements MCPServer {
6058
abstract get name(): string;
6159
abstract connect(): Promise<void>;
6260
abstract close(): Promise<void>;
63-
abstract listTools(
64-
runContext?: RunContext<any>,
65-
agent?: Agent<any, any>,
66-
): Promise<any[]>;
61+
abstract listTools(): Promise<any[]>;
6762
abstract callTool(
6863
_toolName: string,
6964
_args: Record<string, unknown> | null,
@@ -85,7 +80,7 @@ export abstract class BaseMCPServerStdio implements MCPServer {
8580
export abstract class BaseMCPServerStreamableHttp implements MCPServer {
8681
public cacheToolsList: boolean;
8782
protected _cachedTools: any[] | undefined = undefined;
88-
protected toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic;
83+
public toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic;
8984

9085
protected logger: Logger;
9186
constructor(options: MCPServerStreamableHttpOptions) {
@@ -99,10 +94,7 @@ export abstract class BaseMCPServerStreamableHttp implements MCPServer {
9994
abstract get name(): string;
10095
abstract connect(): Promise<void>;
10196
abstract close(): Promise<void>;
102-
abstract listTools(
103-
runContext?: RunContext<any>,
104-
agent?: Agent<any, any>,
105-
): Promise<any[]>;
97+
abstract listTools(): Promise<any[]>;
10698
abstract callTool(
10799
_toolName: string,
108100
_args: Record<string, unknown> | null,
@@ -157,14 +149,11 @@ export class MCPServerStdio extends BaseMCPServerStdio {
157149
close(): Promise<void> {
158150
return this.underlying.close();
159151
}
160-
async listTools(
161-
runContext?: RunContext<any>,
162-
agent?: Agent<any, any>,
163-
): Promise<MCPTool[]> {
152+
async listTools(): Promise<MCPTool[]> {
164153
if (this.cacheToolsList && this._cachedTools) {
165154
return this._cachedTools;
166155
}
167-
const tools = await this.underlying.listTools(runContext, agent);
156+
const tools = await this.underlying.listTools();
168157
if (this.cacheToolsList) {
169158
this._cachedTools = tools;
170159
}
@@ -196,14 +185,11 @@ export class MCPServerStreamableHttp extends BaseMCPServerStreamableHttp {
196185
close(): Promise<void> {
197186
return this.underlying.close();
198187
}
199-
async listTools(
200-
runContext?: RunContext<any>,
201-
agent?: Agent<any, any>,
202-
): Promise<MCPTool[]> {
188+
async listTools(): Promise<MCPTool[]> {
203189
if (this.cacheToolsList && this._cachedTools) {
204190
return this._cachedTools;
205191
}
206-
const tools = await this.underlying.listTools(runContext, agent);
192+
const tools = await this.underlying.listTools();
207193
if (this.cacheToolsList) {
208194
this._cachedTools = tools;
209195
}
@@ -226,18 +212,18 @@ export class MCPServerStreamableHttp extends BaseMCPServerStreamableHttp {
226212
*/
227213
export async function getAllMcpFunctionTools<TContext = UnknownContext>(
228214
mcpServers: MCPServer[],
215+
runContext: RunContext<TContext>,
216+
agent: Agent<any, any>,
229217
convertSchemasToStrict = false,
230-
runContext?: RunContext<TContext>,
231-
agent?: Agent<TContext, any>,
232218
): Promise<Tool<TContext>[]> {
233219
const allTools: Tool<TContext>[] = [];
234220
const toolNames = new Set<string>();
235221
for (const server of mcpServers) {
236222
const serverTools = await getFunctionToolsFromServer(
237223
server,
238-
convertSchemasToStrict,
239224
runContext,
240225
agent,
226+
convertSchemasToStrict,
241227
);
242228
const serverToolNames = new Set(serverTools.map((t) => t.name));
243229
const intersection = [...serverToolNames].filter((n) => toolNames.has(n));
@@ -268,9 +254,9 @@ export async function invalidateServerToolsCache(serverName: string) {
268254
*/
269255
async function getFunctionToolsFromServer<TContext = UnknownContext>(
270256
server: MCPServer,
257+
runContext: RunContext<TContext>,
258+
agent: Agent<any, any>,
271259
convertSchemasToStrict: boolean,
272-
runContext?: RunContext<TContext>,
273-
agent?: Agent<TContext, any>,
274260
): Promise<FunctionTool<TContext, any, unknown>[]> {
275261
if (server.cacheToolsList && _cachedTools[server.name]) {
276262
return _cachedTools[server.name].map((t) =>
@@ -279,7 +265,53 @@ async function getFunctionToolsFromServer<TContext = UnknownContext>(
279265
}
280266
return withMCPListToolsSpan(
281267
async (span) => {
282-
const mcpTools = await server.listTools(runContext, agent);
268+
const fetchedMcpTools = await server.listTools();
269+
const mcpTools: MCPTool[] = [];
270+
const context = {
271+
runContext,
272+
agent,
273+
serverName: server.name,
274+
};
275+
for (const tool of fetchedMcpTools) {
276+
const filter = server.toolFilter;
277+
if (filter) {
278+
if (filter && typeof filter === 'function') {
279+
const filtered = await filter(context, tool);
280+
if (!filtered) {
281+
globalLogger.debug(
282+
`MCP Tool (server: ${server.name}, tool: ${tool.name}) is blocked by the callable filter.`,
283+
);
284+
continue; // skip this tool
285+
}
286+
} else {
287+
const allowedToolNames = filter.allowedToolNames ?? [];
288+
const blockedToolNames = filter.blockedToolNames ?? [];
289+
if (allowedToolNames.length > 0 || blockedToolNames.length > 0) {
290+
const allowed =
291+
allowedToolNames.length > 0
292+
? allowedToolNames.includes(tool.name)
293+
: true;
294+
const blocked =
295+
blockedToolNames.length > 0
296+
? blockedToolNames.includes(tool.name)
297+
: false;
298+
if (!allowed || blocked) {
299+
if (blocked) {
300+
globalLogger.debug(
301+
`MCP Tool (server: ${server.name}, tool: ${tool.name}) is blocked by the static filter.`,
302+
);
303+
} else if (!allowed) {
304+
globalLogger.debug(
305+
`MCP Tool (server: ${server.name}, tool: ${tool.name}) is not allowed by the static filter.`,
306+
);
307+
}
308+
continue; // skip this tool
309+
}
310+
}
311+
}
312+
}
313+
mcpTools.push(tool);
314+
}
283315
span.spanData.result = mcpTools.map((t) => t.name);
284316
const tools: FunctionTool<TContext, any, string>[] = mcpTools.map((t) =>
285317
mcpToFunctionTool(t, server, convertSchemasToStrict),
@@ -298,15 +330,15 @@ async function getFunctionToolsFromServer<TContext = UnknownContext>(
298330
*/
299331
export async function getAllMcpTools<TContext = UnknownContext>(
300332
mcpServers: MCPServer[],
333+
runContext: RunContext<TContext>,
334+
agent: Agent<TContext, any>,
301335
convertSchemasToStrict = false,
302-
runContext?: RunContext<TContext>,
303-
agent?: Agent<TContext, any>,
304336
): Promise<Tool<TContext>[]> {
305337
return getAllMcpFunctionTools(
306338
mcpServers,
307-
convertSchemasToStrict,
308339
runContext,
309340
agent,
341+
convertSchemasToStrict,
310342
);
311343
}
312344

packages/agents-core/src/mcpUtil.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ export interface MCPToolFilterContext<TContext = UnknownContext> {
1717
export type MCPToolFilterCallable<TContext = UnknownContext> = (
1818
context: MCPToolFilterContext<TContext>,
1919
tool: MCPTool,
20-
) => boolean | Promise<boolean>;
20+
) => Promise<boolean>;
2121

2222
/** Static tool filter configuration using allow and block lists. */
2323
export interface MCPToolFilterStatic {

packages/agents-core/src/runState.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -557,8 +557,8 @@ export class RunState<TContext, TAgent extends Agent<any, any>> {
557557
? await deserializeProcessedResponse(
558558
agentMap,
559559
state._currentAgent,
560-
stateJson.lastProcessedResponse,
561560
state._context,
561+
stateJson.lastProcessedResponse,
562562
)
563563
: undefined;
564564

@@ -708,12 +708,12 @@ export function deserializeItem(
708708
async function deserializeProcessedResponse<TContext = UnknownContext>(
709709
agentMap: Map<string, Agent<any, any>>,
710710
currentAgent: Agent<TContext, any>,
711+
context: RunContext<TContext>,
711712
serializedProcessedResponse: z.infer<
712713
typeof serializedProcessedResponseSchema
713714
>,
714-
runContext: RunContext<TContext>,
715715
): Promise<ProcessedResponse<TContext>> {
716-
const allTools = await currentAgent.getAllTools(runContext);
716+
const allTools = await currentAgent.getAllTools(context);
717717
const tools = new Map(
718718
allTools
719719
.filter((tool) => tool.type === 'function')

packages/agents-core/src/shims/mcp-server/browser.ts

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ import {
66
MCPServerStreamableHttpOptions,
77
MCPTool,
88
} from '../../mcp';
9-
import type { RunContext } from '../../runContext';
10-
import type { Agent } from '../../agent';
119

1210
export class MCPServerStdio extends BaseMCPServerStdio {
1311
constructor(params: MCPServerStdioOptions) {
@@ -22,10 +20,7 @@ export class MCPServerStdio extends BaseMCPServerStdio {
2220
close(): Promise<void> {
2321
throw new Error('Method not implemented.');
2422
}
25-
listTools(
26-
_runContext?: RunContext<any>,
27-
_agent?: Agent<any, any>,
28-
): Promise<MCPTool[]> {
23+
listTools(): Promise<MCPTool[]> {
2924
throw new Error('Method not implemented.');
3025
}
3126
callTool(
@@ -52,10 +47,7 @@ export class MCPServerStreamableHttp extends BaseMCPServerStreamableHttp {
5247
close(): Promise<void> {
5348
throw new Error('Method not implemented.');
5449
}
55-
listTools(
56-
_runContext?: RunContext<any>,
57-
_agent?: Agent<any, any>,
58-
): Promise<MCPTool[]> {
50+
listTools(): Promise<MCPTool[]> {
5951
throw new Error('Method not implemented.');
6052
}
6153
callTool(

packages/agents-core/src/shims/mcp-server/node.ts

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ import {
1212
invalidateServerToolsCache,
1313
} from '../../mcp';
1414
import logger from '../../logger';
15-
import type { RunContext } from '../../runContext';
16-
import type { Agent } from '../../agent';
1715

1816
export interface SessionMessage {
1917
message: any;
@@ -98,10 +96,7 @@ export class NodeMCPServerStdio extends BaseMCPServerStdio {
9896
this._cacheDirty = true;
9997
}
10098

101-
async listTools(
102-
_runContext?: RunContext<any>,
103-
_agent?: Agent<any, any>,
104-
): Promise<MCPTool[]> {
99+
async listTools(): Promise<MCPTool[]> {
105100
const { ListToolsResultSchema } = await import(
106101
'@modelcontextprotocol/sdk/types.js'
107102
).catch(failedToImport);
@@ -218,10 +213,7 @@ export class NodeMCPServerStreamableHttp extends BaseMCPServerStreamableHttp {
218213
this._cacheDirty = true;
219214
}
220215

221-
async listTools(
222-
_runContext?: RunContext<any>,
223-
_agent?: Agent<any, any>,
224-
): Promise<MCPTool[]> {
216+
async listTools(): Promise<MCPTool[]> {
225217
const { ListToolsResultSchema } = await import(
226218
'@modelcontextprotocol/sdk/types.js'
227219
).catch(failedToImport);

packages/agents-core/test/mcpCache.test.ts

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import type { FunctionTool } from '../src/tool';
44
import { withTrace } from '../src/tracing';
55
import { NodeMCPServerStdio } from '../src/shims/mcp-server/node';
66
import type { CallToolResultContent } from '../src/mcp';
7+
import { RunContext } from '../src/runContext';
8+
import { Agent } from '../src/agent';
79

810
class StubServer extends NodeMCPServerStdio {
911
public toolList: any[];
@@ -49,15 +51,27 @@ describe('MCP tools cache invalidation', () => {
4951
];
5052
const server = new StubServer('server', toolsA);
5153

52-
let tools = await getAllMcpTools([server]);
54+
let tools = await getAllMcpTools(
55+
[server],
56+
new RunContext({}),
57+
new Agent({ name: 'test' }),
58+
);
5359
expect(tools.map((t) => t.name)).toEqual(['a']);
5460

5561
server.toolList = toolsB;
56-
tools = await getAllMcpTools([server]);
62+
tools = await getAllMcpTools(
63+
[server],
64+
new RunContext({}),
65+
new Agent({ name: 'test' }),
66+
);
5767
expect(tools.map((t) => t.name)).toEqual(['a']);
5868

5969
await server.invalidateToolsCache();
60-
tools = await getAllMcpTools([server]);
70+
tools = await getAllMcpTools(
71+
[server],
72+
new RunContext({}),
73+
new Agent({ name: 'test' }),
74+
);
6175
expect(tools.map((t) => t.name)).toEqual(['b']);
6276
});
6377
});

0 commit comments

Comments
 (0)