Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 193 additions & 43 deletions src/mcp_agent/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import json
import uuid
from typing import Callable, Dict, List, Optional, TypeVar, TYPE_CHECKING, Any
from typing import Callable, Dict, List, Optional, Set, TypeVar, TYPE_CHECKING, Any
from contextlib import asynccontextmanager

from opentelemetry import trace
Expand Down Expand Up @@ -459,7 +459,48 @@ async def get_server_session(self, server_name: str):

return result

async def list_tools(self, server_name: str | None = None) -> ListToolsResult:
def _should_include_non_namespaced_tool(
self, tool_name: str, tool_filter: Dict[str, Set[str]] | None
) -> tuple[bool, str | None]:
"""
Determine if a non-namespaced tool (function tool or human input) should be included.

Uses the special reserved key "non_namespaced_tools" to filter function tools and human input.

Returns: (should_include, filter_reason)
- filter_reason is None if tool should be included, otherwise explains why filtered
"""
if tool_filter is None:
return True, None

# Priority 1: Check non_namespaced_tools key (explicitly for non-namespaced tools)
if "non_namespaced_tools" in tool_filter:
if tool_name in tool_filter["non_namespaced_tools"]:
return True, None
else:
return False, f"{tool_name} not in tool_filter[non_namespaced_tools]"

# Priority 2: Check wildcard filter
elif "*" in tool_filter:
if tool_name in tool_filter["*"]:
return True, None
else:
return False, f"{tool_name} not in tool_filter[*]"

# No non_namespaced_tools key and no wildcard - include by default (no filter for non-namespaced)
return True, None

async def list_tools(self, server_name: str | None = None, tool_filter: Dict[str, Set[str]] | None = None) -> ListToolsResult:
"""
List available tools with optional filtering.

Args:
server_name: Optional specific server to list tools from
tool_filter: Optional dict mapping server names to sets of allowed tool names.
Special reserved keys:
- "*": Wildcard filter for servers without explicit filters
- "non_namespaced_tools": Filter for non-namespaced tools (function tools, human input)
"""
if not self.initialized:
await self.initialize()

Expand All @@ -473,38 +514,112 @@ async def list_tools(self, server_name: str | None = None) -> ListToolsResult:
"human_input_callback", self.human_input_callback is not None
)

# Track filtered tools for debugging and telemetry
filtered_out_tools = [] # List of (tool_name, reason) tuples

if server_name:
span.set_attribute("server_name", server_name)
result = ListToolsResult(
tools=[
namespaced_tool.tool.model_copy(
update={"name": namespaced_tool.namespaced_tool_name}
)
for namespaced_tool in self._server_to_tool_map.get(
server_name, []
)
]
)
# Get tools for specific server
server_tools = self._server_to_tool_map.get(server_name, [])

# Check if we should apply filtering for this specific server
if tool_filter is not None and server_name in tool_filter:
# Server is explicitly in filter dict - apply its filter rules
# If tool_filter[server_name] is empty set, no tools will pass
# If tool_filter[server_name] has tools, only those will pass
allowed_tools = tool_filter[server_name]
result_tools = []
for namespaced_tool in server_tools:
if namespaced_tool.tool.name in allowed_tools:
result_tools.append(
namespaced_tool.tool.model_copy(
update={"name": namespaced_tool.namespaced_tool_name}
)
)
else:
filtered_out_tools.append(
(namespaced_tool.namespaced_tool_name, f"Not in tool_filter[{server_name}]")
)
result = ListToolsResult(tools=result_tools)
else:
# Either no filter at all (tool_filter is None) or
# this server is not in the filter dict (no filtering for this server)
# Include all tools from this server
result = ListToolsResult(
tools=[
namespaced_tool.tool.model_copy(
update={"name": namespaced_tool.namespaced_tool_name}
)
for namespaced_tool in server_tools
]
)
else:
result = ListToolsResult(
tools=[
namespaced_tool.tool.model_copy(
update={"name": namespaced_tool_name}
)
for namespaced_tool_name, namespaced_tool in self._namespaced_tool_map.items()
]
)
# No specific server requested - get tools from all servers
if tool_filter is not None:
# Filter is active - check each tool's server against filter rules
filtered_tools = []
for namespaced_tool_name, namespaced_tool in self._namespaced_tool_map.items():
should_include = False

