Skip to content

Commit d94ee86

Browse files
committed
refactor: restructure mcp tools fetching with options object pattern
1 parent d639151 commit d94ee86

File tree

3 files changed

+100
-87
lines changed

3 files changed

+100
-87
lines changed

packages/agents-core/src/agent.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,12 @@ export class Agent<
518518
runContext: RunContext<TContext>,
519519
): Promise<Tool<TContext>[]> {
520520
if (this.mcpServers.length > 0) {
521-
return getAllMcpTools(this.mcpServers, runContext, this, false);
521+
return getAllMcpTools({
522+
mcpServers: this.mcpServers,
523+
runContext,
524+
agent: this,
525+
convertSchemasToStrict: false,
526+
});
522527
}
523528

524529
return [];

packages/agents-core/src/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,12 @@ export { getLogger } from './logger';
7070
export {
7171
getAllMcpTools,
7272
invalidateServerToolsCache,
73+
mcpToFunctionTool,
7374
MCPServer,
7475
MCPServerStdio,
7576
MCPServerStreamableHttp,
7677
MCPServerSSE,
78+
GetAllMcpToolsOptions,
7779
} from './mcp';
7880
export {
7981
MCPToolFilterCallable,

packages/agents-core/src/mcp.ts

Lines changed: 92 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -285,35 +285,6 @@ export class MCPServerSSE extends BaseMCPServerSSE {
285285
* Fetches and flattens all tools from multiple MCP servers.
286286
* Logs and skips any servers that fail to respond.
287287
*/
288-
export async function getAllMcpFunctionTools<TContext = UnknownContext>(
289-
mcpServers: MCPServer[],
290-
runContext: RunContext<TContext>,
291-
agent: Agent<any, any>,
292-
convertSchemasToStrict = false,
293-
): Promise<Tool<TContext>[]> {
294-
const allTools: Tool<TContext>[] = [];
295-
const toolNames = new Set<string>();
296-
for (const server of mcpServers) {
297-
const serverTools = await getFunctionToolsFromServer(
298-
server,
299-
runContext,
300-
agent,
301-
convertSchemasToStrict,
302-
);
303-
const serverToolNames = new Set(serverTools.map((t) => t.name));
304-
const intersection = [...serverToolNames].filter((n) => toolNames.has(n));
305-
if (intersection.length > 0) {
306-
throw new UserError(
307-
`Duplicate tool names found across MCP servers: ${intersection.join(', ')}`,
308-
);
309-
}
310-
for (const t of serverTools) {
311-
toolNames.add(t.name);
312-
allTools.push(t);
313-
}
314-
}
315-
return allTools;
316-
}
317288

318289
const _cachedTools: Record<string, MCPTool[]> = {};
319290
/**
@@ -327,12 +298,17 @@ export async function invalidateServerToolsCache(serverName: string) {
327298
/**
328299
* Fetches all function tools from a single MCP server.
329300
*/
330-
async function getFunctionToolsFromServer<TContext = UnknownContext>(
331-
server: MCPServer,
332-
runContext: RunContext<TContext>,
333-
agent: Agent<any, any>,
334-
convertSchemasToStrict: boolean,
335-
): Promise<FunctionTool<TContext, any, unknown>[]> {
301+
async function getFunctionToolsFromServer<TContext = UnknownContext>({
302+
server,
303+
convertSchemasToStrict,
304+
runContext,
305+
agent,
306+
}: {
307+
server: MCPServer;
308+
convertSchemasToStrict: boolean;
309+
runContext?: RunContext<TContext>;
310+
agent?: Agent<any, any>;
311+
}): Promise<FunctionTool<TContext, any, unknown>[]> {
336312
if (server.cacheToolsList && _cachedTools[server.name]) {
337313
return _cachedTools[server.name].map((t) =>
338314
mcpToFunctionTool(t, server, convertSchemasToStrict),
@@ -341,52 +317,54 @@ async function getFunctionToolsFromServer<TContext = UnknownContext>(
341317
return withMCPListToolsSpan(
342318
async (span) => {
343319
const fetchedMcpTools = await server.listTools();
344-
const mcpTools: MCPTool[] = [];
345-
const context = {
346-
runContext,
347-
agent,
348-
serverName: server.name,
349-
};
350-
for (const tool of fetchedMcpTools) {
351-
const filter = server.toolFilter;
352-
if (filter) {
353-
if (filter && typeof filter === 'function') {
354-
const filtered = await filter(context, tool);
355-
if (!filtered) {
356-
globalLogger.debug(
357-
`MCP Tool (server: ${server.name}, tool: ${tool.name}) is blocked by the callable filter.`,
358-
);
359-
continue; // skip this tool
360-
}
361-
} else {
362-
const allowedToolNames = filter.allowedToolNames ?? [];
363-
const blockedToolNames = filter.blockedToolNames ?? [];
364-
if (allowedToolNames.length > 0 || blockedToolNames.length > 0) {
365-
const allowed =
366-
allowedToolNames.length > 0
367-
? allowedToolNames.includes(tool.name)
368-
: true;
369-
const blocked =
370-
blockedToolNames.length > 0
371-
? blockedToolNames.includes(tool.name)
372-
: false;
373-
if (!allowed || blocked) {
374-
if (blocked) {
375-
globalLogger.debug(
376-
`MCP Tool (server: ${server.name}, tool: ${tool.name}) is blocked by the static filter.`,
377-
);
378-
} else if (!allowed) {
379-
globalLogger.debug(
380-
`MCP Tool (server: ${server.name}, tool: ${tool.name}) is not allowed by the static filter.`,
381-
);
320+
let mcpTools: MCPTool[] = fetchedMcpTools;
321+
322+
if (runContext && agent) {
323+
const context = { runContext, agent, serverName: server.name };
324+
const filteredTools: MCPTool[] = [];
325+
for (const tool of fetchedMcpTools) {
326+
const filter = server.toolFilter;
327+
if (filter) {
328+
if (typeof filter === 'function') {
329+
const filtered = await filter(context, tool);
330+
if (!filtered) {
331+
globalLogger.debug(
332+
`MCP Tool (server: ${server.name}, tool: ${tool.name}) is blocked by the callable filter.`,
333+
);
334+
continue;
335+
}
336+
} else {
337+
const allowedToolNames = filter.allowedToolNames ?? [];
338+
const blockedToolNames = filter.blockedToolNames ?? [];
339+
if (allowedToolNames.length > 0 || blockedToolNames.length > 0) {
340+
const allowed =
341+
allowedToolNames.length > 0
342+
? allowedToolNames.includes(tool.name)
343+
: true;
344+
const blocked =
345+
blockedToolNames.length > 0
346+
? blockedToolNames.includes(tool.name)
347+
: false;
348+
if (!allowed || blocked) {
349+
if (blocked) {
350+
globalLogger.debug(
351+
`MCP Tool (server: ${server.name}, tool: ${tool.name}) is blocked by the static filter.`,
352+
);
353+
} else if (!allowed) {
354+
globalLogger.debug(
355+
`MCP Tool (server: ${server.name}, tool: ${tool.name}) is not allowed by the static filter.`,
356+
);
357+
}
358+
continue;
382359
}
383-
continue; // skip this tool
384360
}
385361
}
386362
}
363+
filteredTools.push(tool);
387364
}
388-
mcpTools.push(tool);
365+
mcpTools = filteredTools;
389366
}
367+
390368
span.spanData.result = mcpTools.map((t) => t.name);
391369
const tools: FunctionTool<TContext, any, string>[] = mcpTools.map((t) =>
392370
mcpToFunctionTool(t, server, convertSchemasToStrict),
@@ -400,21 +378,49 @@ async function getFunctionToolsFromServer<TContext = UnknownContext>(
400378
);
401379
}
402380

381+
/**
382+
* Options for fetching MCP tools.
383+
*/
384+
export type GetAllMcpToolsOptions<TContext> = {
385+
mcpServers: MCPServer[];
386+
convertSchemasToStrict?: boolean;
387+
runContext?: RunContext<TContext>;
388+
agent?: Agent<TContext, any>;
389+
};
390+
403391
/**
404392
* Returns all MCP tools from the provided servers, using the function tool conversion.
393+
* If runContext and agent are provided, callable tool filters will be applied.
405394
*/
406-
export async function getAllMcpTools<TContext = UnknownContext>(
407-
mcpServers: MCPServer[],
408-
runContext: RunContext<TContext>,
409-
agent: Agent<TContext, any>,
395+
export async function getAllMcpTools<TContext = UnknownContext>({
396+
mcpServers,
410397
convertSchemasToStrict = false,
411-
): Promise<Tool<TContext>[]> {
412-
return getAllMcpFunctionTools(
413-
mcpServers,
414-
runContext,
415-
agent,
416-
convertSchemasToStrict,
417-
);
398+
runContext,
399+
agent,
400+
}: GetAllMcpToolsOptions<TContext>): Promise<Tool<TContext>[]> {
401+
const allTools: Tool<TContext>[] = [];
402+
const toolNames = new Set<string>();
403+
404+
for (const server of mcpServers) {
405+
const serverTools = await getFunctionToolsFromServer({
406+
server,
407+
convertSchemasToStrict,
408+
runContext,
409+
agent,
410+
});
411+
const serverToolNames = new Set(serverTools.map((t) => t.name));
412+
const intersection = [...serverToolNames].filter((n) => toolNames.has(n));
413+
if (intersection.length > 0) {
414+
throw new UserError(
415+
`Duplicate tool names found across MCP servers: ${intersection.join(', ')}`,
416+
);
417+
}
418+
for (const t of serverTools) {
419+
toolNames.add(t.name);
420+
allTools.push(t);
421+
}
422+
}
423+
return allTools;
418424
}
419425

420426
/**

0 commit comments

Comments
 (0)