Skip to content

Commit 34303c1

Browse files
authored
Merge pull request #726 from devoxx/issue-725
Fix #725 where langchain4J MCP had to be patched
2 parents 5fc6656 + 72ca48b commit 34303c1

File tree

58 files changed

+3605
-3
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+3605
-3
lines changed

build.gradle.kts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ plugins {
77
}
88

99
group = "com.devoxx.genie"
10-
version = "0.6.7"
10+
version = "0.6.8"
1111

1212
repositories {
1313
mavenCentral()
@@ -67,7 +67,7 @@ dependencies {
6767
implementation("dev.langchain4j:langchain4j-web-search-engine-tavily:$lg4j_version")
6868
implementation("dev.langchain4j:langchain4j-azure-open-ai:$lg4j_version")
6969
implementation("dev.langchain4j:langchain4j-chroma:$lg4j_version")
70-
implementation("dev.langchain4j:langchain4j-mcp:$lg4j_version")
70+
// implementation("dev.langchain4j:langchain4j-mcp:$lg4j_version")
7171
implementation("dev.langchain4j:langchain4j-reactor:$lg4j_version")
7272

7373
// Retrofit dependencies
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
package dev.langchain4j.mcp;
2+
3+
import dev.langchain4j.agent.tool.ToolExecutionRequest;
4+
import dev.langchain4j.agent.tool.ToolSpecification;
5+
import dev.langchain4j.internal.Utils;
6+
import dev.langchain4j.service.IllegalConfigurationException;
7+
import dev.langchain4j.service.tool.ToolExecutor;
8+
import dev.langchain4j.service.tool.ToolProvider;
9+
import dev.langchain4j.service.tool.ToolProviderRequest;
10+
import dev.langchain4j.service.tool.ToolProviderResult;
11+
import dev.langchain4j.mcp.client.McpClient;
12+
import org.slf4j.Logger;
13+
import org.slf4j.LoggerFactory;
14+
15+
import java.util.Arrays;
16+
import java.util.List;
17+
import java.util.Objects;
18+
import java.util.concurrent.CopyOnWriteArrayList;
19+
import java.util.concurrent.atomic.AtomicReference;
20+
import java.util.function.BiPredicate;
21+
import java.util.function.Function;
22+
23+
/**
24+
* A tool provider backed by one or more MCP clients.
25+
*/
26+
public class McpToolProvider implements ToolProvider {
27+
28+
private final CopyOnWriteArrayList<McpClient> mcpClients;
29+
private final boolean failIfOneServerFails;
30+
private final AtomicReference<BiPredicate<McpClient, ToolSpecification>> mcpToolsFilter;
31+
private final Function<ToolExecutor, ToolExecutor> toolWrapper;
32+
private static final Logger log = LoggerFactory.getLogger(McpToolProvider.class);
33+
34+
private McpToolProvider(Builder builder) {
35+
this(builder.mcpClients, Utils.getOrDefault(builder.failIfOneServerFails, false), builder.mcpToolsFilter, builder.toolWrapper);
36+
}
37+
38+
protected McpToolProvider(List<McpClient> mcpClients, boolean failIfOneServerFails, BiPredicate<McpClient, ToolSpecification> mcpToolsFilter) {
39+
this(Objects.requireNonNull(mcpClients), failIfOneServerFails, mcpToolsFilter, Function.identity());
40+
}
41+
42+
protected McpToolProvider(List<McpClient> mcpClients, boolean failIfOneServerFails, BiPredicate<McpClient, ToolSpecification> mcpToolsFilter, Function<ToolExecutor, ToolExecutor> toolWrapper) {
43+
this.mcpClients = new CopyOnWriteArrayList<>(mcpClients);
44+
this.failIfOneServerFails = failIfOneServerFails;
45+
this.mcpToolsFilter = new AtomicReference<>(mcpToolsFilter);
46+
this.toolWrapper = toolWrapper;
47+
}
48+
49+
/**
50+
* Adds a new MCP client to the list of clients.
51+
*
52+
* @param client the MCP client to add
53+
*/
54+
public void addMcpClient(McpClient client) {
55+
Objects.requireNonNull(client);
56+
mcpClients.add(client);
57+
}
58+
59+
/**
60+
* Removes an MCP client from the list of clients.
61+
*
62+
* @param client the MCP client to remove
63+
*/
64+
public void removeMcpClient(McpClient client) {
65+
mcpClients.remove(client);
66+
}
67+
68+
/**
69+
* Adds a tools filter that will act in conjunction (AND) with the eventually existing ones.
70+
*
71+
* @param filter the filter to add
72+
*/
73+
public void addFilter(BiPredicate<McpClient, ToolSpecification> filter) {
74+
Objects.requireNonNull(filter);
75+
BiPredicate<McpClient, ToolSpecification> currentFilter = mcpToolsFilter.get();
76+
while (!mcpToolsFilter.compareAndSet(currentFilter, currentFilter.and(filter))) {
77+
currentFilter = mcpToolsFilter.get();
78+
}
79+
}
80+
81+
/**
82+
* Sets the tools filter overriding the eventually existing ones.
83+
*
84+
* @param filter the filter to add
85+
*/
86+
public void setFilter(BiPredicate<McpClient, ToolSpecification> filter) {
87+
Objects.requireNonNull(filter);
88+
BiPredicate<McpClient, ToolSpecification> currentFilter = mcpToolsFilter.get();
89+
while (!mcpToolsFilter.compareAndSet(currentFilter, filter)) {
90+
currentFilter = mcpToolsFilter.get();
91+
}
92+
}
93+
94+
/**
95+
* Resets the all the eventually existing tools filters.
96+
*/
97+
public void resetFilters() {
98+
setFilter((mcp, tool) -> true);
99+
}
100+
101+
@Override
102+
public ToolProviderResult provideTools(ToolProviderRequest request) {
103+
return provideTools(request, mcpToolsFilter.get());
104+
}
105+
106+
protected ToolProviderResult provideTools(ToolProviderRequest request, BiPredicate<McpClient, ToolSpecification> mcpToolsFilter) {
107+
ToolProviderResult.Builder builder = ToolProviderResult.builder();
108+
for (McpClient mcpClient : mcpClients) {
109+
var defaultToolExecutor = new DefaultToolExecutor(mcpClient);
110+
try {
111+
mcpClient.listTools().stream().filter(tool -> mcpToolsFilter.test(mcpClient, tool))
112+
.forEach(toolSpecification -> {
113+
builder.add(toolSpecification, toolWrapper.apply(defaultToolExecutor));
114+
});
115+
} catch (IllegalConfigurationException e) {
116+
throw e;
117+
} catch (Exception e) {
118+
if (failIfOneServerFails) {
119+
throw new RuntimeException("Failed to retrieve tools from MCP server", e);
120+
} else {
121+
log.warn("Failed to retrieve tools from MCP server", e);
122+
}
123+
}
124+
}
125+
return builder.build();
126+
}
127+
128+
public static Builder builder() {
129+
return new Builder();
130+
}
131+
132+
public static class Builder {
133+
134+
private List<McpClient> mcpClients;
135+
private Boolean failIfOneServerFails;
136+
private BiPredicate<McpClient, ToolSpecification> mcpToolsFilter = (mcp, tool) -> true;
137+
private Function<ToolExecutor, ToolExecutor> toolWrapper = Function.identity();
138+
139+
/**
140+
* The list of MCP clients to use for retrieving tools.
141+
*/
142+
public Builder mcpClients(List<McpClient> mcpClients) {
143+
this.mcpClients = mcpClients;
144+
return this;
145+
}
146+
147+
/**
148+
* The list of MCP clients to use for retrieving tools.
149+
*/
150+
public Builder mcpClients(McpClient... mcpClients) {
151+
return mcpClients(Arrays.asList(mcpClients));
152+
}
153+
154+
/**
155+
* The predicate to filter MCP provided tools.
156+
*/
157+
public Builder filter(BiPredicate<McpClient, ToolSpecification> mcpToolsFilter) {
158+
this.mcpToolsFilter = this.mcpToolsFilter.and(mcpToolsFilter);
159+
return this;
160+
}
161+
162+
/**
163+
* Filter MCP provided tools with a specific name.
164+
*/
165+
public Builder filterToolNames(String... toolNames) {
166+
return filter(new ToolsNameFilter(toolNames));
167+
}
168+
169+
/**
170+
* If this is true, then the tool provider will throw an exception if it fails to list tools from any of the servers.
171+
* If this is false (default), then the tool provider will ignore the error and continue with the next server.
172+
*/
173+
public Builder failIfOneServerFails(boolean failIfOneServerFails) {
174+
this.failIfOneServerFails = failIfOneServerFails;
175+
return this;
176+
}
177+
178+
/**
179+
* Provide a wrapper around the {@link ToolExecutor} that can be used to implement tracing for example.
180+
*/
181+
public Builder toolWrapper(Function<ToolExecutor, ToolExecutor> toolWrapper) {
182+
this.toolWrapper = toolWrapper;
183+
return this;
184+
}
185+
186+
public McpToolProvider build() {
187+
return new McpToolProvider(this);
188+
}
189+
}
190+
191+
private static class ToolsNameFilter implements BiPredicate<McpClient, ToolSpecification> {
192+
private final List<String> toolNames;
193+
194+
private ToolsNameFilter(String... toolNames) {
195+
this(Arrays.asList(toolNames));
196+
}
197+
198+
private ToolsNameFilter(List<String> toolNames) {
199+
this.toolNames = toolNames;
200+
}
201+
202+
@Override
203+
public boolean test(McpClient mcpClient, ToolSpecification tool) {
204+
return toolNames.stream().anyMatch(name -> name.equals(tool.name()));
205+
}
206+
}
207+
208+
private static class DefaultToolExecutor implements ToolExecutor {
209+
private final McpClient mcpClient;
210+
211+
public DefaultToolExecutor(McpClient mcpClient) {
212+
this.mcpClient = mcpClient;
213+
}
214+
215+
@Override
216+
public String execute(ToolExecutionRequest executionRequest, Object memoryId) {
217+
return mcpClient.executeTool(executionRequest);
218+
}
219+
}
220+
}

0 commit comments

Comments
 (0)