# Priority 1: Check if tool's server has explicit filter rules
if namespaced_tool.server_name in tool_filter:
# Server has explicit filter - tool must be in the allowed set
if namespaced_tool.tool.name in tool_filter[namespaced_tool.server_name]:
should_include = True
else:
filtered_out_tools.append(
(namespaced_tool_name, f"Not in tool_filter[{namespaced_tool.server_name}]")
)
# Priority 2: If no server-specific filter, check wildcard
elif "*" in tool_filter:
# Wildcard filter applies to servers without explicit filters
if namespaced_tool.tool.name in tool_filter["*"]:
should_include = True
else:
filtered_out_tools.append(
(namespaced_tool_name, "Not in tool_filter[*]")
)
else:
# No explicit filter for this server and no wildcard
# Default behavior: include the tool (no filtering)
should_include = True

if should_include:
filtered_tools.append(
namespaced_tool.tool.model_copy(
update={"name": namespaced_tool_name}
)
)
result = ListToolsResult(tools=filtered_tools)
else:
# No filter at all - include everything
result = ListToolsResult(
tools=[
namespaced_tool.tool.model_copy(
update={"name": namespaced_tool_name}
)
for namespaced_tool_name, namespaced_tool in self._namespaced_tool_map.items()
]
)

# Add function tools
# Add function tools (non-namespaced) with filtering
# These use the special "non_namespaced_tools" key in tool_filter
for tool in self._function_tool_map.values():
result.tools.append(
Tool(
name=tool.name,
description=tool.description,
inputSchema=tool.parameters,
)
should_include, filter_reason = self._should_include_non_namespaced_tool(
tool.name, tool_filter
)

if should_include:
result.tools.append(
Tool(
name=tool.name,
description=tool.description,
inputSchema=tool.parameters,
)
)
elif filter_reason:
filtered_out_tools.append((tool.name, filter_reason))

def _annotate_span_for_tools_result(result: ListToolsResult):
if not self.context.tracing_enabled:
return
Expand All @@ -529,24 +644,59 @@ def _annotate_span_for_tools_result(result: ListToolsResult):
f"tool.{tool.name}.annotations.{attr}", value
)

# Add a human_input_callback as a tool
if not self.human_input_callback:
logger.debug("Human input callback not set")
_annotate_span_for_tools_result(result)
# Add human_input_callback tool (non-namespaced) with filtering
# This uses the special "non_namespaced_tools" key in tool_filter
if self.human_input_callback:
should_include, filter_reason = self._should_include_non_namespaced_tool(
HUMAN_INPUT_TOOL_NAME, tool_filter
)

return result
if should_include:
human_input_tool: FastTool = FastTool.from_function(
self.request_human_input
)
result.tools.append(
Tool(
name=HUMAN_INPUT_TOOL_NAME,
description=human_input_tool.description,
inputSchema=human_input_tool.parameters,
)
)
elif filter_reason:
filtered_out_tools.append((HUMAN_INPUT_TOOL_NAME, filter_reason))
else:
logger.debug("Human input callback not set")

# Add a human_input_callback as a tool
human_input_tool: FastTool = FastTool.from_function(
self.request_human_input
)
result.tools.append(
Tool(
name=HUMAN_INPUT_TOOL_NAME,
description=human_input_tool.description,
inputSchema=human_input_tool.parameters,
)
)
# Log and track filtering metrics if filter was applied
if tool_filter is not None:
span.set_attribute("tool_filter_applied", True)
span.set_attribute("tools_included_count", len(result.tools))
span.set_attribute("tools_filtered_out_count", len(filtered_out_tools))

# Add telemetry for filtered tools (limit to first 20 to avoid span bloat)
if self.context.tracing_enabled:
for i, (tool_name, reason) in enumerate(filtered_out_tools[:20]):
span.set_attribute(f"filtered_tool.{i}.name", tool_name)
span.set_attribute(f"filtered_tool.{i}.reason", reason)
if len(filtered_out_tools) > 20:
span.set_attribute("filtered_tools_truncated", True)

# Log filtered tools for debugging
if filtered_out_tools:
logger.debug(
f"Tool filter applied: {len(filtered_out_tools)} tools filtered out, "
f"{len(result.tools)} tools remaining. "
f"Filtered tools: {[name for name, _ in filtered_out_tools[:10]]}" +
("..." if len(filtered_out_tools) > 10 else "")
)
# Log detailed reasons at trace level (if trace logging is available)
if logger.isEnabledFor(10): # TRACE level is usually 10
for tool_name, reason in filtered_out_tools:
logger.log(10, f"Filtered out '{tool_name}': {reason}")
else:
logger.debug(
f"Tool filter applied: All {len(result.tools)} tools passed the filter"
)

