Skip to content

Commit d04a41b

Browse files
ilayaperumalgSenreySong
authored andcommitted
feat(mcp): Add McpToolFilter interface for selective MCP tool inclusion/exclusion (spring-projects#3901)
- Introduce McpToolFilter interface extending BiPredicate<McpMetadata, McpSchema.Tool> - Add McpMetadata record containing client capabilities, info, and initialization result - Refactor AsyncMcpToolCallbackProvider and SyncMcpToolCallbackProvider to use McpToolFilter - Update auto-configuration to inject optional McpToolFilter bean - Add tests for tool filtering functionality - Update documentation with tool filtering examples and usage Signed-off-by: Ilayaperumal Gopinathan <[email protected]>
1 parent 8cf2ea5 commit d04a41b

File tree

8 files changed

+235
-31
lines changed

8 files changed

+235
-31
lines changed

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import io.modelcontextprotocol.client.McpSyncClient;
2323

2424
import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
25+
import org.springframework.ai.mcp.McpToolFilter;
2526
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
2627
import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties;
2728
import org.springframework.beans.factory.ObjectProvider;
@@ -45,22 +46,28 @@ public class McpToolCallbackAutoConfiguration {
4546
* <p>
4647
* These callbacks enable integration with Spring AI's tool execution framework,
4748
* allowing MCP tools to be used as part of AI interactions.
49+
* @param syncClientsToolFilter list of {@link McpToolFilter}s for the sync client to
50+
* filter the discovered tools
4851
* @param syncMcpClients provider of MCP sync clients
4952
* @return list of tool callbacks for MCP integration
5053
*/
5154
@Bean
5255
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC",
5356
matchIfMissing = true)
54-
public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider<List<McpSyncClient>> syncMcpClients) {
57+
public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider<McpToolFilter> syncClientsToolFilter,
58+
ObjectProvider<List<McpSyncClient>> syncMcpClients) {
5559
List<McpSyncClient> mcpClients = syncMcpClients.stream().flatMap(List::stream).toList();
56-
return new SyncMcpToolCallbackProvider(mcpClients);
60+
return new SyncMcpToolCallbackProvider(syncClientsToolFilter.getIfUnique((() -> (McpSyncClient, tool) -> true)),
61+
mcpClients);
5762
}
5863

