Skip to content

Commit 8fd03d5

Browse files
committed
[app-builder] support MCP tool invocation in LLM nodes
1 parent 523d902 commit 8fd03d5

File tree

4 files changed

+244
-18
lines changed

4 files changed

+244
-18
lines changed

app-builder/jane/plugins/aipp-plugin/src/main/java/modelengine/fit/jober/aipp/fel/FelComponentConfig.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import modelengine.fel.core.chat.ChatModel;
1010
import modelengine.fel.core.chat.Prompt;
1111
import modelengine.fel.engine.operators.patterns.AbstractAgent;
12+
import modelengine.fel.tool.mcp.client.McpClientFactory;
1213
import modelengine.fit.jade.tool.SyncToolCall;
1314
import modelengine.fit.jober.aipp.constants.AippConst;
1415
import modelengine.fitframework.annotation.Bean;
@@ -28,11 +29,12 @@ public class FelComponentConfig {
2829
*
2930
* @param syncToolCall 表示同步工具调用服务的 {@link SyncToolCall}。
3031
* @param chatModel 表示模型流式服务的 {@link ChatModel}。
32+
* @param mcpClientFactory 表示大模型上下文客户端工厂的 {@link McpClientFactory}。
3133
* @return 返回 WaterFlow 场景的 Agent 服务的 {@link AbstractAgent}{@code <}{@link Prompt}{@code ,
3234
* }{@link Prompt}{@code >}。
3335
*/
3436
@Bean(AippConst.WATER_FLOW_AGENT_BEAN)
35-
public AbstractAgent getWaterFlowAgent(@Fit SyncToolCall syncToolCall, ChatModel chatModel) {
36-
return new WaterFlowAgent(syncToolCall, chatModel);
37+
public AbstractAgent getWaterFlowAgent(@Fit SyncToolCall syncToolCall, ChatModel chatModel, McpClientFactory mcpClientFactory) {
38+
return new WaterFlowAgent(syncToolCall, chatModel, mcpClientFactory);
3739
}
3840
}

app-builder/jane/plugins/aipp-plugin/src/main/java/modelengine/fit/jober/aipp/fel/WaterFlowAgent.java

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,40 @@
66

77
package modelengine.fit.jober.aipp.fel;
88

9+
import com.alibaba.fastjson.JSON;
10+
import com.alibaba.fastjson.JSONObject;
11+
912
import modelengine.fel.core.chat.ChatMessage;
1013
import modelengine.fel.core.chat.ChatModel;
1114
import modelengine.fel.core.chat.Prompt;
1215
import modelengine.fel.core.chat.support.ChatMessages;
1316
import modelengine.fel.core.chat.support.FlatChatMessage;
1417
import modelengine.fel.core.chat.support.ToolMessage;
1518
import modelengine.fel.core.tool.ToolCall;
19+
import modelengine.fel.core.tool.ToolInfo;
1620
import modelengine.fel.engine.flows.AiFlows;
1721
import modelengine.fel.engine.flows.AiProcessFlow;
1822
import modelengine.fel.engine.operators.models.ChatChunk;
1923
import modelengine.fel.engine.operators.models.ChatFlowModel;
2024
import modelengine.fel.engine.operators.patterns.AbstractAgent;
25+
import modelengine.fel.tool.mcp.client.McpClient;
26+
import modelengine.fel.tool.mcp.client.McpClientFactory;
2127
import modelengine.fit.jade.tool.SyncToolCall;
28+
import modelengine.fit.jober.aipp.common.exception.AippErrCode;
29+
import modelengine.fit.jober.aipp.common.exception.AippException;
2230
import modelengine.fit.jober.aipp.constants.AippConst;
31+
import modelengine.fit.jober.aipp.util.McpUtils;
2332
import modelengine.fit.waterflow.domain.context.StateContext;
2433
import modelengine.fitframework.annotation.Fit;
2534
import modelengine.fitframework.inspection.Validation;
35+
import modelengine.fitframework.util.CollectionUtils;
2636
import modelengine.fitframework.util.ObjectUtils;
2737

38+
import java.io.IOException;
2839
import java.util.Collections;
2940
import java.util.List;
3041
import java.util.Map;
42+
import java.util.function.Function;
3143
import java.util.stream.Collectors;
3244

3345
/**
@@ -42,28 +54,30 @@ public class WaterFlowAgent extends AbstractAgent {
4254

4355
private final String agentMsgKey;
4456
private final SyncToolCall syncToolCall;
57+
private final McpClientFactory mcpClientFactory;
4558

4659
/**
4760
* {@link WaterFlowAgent} 的构造方法。
4861
*
4962
* @param syncToolCall 表示工具调用服务的 {@link SyncToolCall}。
5063
* @param chatStreamModel 表示流式对话大模型的 {@link ChatModel}。
64+
* @param mcpClientFactory 表示大模型上下文客户端工厂的 {@link McpClientFactory}。
5165
*/
52-
public WaterFlowAgent(@Fit SyncToolCall syncToolCall, ChatModel chatStreamModel) {
66+
public WaterFlowAgent(@Fit SyncToolCall syncToolCall, ChatModel chatStreamModel,
67+
McpClientFactory mcpClientFactory) {
5368
super(new ChatFlowModel(chatStreamModel, null));
54-
this.syncToolCall = Validation.notNull(syncToolCall, "The tool sync tool call cannot be null.");
69+
this.syncToolCall = Validation.notNull(syncToolCall, "The tool sync tool call cannot be null.");
70+
this.mcpClientFactory = Validation.notNull(mcpClientFactory, "The mcp client factory cannot be null.");
5571
this.agentMsgKey = AGENT_MSG_KEY;
5672
}
5773

5874
@Override
5975
protected Prompt doToolCall(List<ToolCall> toolCalls, StateContext ctx) {
6076
Validation.notNull(ctx, "The state context cannot be null.");
61-
Map<String, Object> toolContext = ObjectUtils.getIfNull(ctx.getState(AippConst.TOOL_CONTEXT_KEY),
62-
Collections::emptyMap);
63-
return toolCalls.stream()
64-
.map(toolCall -> (ChatMessage) new ToolMessage(toolCall.id(),
65-
this.syncToolCall.call(toolCall.name(), toolCall.arguments(), toolContext)))
66-
.collect(Collectors.collectingAndThen(Collectors.toList(), ChatMessages::from));
77+
return ChatMessages.from(this.callTools(toolCalls, ctx)
78+
.stream()
79+
.map(message -> (ChatMessage) FlatChatMessage.from(message))
80+
.collect(Collectors.toList()));
6781
}
6882

6983
@Override
@@ -87,18 +101,53 @@ public AiProcessFlow<Prompt, ChatMessage> buildFlow() {
87101
private ChatMessage handleTool(ChatMessage input, StateContext ctx) {
88102
Validation.notNull(ctx, "The state context cannot be null.");
89103
Validation.notNull(input, "The input message cannot be null.");
90-
91-
Map<String, Object> toolContext = ObjectUtils.getIfNull(ctx.getState(AippConst.TOOL_CONTEXT_KEY),
92-
Collections::emptyMap);
93104
ChatMessages lastRequest = ctx.getState(this.agentMsgKey);
94105
lastRequest.add(input);
95-
input.toolCalls().forEach(toolCall -> {
96-
lastRequest.add(FlatChatMessage.from(new ToolMessage(toolCall.id(),
97-
this.syncToolCall.call(toolCall.name(), toolCall.arguments(), toolContext))));
98-
});
106+
lastRequest.addAll(this.callTools(input.toolCalls(), ctx));
99107
return input;
100108
}
101109

110+
private List<ChatMessage> callTools(List<ToolCall> toolCalls, StateContext ctx) {
111+
if (CollectionUtils.isEmpty(toolCalls)) {
112+
return Collections.emptyList();
113+
}
114+
List<ToolInfo> tools = ctx.getState(AippConst.TOOLS_KEY);
115+
Validation.notEmpty(tools, "Missing tool detected during call.");
116+
Map<String, ToolInfo> toolsMap = tools.stream().collect(Collectors.toMap(ToolInfo::name, Function.identity()));
117+
Map<String, Object> toolContext =
118+
ObjectUtils.getIfNull(ctx.getState(AippConst.TOOL_CONTEXT_KEY), Collections::emptyMap);
119+
return toolCalls.stream()
120+
.map(toolCall -> this.callTool(toolCall, toolsMap, toolContext))
121+
.collect(Collectors.toList());
122+
}
123+
124+
private ChatMessage callTool(ToolCall toolCall, Map<String, ToolInfo> toolsMap, Map<String, Object> toolContext) {
125+
ToolInfo toolInfo = toolsMap.get(toolCall.name());
126+
if (toolInfo == null) {
127+
throw new IllegalStateException(String.format("The tool call's tool is not exist. [toolName=%s]",
128+
toolCall.name()));
129+
}
130+
Map<String, Object> extensions = Validation.notNull(toolInfo.extensions(),
131+
"The tool call's extension is not exist. [toolName={0}]", toolCall.name());
132+
String toolRealName = Validation.notBlank(ObjectUtils.cast(extensions.get(AippConst.TOOL_REAL_NAME)),
133+
"Can not find the tool real name. [toolName={0}]",
134+
toolCall.name());
135+
Map<String, Object> mcpServerConfig = ObjectUtils.cast(extensions.get(AippConst.MCP_SERVER_KEY));
136+
if (mcpServerConfig != null) {
137+
String url = Validation.notBlank(ObjectUtils.cast(mcpServerConfig.get(AippConst.MCP_SERVER_URL_KEY)),
138+
"The mcp url should not be empty.");
139+
try (McpClient mcpClient = this.mcpClientFactory.create(McpUtils.getBaseUrl(url),
140+
McpUtils.getSseEndpoint(url))) {
141+
mcpClient.initialize();
142+
Object result = mcpClient.callTool(toolRealName, JSONObject.parseObject(toolCall.arguments()));
143+
return new ToolMessage(toolCall.id(), JSON.toJSONString(result));
144+
} catch (IOException exception) {
145+
throw new AippException(AippErrCode.CALL_MCP_SERVER_FAILED, exception.getMessage());
146+
}
147+
}
148+
return new ToolMessage(toolCall.id(), this.syncToolCall.call(toolRealName, toolCall.arguments(), toolContext));
149+
}
150+
102151
private ChatMessages getAgentMsg(ChatMessage input, StateContext ctx) {
103152
Validation.notNull(ctx, "The state context cannot be null.");
104153
return ctx.getState(this.agentMsgKey);
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/*---------------------------------------------------------------------------------------------
2+
* Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3+
* This file is a part of the ModelEngine Project.
4+
* Licensed under the MIT License. See License.txt in the project root for license information.
5+
*--------------------------------------------------------------------------------------------*/
6+
7+
package modelengine.fit.jober.aipp.fel;
8+
9+
import modelengine.fel.core.chat.ChatMessage;
10+
import modelengine.fel.core.chat.ChatModel;
11+
import modelengine.fel.core.chat.ChatOption;
12+
import modelengine.fel.core.chat.Prompt;
13+
import modelengine.fel.core.chat.support.AiMessage;
14+
import modelengine.fel.core.chat.support.ChatMessages;
15+
import modelengine.fel.core.chat.support.HumanMessage;
16+
import modelengine.fel.core.tool.ToolCall;
17+
import modelengine.fel.core.tool.ToolInfo;
18+
import modelengine.fel.engine.flows.AiProcessFlow;
19+
import modelengine.fel.tool.mcp.client.McpClient;
20+
import modelengine.fel.tool.mcp.client.McpClientFactory;
21+
import modelengine.fit.jade.tool.SyncToolCall;
22+
import modelengine.fit.jober.aipp.constants.AippConst;
23+
import modelengine.fitframework.flowable.Choir;
24+
import modelengine.fitframework.util.MapBuilder;
25+
26+
import org.apache.commons.collections.CollectionUtils;
27+
import org.junit.jupiter.api.Test;
28+
import org.junit.jupiter.api.extension.ExtendWith;
29+
import org.mockito.Mock;
30+
import org.mockito.junit.jupiter.MockitoExtension;
31+
32+
import java.util.Collections;
33+
import java.util.HashMap;
34+
import java.util.List;
35+
import java.util.Map;
36+
import java.util.concurrent.atomic.AtomicInteger;
37+
38+
import static org.junit.jupiter.api.Assertions.*;
39+
import static org.mockito.ArgumentMatchers.any;
40+
import static org.mockito.Mockito.doAnswer;
41+
import static org.mockito.Mockito.mock;
42+
import static org.mockito.Mockito.times;
43+
import static org.mockito.Mockito.verify;
44+
import static org.mockito.Mockito.when;
45+
46+
/**
47+
* {@link WaterFlowAgent} 的测试。
48+
*/
49+
@ExtendWith(MockitoExtension.class)
50+
class WaterFlowAgentTest {
51+
@Mock
52+
private SyncToolCall syncToolCall;
53+
@Mock
54+
private ChatModel chatModel;
55+
@Mock
56+
private McpClientFactory mcpClientFactory;
57+
58+
@Test
59+
void shouldGetResultWhenRunFlowGivenNoToolCall() {
60+
WaterFlowAgent waterFlowAgent = new WaterFlowAgent(this.syncToolCall, this.chatModel, this.mcpClientFactory);
61+
62+
String expectResult = "0123";
63+
doAnswer(invocation -> Choir.create(emitter -> {
64+
for (int i = 0; i < 4; i++) {
65+
emitter.emit(new AiMessage(String.valueOf(i)));
66+
}
67+
emitter.complete();
68+
})).when(chatModel).generate(any(), any());
69+
70+
AiProcessFlow<Prompt, ChatMessage> flow = waterFlowAgent.buildFlow();
71+
ChatMessage result = flow.converse()
72+
.bind(ChatOption.custom().build())
73+
.offer(ChatMessages.from(new HumanMessage("hi"))).await();
74+
75+
assertEquals(expectResult, result.text());
76+
}
77+
78+
@Test
79+
void shouldGetResultWhenRunFlowGivenStoreToolCall() {
80+
WaterFlowAgent waterFlowAgent = new WaterFlowAgent(this.syncToolCall, this.chatModel, this.mcpClientFactory);
81+
82+
String expectResult = "tool result:0123";
83+
String realName = "realName";
84+
ToolInfo toolInfo = buildToolInfo(realName);
85+
ToolCall toolCall = ToolCall.custom().id("id").name(toolInfo.name()).arguments("{}").build();
86+
List<ToolCall> toolCalls = Collections.singletonList(toolCall);
87+
doAnswer(invocation -> {
88+
Prompt prompt = invocation.getArgument(0);
89+
return mockGenerateResult(toolCalls, prompt);
90+
}).when(chatModel).generate(any(), any());
91+
Map<String, Object> toolContext = MapBuilder.<String, Object>get().put("key", "value").build();
92+
when(this.syncToolCall.call(realName, toolCall.arguments(), toolContext)).thenReturn("tool result:");
93+
94+
AiProcessFlow<Prompt, ChatMessage> flow = waterFlowAgent.buildFlow();
95+
ChatMessage result = flow.converse()
96+
.bind(ChatOption.custom().build())
97+
.bind(AippConst.TOOL_CONTEXT_KEY, toolContext)
98+
.bind(AippConst.TOOLS_KEY, Collections.singletonList(toolInfo))
99+
.offer(ChatMessages.from(new HumanMessage("hi"))).await();
100+
101+
verify(this.mcpClientFactory, times(0)).create(any(), any());
102+
assertEquals(expectResult, result.text());
103+
}
104+
105+
@Test
106+
void shouldGetResultWhenRunFlowGivenMcpToolCall() {
107+
WaterFlowAgent waterFlowAgent = new WaterFlowAgent(this.syncToolCall, this.chatModel, this.mcpClientFactory);
108+
109+
String expectResult = "\"tool result:\"0123";
110+
String realName = "realName";
111+
String baseUrl = "http://localhost";
112+
String sseEndpoint = "/sse";
113+
ToolInfo toolInfo = buildMcpToolInfo(realName, baseUrl, sseEndpoint);
114+
ToolCall toolCall = ToolCall.custom().id("id").name(toolInfo.name()).arguments("{}").build();
115+
List<ToolCall> toolCalls = Collections.singletonList(toolCall);
116+
doAnswer(invocation -> {
117+
Prompt prompt = invocation.getArgument(0);
118+
return mockGenerateResult(toolCalls, prompt);
119+
}).when(chatModel).generate(any(), any());
120+
Map<String, Object> toolContext = MapBuilder.<String, Object>get().put("key", "value").build();
121+
McpClient mcpClient = mock(McpClient.class);
122+
when(this.mcpClientFactory.create(baseUrl, sseEndpoint)).thenReturn(mcpClient);
123+
when(mcpClient.callTool(realName, new HashMap<>())).thenReturn("tool result:");
124+
125+
AiProcessFlow<Prompt, ChatMessage> flow = waterFlowAgent.buildFlow();
126+
ChatMessage result = flow.converse()
127+
.bind(ChatOption.custom().build())
128+
.bind(AippConst.TOOL_CONTEXT_KEY, toolContext)
129+
.bind(AippConst.TOOLS_KEY, Collections.singletonList(toolInfo))
130+
.offer(ChatMessages.from(new HumanMessage("hi"))).await();
131+
132+
verify(this.syncToolCall, times(0)).call(any(), any(), any());
133+
assertEquals(expectResult, result.text());
134+
}
135+
136+
private static Choir<Object> mockGenerateResult(List<ToolCall> toolCalls, Prompt prompt) {
137+
AtomicInteger step = new AtomicInteger();
138+
return Choir.create(emitter -> {
139+
if (step.getAndIncrement() == 0) {
140+
emitter.emit(new AiMessage("tool_data", toolCalls));
141+
emitter.complete();
142+
return;
143+
}
144+
if (CollectionUtils.isNotEmpty(prompt.messages())) {
145+
emitter.emit(new AiMessage(prompt.messages().get(prompt.messages().size() - 1).text()));
146+
}
147+
for (int i = 0; i < 4; i++) {
148+
emitter.emit(new AiMessage(String.valueOf(i)));
149+
}
150+
emitter.complete();
151+
});
152+
}
153+
154+
private static ToolInfo buildToolInfo(String realName) {
155+
return ToolInfo.custom()
156+
.name("tool1")
157+
.description("desc")
158+
.parameters(new HashMap<>())
159+
.extensions(MapBuilder.<String, Object>get().put(AippConst.TOOL_REAL_NAME, realName).build())
160+
.build();
161+
}
162+
163+
private static ToolInfo buildMcpToolInfo(String realName, String baseUrl, String sseEndpoint) {
164+
return ToolInfo.custom()
165+
.name("tool1")
166+
.description("desc")
167+
.parameters(new HashMap<>())
168+
.extensions(MapBuilder.<String, Object>get()
169+
.put(AippConst.TOOL_REAL_NAME, realName)
170+
.put(AippConst.MCP_SERVER_KEY,
171+
MapBuilder.get().put(AippConst.MCP_SERVER_URL_KEY, baseUrl + sseEndpoint).build())
172+
.build())
173+
.build();
174+
}
175+
}

app-builder/jane/plugins/aipp-plugin/src/test/java/modelengine/fit/jober/aipp/fitable/LlmComponentTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ protected AiProcessFlow<Prompt, ChatMessage> buildFlow() {
166166
}
167167

168168
private AbstractAgent getWaterFlowAgent(ChatModel model) {
169-
return new WaterFlowAgent(this.syncToolCall, model);
169+
return new WaterFlowAgent(this.syncToolCall, model, this.mcpClientFactory);
170170
}
171171

172172
private ChatModel buildChatStreamModel(String exceptionMsg) {

0 commit comments

Comments
 (0)