_annotate_span_for_tools_result(result)

Expand Down
26 changes: 24 additions & 2 deletions src/mcp_agent/workflows/llm/augmented_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from typing import (
Any,
Dict,
Generic,
List,
Optional,
Protocol,
Set,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -173,6 +175,26 @@ class RequestParams(CreateMessageRequestParams):
Whether models that support strict mode should strictly enforce the response schema.
"""

tool_filter: Dict[str, Set[str]] | None = None
"""
Mapping of server names to sets of allowed tool names for this request.
If specified, only these tools will be exposed to the LLM for each server.
This overrides the server-level allowed_tools configuration.

Special reserved keys:
- "*": Wildcard filter for servers without explicit filters
- "non_namespaced_tools": Filter for non-namespaced tools (function tools, human input)

Examples:
- {"server1": {"tool1", "tool2"}} - Allow specific tools from server1
- {"*": {"tool1"}} - Allow tool1 from all servers without explicit filters
- {"non_namespaced_tools": {"human_input", "func1"}} - Allow specific non-namespaced tools
- {} - No tools allowed from any server
- None - No filtering applied (default behavior)

Tool names should match exactly as they appear in the server's tool list.
"""


class AugmentedLLMProtocol(Protocol, Generic[MessageParamT, MessageT]):
"""Protocol defining the interface for augmented LLMs"""
Expand Down Expand Up @@ -538,9 +560,9 @@ async def call_tool(
],
)

async def list_tools(self, server_name: str | None = None) -> ListToolsResult:
async def list_tools(self, server_name: str | None = None, tool_filter: Dict[str, Set[str]] | None = None) -> ListToolsResult:
"""Call the underlying agent's list_tools method for a given server."""
return await self.agent.list_tools(server_name=server_name)
return await self.agent.list_tools(server_name=server_name, tool_filter=tool_filter)

async def list_resources(
self, server_name: str | None = None
Expand Down
2 changes: 1 addition & 1 deletion src/mcp_agent/workflows/llm/augmented_llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ async def generate(
AnthropicConverter.convert_mixed_messages_to_anthropic(message)
)

list_tools_result = await self.agent.list_tools()
list_tools_result = await self.agent.list_tools(tool_filter=params.tool_filter)
available_tools: List[ToolParam] = [
{
"name": tool.name,
Expand Down
2 changes: 1 addition & 1 deletion src/mcp_agent/workflows/llm/augmented_llm_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ async def generate(self, message, request_params: RequestParams | None = None):

messages.extend(AzureConverter.convert_mixed_messages_to_azure(message))

response = await self.agent.list_tools()
response = await self.agent.list_tools(tool_filter=params.tool_filter)

tools: list[ChatCompletionsToolDefinition] = [
ChatCompletionsToolDefinition(
Expand Down
2 changes: 1 addition & 1 deletion src/mcp_agent/workflows/llm/augmented_llm_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ async def generate(self, message, request_params: RequestParams | None = None):

messages.extend(BedrockConverter.convert_mixed_messages_to_bedrock(message))

response = await self.agent.list_tools()
response = await self.agent.list_tools(tool_filter=params.tool_filter)

tool_config: ToolConfigurationTypeDef = {
"tools": [
Expand Down
2 changes: 1 addition & 1 deletion src/mcp_agent/workflows/llm/augmented_llm_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def generate(self, message, request_params: RequestParams | None = None):

messages.extend(GoogleConverter.convert_mixed_messages_to_google(message))

response = await self.agent.list_tools()
response = await self.agent.list_tools(tool_filter=params.tool_filter)

tools = [
types.Tool(
Expand Down
2 changes: 1 addition & 1 deletion src/mcp_agent/workflows/llm/augmented_llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ async def generate(
)
messages.extend((OpenAIConverter.convert_mixed_messages_to_openai(message)))

response: ListToolsResult = await self.agent.list_tools()
response: ListToolsResult = await self.agent.list_tools(tool_filter=params.tool_filter)
available_tools: List[ChatCompletionToolParam] = [
ChatCompletionToolParam(
type="function",
Expand Down
Loading
Loading