5964
@Bean
6065
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC")
61-
public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(ObjectProvider<List<McpAsyncClient>> mcpClientsProvider) {
66+
public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(ObjectProvider<McpToolFilter> asyncClientsToolFilter,
67+
ObjectProvider<List<McpAsyncClient>> mcpClientsProvider) {
6268
List<McpAsyncClient> mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList();
63-
return new AsyncMcpToolCallbackProvider(mcpClients);
69+
return new AsyncMcpToolCallbackProvider(
70+
asyncClientsToolFilter.getIfUnique(() -> (McpAsyncClient, tool) -> true), mcpClients);
6471
}
6572

6673
public static class McpToolCallbackAutoConfigurationCondition extends AllNestedConditions {

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,28 @@
1616

1717
package org.springframework.ai.mcp.client.common.autoconfigure;
1818

19+
import java.lang.reflect.Field;
20+
import java.util.List;
21+
22+
import io.modelcontextprotocol.client.McpAsyncClient;
23+
import io.modelcontextprotocol.client.McpSyncClient;
24+
import io.modelcontextprotocol.spec.McpSchema;
1925
import org.junit.jupiter.api.Test;
26+
import reactor.core.publisher.Mono;
2027

28+
import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
29+
import org.springframework.ai.mcp.McpMetadata;
30+
import org.springframework.ai.mcp.McpToolFilter;
31+
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
2132
import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration.McpToolCallbackAutoConfigurationCondition;
2233
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
2334
import org.springframework.context.annotation.Bean;
2435
import org.springframework.context.annotation.Conditional;
2536
import org.springframework.context.annotation.Configuration;
2637

2738
import static org.assertj.core.api.Assertions.assertThat;
39+
import static org.mockito.Mockito.mock;
40+
import static org.mockito.Mockito.when;
2841

2942
/**
3043
* Tests for {@link McpToolCallbackAutoConfigurationCondition}.
@@ -73,6 +86,58 @@ void doesMatchWhenBothPropertiesAreMissing() {
7386
this.contextRunner.run(context -> assertThat(context).hasBean("testBean"));
7487
}
7588

89+
@Test
90+
void verifySyncToolCallbackFilterConfiguration() {
91+
this.contextRunner
92+
.withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpClientFilterConfiguration.class)
93+
.withPropertyValues("spring.ai.mcp.client.type=SYNC")
94+
.run(context -> {
95+
assertThat(context).hasBean("mcpClientFilter");
96+
SyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(SyncMcpToolCallbackProvider.class);
97+
Field field = SyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter");
98+
field.setAccessible(true);
99+
McpToolFilter toolFilter = (McpToolFilter) field.get(toolCallbackProvider);
100+
McpSyncClient syncClient1 = mock(McpSyncClient.class);
101+
var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0");
102+
when(syncClient1.getClientInfo()).thenReturn(clientInfo1);
103+
McpSchema.Tool tool1 = mock(McpSchema.Tool.class);
104+
when(tool1.name()).thenReturn("tool1");
105+
McpSchema.Tool tool2 = mock(McpSchema.Tool.class);
106+
when(tool2.name()).thenReturn("tool2");
107+
McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class);
108+
when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2));
109+
when(syncClient1.listTools()).thenReturn(listToolsResult1);
110+
assertThat(toolFilter.test(new McpMetadata(null, syncClient1.getClientInfo(), null), tool1)).isFalse();
111+
assertThat(toolFilter.test(new McpMetadata(null, syncClient1.getClientInfo(), null), tool2)).isTrue();
112+
});
113+
}
114+
115+
@Test
116+
void verifyASyncToolCallbackFilterConfiguration() {
117+
this.contextRunner
118+
.withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpClientFilterConfiguration.class)
119+
.withPropertyValues("spring.ai.mcp.client.type=ASYNC")
120+
.run(context -> {
121+
assertThat(context).hasBean("mcpClientFilter");
122+
AsyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(AsyncMcpToolCallbackProvider.class);
123+
Field field = AsyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter");
124+
field.setAccessible(true);
125+
McpToolFilter toolFilter = (McpToolFilter) field.get(toolCallbackProvider);
126+
McpAsyncClient asyncClient1 = mock(McpAsyncClient.class);
127+
var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0");
128+
when(asyncClient1.getClientInfo()).thenReturn(clientInfo1);
129+
McpSchema.Tool tool1 = mock(McpSchema.Tool.class);
130+
when(tool1.name()).thenReturn("tool1");
131+
McpSchema.Tool tool2 = mock(McpSchema.Tool.class);
132+
when(tool2.name()).thenReturn("tool2");
133+
McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class);
134+
when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2));
135+
when(asyncClient1.listTools()).thenReturn(Mono.just(listToolsResult1));
136+
assertThat(toolFilter.test(new McpMetadata(null, asyncClient1.getClientInfo(), null), tool1)).isFalse();
137+
assertThat(toolFilter.test(new McpMetadata(null, asyncClient1.getClientInfo(), null), tool2)).isTrue();
138+
});
139+
}
140+
76141
@Configuration
77142
@Conditional(McpToolCallbackAutoConfigurationCondition.class)
78143
static class TestConfiguration {
@@ -84,4 +149,22 @@ String testBean() {
84149

85150
}
86151

152+
@Configuration
153+
static class McpClientFilterConfiguration {
154+
155+
@Bean
156+
McpToolFilter mcpClientFilter() {
157+
return new McpToolFilter() {
158+
@Override
159+
public boolean test(McpMetadata metadata, McpSchema.Tool tool) {
160+
if (metadata.clientInfo().name().equals("client1") && tool.name().contains("tool1")) {
161+
return false;
162+
}
163+
return true;
164+
}
165+
};
166+
}
167+
168+
}
169+
87170
}

mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallbackProvider.java

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818

1919
import java.util.ArrayList;
2020
import java.util.List;
21-
import java.util.function.BiPredicate;
2221

2322
import io.modelcontextprotocol.client.McpAsyncClient;
24-
import io.modelcontextprotocol.spec.McpSchema.Tool;
2523
import io.modelcontextprotocol.util.Assert;
2624
import reactor.core.publisher.Flux;
2725

@@ -74,21 +72,21 @@
7472
*/
7573
public class AsyncMcpToolCallbackProvider implements ToolCallbackProvider {
7674

77-
private final List<McpAsyncClient> mcpClients;
75+
private final McpToolFilter toolFilter;
7876

79-
private final BiPredicate<McpAsyncClient, Tool> toolFilter;
77+
private final List<McpAsyncClient> mcpClients;
8078

8179
/**
8280
* Creates a new {@code AsyncMcpToolCallbackProvider} instance with a list of MCP
8381
* clients.
84-
* @param mcpClients the list of MCP clients to use for discovering tools
8582
* @param toolFilter a filter to apply to each discovered tool
83+
* @param mcpClients the list of MCP clients to use for discovering tools
8684
*/
87-
public AsyncMcpToolCallbackProvider(BiPredicate<McpAsyncClient, Tool> toolFilter, List<McpAsyncClient> mcpClients) {
85+
public AsyncMcpToolCallbackProvider(McpToolFilter toolFilter, List<McpAsyncClient> mcpClients) {
8886
Assert.notNull(mcpClients, "MCP clients must not be null");
8987
Assert.notNull(toolFilter, "Tool filter must not be null");
90-
this.mcpClients = mcpClients;
9188
this.toolFilter = toolFilter;
89+
this.mcpClients = mcpClients;
9290
}
9391

9492
/**
@@ -106,10 +104,10 @@ public AsyncMcpToolCallbackProvider(List<McpAsyncClient> mcpClients) {
106104
/**
107105
* Creates a new {@code AsyncMcpToolCallbackProvider} instance with one or more MCP
108106
* clients.
109-
* @param mcpClients the MCP clients to use for discovering tools
110107
* @param toolFilter a filter to apply to each discovered tool
108+
* @param mcpClients the MCP clients to use for discovering tools
111109
*/
112-
public AsyncMcpToolCallbackProvider(BiPredicate<McpAsyncClient, Tool> toolFilter, McpAsyncClient... mcpClients) {
110+
public AsyncMcpToolCallbackProvider(McpToolFilter toolFilter, McpAsyncClient... mcpClients) {
113111
this(toolFilter, List.of(mcpClients));
114112
}
115113

@@ -147,7 +145,8 @@ public ToolCallback[] getToolCallbacks() {
147145
ToolCallback[] toolCallbacks = mcpClient.listTools()
148146
.map(response -> response.tools()
149147
.stream()
150-
.filter(tool -> this.toolFilter.test(mcpClient, tool))
148+
.filter(tool -> this.toolFilter.test(new McpMetadata(mcpClient.getClientCapabilities(),
149+
mcpClient.getClientInfo(), mcpClient.getCurrentInitializationResult()), tool))
151150
.map(tool -> new AsyncMcpToolCallback(mcpClient, tool))
152151
.toArray(ToolCallback[]::new))
153152
.block();
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.mcp;
18+
19+
import io.modelcontextprotocol.spec.McpSchema;
20+
21+
/**
22+
* MCP metadata record containing the client/server specific meta data.
23+
*
24+
* @param clientCapabilities the MCP client capabilities
25+
* @param clientInfo the MCP client information
26+
* @param initializeResult the MCP server initialization result
27+
* @author Ilayaperumal Gopinathan
28+
* @author Christian Tzolov
29+
*/
30+
public record McpMetadata(// @formatter:off
31+
McpSchema.ClientCapabilities clientCapabilities,
32+
McpSchema.Implementation clientInfo,
33+
McpSchema.InitializeResult initializeResult) { // @formatter:on
34+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.mcp;
18+
19+
import java.util.function.BiPredicate;
20+
21+
import io.modelcontextprotocol.spec.McpSchema;
22+
23+
/**
24+
* A {@link BiPredicate} for {@link SyncMcpToolCallbackProvider} and the
25+
* {@link AsyncMcpToolCallbackProvider} to filter the discovered tool for the given
26+
* {@link McpMetadata}.
27+
*
28+
* @author Ilayaperumal Gopinathan
29+
*/
30+
public interface McpToolFilter extends BiPredicate<McpMetadata, McpSchema.Tool> {
31+
32+
}

mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallbackProvider.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717
package org.springframework.ai.mcp;
1818

1919
import java.util.List;
20-
import java.util.function.BiPredicate;
2120

2221
import io.modelcontextprotocol.client.McpSyncClient;
23-
import io.modelcontextprotocol.spec.McpSchema.Tool;
2422

2523
import org.springframework.ai.tool.ToolCallback;
2624
import org.springframework.ai.tool.ToolCallbackProvider;
@@ -72,15 +70,15 @@ public class SyncMcpToolCallbackProvider implements ToolCallbackProvider {
7270

7371
private final List<McpSyncClient> mcpClients;
7472

75-
private final BiPredicate<McpSyncClient, Tool> toolFilter;
73+
private final McpToolFilter toolFilter;
7674

7775
/**
7876
* Creates a new {@code SyncMcpToolCallbackProvider} instance with a list of MCP
7977
* clients.
8078
* @param mcpClients the list of MCP clients to use for discovering tools
8179
* @param toolFilter a filter to apply to each discovered tool
8280
*/
83-
public SyncMcpToolCallbackProvider(BiPredicate<McpSyncClient, Tool> toolFilter, List<McpSyncClient> mcpClients) {
81+
public SyncMcpToolCallbackProvider(McpToolFilter toolFilter, List<McpSyncClient> mcpClients) {
8482
Assert.notNull(mcpClients, "MCP clients must not be null");
8583
Assert.notNull(toolFilter, "Tool filter must not be null");
8684
this.mcpClients = mcpClients;
@@ -102,7 +100,7 @@ public SyncMcpToolCallbackProvider(List<McpSyncClient> mcpClients) {
102100
* @param mcpClients the MCP clients to use for discovering tools
103101
* @param toolFilter a filter to apply to each discovered tool
104102
*/
105-
public SyncMcpToolCallbackProvider(BiPredicate<McpSyncClient, Tool> toolFilter, McpSyncClient... mcpClients) {
103+
public SyncMcpToolCallbackProvider(McpToolFilter toolFilter, McpSyncClient... mcpClients) {
106104
this(toolFilter, List.of(mcpClients));
107105
}
108106

@@ -133,7 +131,8 @@ public ToolCallback[] getToolCallbacks() {
133131
.flatMap(mcpClient -> mcpClient.listTools()
134132
.tools()
135133
.stream()
136-
.filter(tool -> this.toolFilter.test(mcpClient, tool))
134+
.filter(tool -> this.toolFilter.test(new McpMetadata(mcpClient.getClientCapabilities(),
135+
mcpClient.getClientInfo(), mcpClient.getCurrentInitializationResult()), tool))
137136
.map(tool -> new SyncMcpToolCallback(mcpClient, tool)))
138137
.toArray(ToolCallback[]::new);
139138
validateToolCallbacks(array);

mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackProviderTests.java

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package org.springframework.ai.mcp;
1818

1919
import java.util.List;
20-
import java.util.function.BiPredicate;
2120

2221
import io.modelcontextprotocol.client.McpSyncClient;
2322
import io.modelcontextprotocol.spec.McpSchema.Implementation;
@@ -164,7 +163,7 @@ void toolFilterShouldRejectAllToolsWhenConfigured() {
164163
when(this.mcpClient.listTools()).thenReturn(listToolsResult);
165164

166165
// Create a filter that rejects all tools
167-
BiPredicate<McpSyncClient, Tool> rejectAllFilter = (client, tool) -> false;
166+
McpToolFilter rejectAllFilter = (client, tool) -> false;
168167

169168
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(rejectAllFilter, this.mcpClient);
170169

@@ -192,8 +191,7 @@ void toolFilterShouldFilterToolsByNameWhenConfigured() {
192191
when(this.mcpClient.listTools()).thenReturn(listToolsResult);
193192

194193
// Create a filter that only accepts tools with names containing "2" or "3"
195-
BiPredicate<McpSyncClient, Tool> nameFilter = (client, tool) -> tool.name().contains("2")
196-
|| tool.name().contains("3");
194+
McpToolFilter nameFilter = (client, tool) -> tool.name().contains("2") || tool.name().contains("3");
197195

198196
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(nameFilter, this.mcpClient);
199197

@@ -228,8 +226,7 @@ void toolFilterShouldFilterToolsByClientWhenConfigured() {
228226
when(mcpClient2.getClientInfo()).thenReturn(clientInfo2);
229227

230228
// Create a filter that only accepts tools from client1
231-
BiPredicate<McpSyncClient, Tool> clientFilter = (client,
232-
tool) -> client.getClientInfo().name().equals("testClient1");
229+
McpToolFilter clientFilter = (mcpMetadata, tool) -> mcpMetadata.clientInfo().name().equals("testClient1");
233230

234231
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(clientFilter, mcpClient1, mcpClient2);
235232

@@ -256,8 +253,8 @@ void toolFilterShouldCombineClientAndToolCriteriaWhenConfigured() {
256253
when(weatherClient.getClientInfo()).thenReturn(weatherClientInfo);
257254

258255
// Create a filter that only accepts weather tools from the weather service
259-
BiPredicate<McpSyncClient, Tool> complexFilter = (client,
260-
tool) -> client.getClientInfo().name().equals("weather-service") && tool.name().equals("weather");
256+
McpToolFilter complexFilter = (mcpMetadata, tool) -> mcpMetadata.clientInfo().name().equals("weather-service")
257+
&& tool.name().equals("weather");
261258

262259
SyncMcpToolCallbackProvider provider = new SyncMcpToolCallbackProvider(complexFilter, weatherClient);
263260

0 commit comments

Comments
 (0)