1+ /*---------------------------------------------------------------------------------------------
2+ * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3+ * This file is a part of the ModelEngine Project.
4+ * Licensed under the MIT License. See License.txt in the project root for license information.
5+ *--------------------------------------------------------------------------------------------*/
6+
7+ package modelengine .fit .jober .aipp .fel ;
8+
9+ import modelengine .fel .core .chat .ChatMessage ;
10+ import modelengine .fel .core .chat .ChatModel ;
11+ import modelengine .fel .core .chat .ChatOption ;
12+ import modelengine .fel .core .chat .Prompt ;
13+ import modelengine .fel .core .chat .support .AiMessage ;
14+ import modelengine .fel .core .chat .support .ChatMessages ;
15+ import modelengine .fel .core .chat .support .HumanMessage ;
16+ import modelengine .fel .core .tool .ToolCall ;
17+ import modelengine .fel .core .tool .ToolInfo ;
18+ import modelengine .fel .engine .flows .AiProcessFlow ;
19+ import modelengine .fel .tool .mcp .client .McpClient ;
20+ import modelengine .fel .tool .mcp .client .McpClientFactory ;
21+ import modelengine .fit .jade .tool .SyncToolCall ;
22+ import modelengine .fit .jober .aipp .constants .AippConst ;
23+ import modelengine .fitframework .flowable .Choir ;
24+ import modelengine .fitframework .util .MapBuilder ;
25+
26+ import org .apache .commons .collections .CollectionUtils ;
27+ import org .junit .jupiter .api .Test ;
28+ import org .junit .jupiter .api .extension .ExtendWith ;
29+ import org .mockito .Mock ;
30+ import org .mockito .junit .jupiter .MockitoExtension ;
31+
32+ import java .util .Collections ;
33+ import java .util .HashMap ;
34+ import java .util .List ;
35+ import java .util .Map ;
36+ import java .util .concurrent .atomic .AtomicInteger ;
37+
38+ import static org .junit .jupiter .api .Assertions .*;
39+ import static org .mockito .ArgumentMatchers .any ;
40+ import static org .mockito .Mockito .doAnswer ;
41+ import static org .mockito .Mockito .mock ;
42+ import static org .mockito .Mockito .times ;
43+ import static org .mockito .Mockito .verify ;
44+ import static org .mockito .Mockito .when ;
45+
46+ /**
47+ * {@link WaterFlowAgent} 的测试。
48+ */
49+ @ ExtendWith (MockitoExtension .class )
50+ class WaterFlowAgentTest {
51+ @ Mock
52+ private SyncToolCall syncToolCall ;
53+ @ Mock
54+ private ChatModel chatModel ;
55+ @ Mock
56+ private McpClientFactory mcpClientFactory ;
57+
58+ @ Test
59+ void shouldGetResultWhenRunFlowGivenNoToolCall () {
60+ WaterFlowAgent waterFlowAgent = new WaterFlowAgent (this .syncToolCall , this .chatModel , this .mcpClientFactory );
61+
62+ String expectResult = "0123" ;
63+ doAnswer (invocation -> Choir .create (emitter -> {
64+ for (int i = 0 ; i < 4 ; i ++) {
65+ emitter .emit (new AiMessage (String .valueOf (i )));
66+ }
67+ emitter .complete ();
68+ })).when (chatModel ).generate (any (), any ());
69+
70+ AiProcessFlow <Prompt , ChatMessage > flow = waterFlowAgent .buildFlow ();
71+ ChatMessage result = flow .converse ()
72+ .bind (ChatOption .custom ().build ())
73+ .offer (ChatMessages .from (new HumanMessage ("hi" ))).await ();
74+
75+ assertEquals (expectResult , result .text ());
76+ }
77+
78+ @ Test
79+ void shouldGetResultWhenRunFlowGivenStoreToolCall () {
80+ WaterFlowAgent waterFlowAgent = new WaterFlowAgent (this .syncToolCall , this .chatModel , this .mcpClientFactory );
81+
82+ String expectResult = "tool result:0123" ;
83+ String realName = "realName" ;
84+ ToolInfo toolInfo = buildToolInfo (realName );
85+ ToolCall toolCall = ToolCall .custom ().id ("id" ).name (toolInfo .name ()).arguments ("{}" ).build ();
86+ List <ToolCall > toolCalls = Collections .singletonList (toolCall );
87+ doAnswer (invocation -> {
88+ Prompt prompt = invocation .getArgument (0 );
89+ return mockGenerateResult (toolCalls , prompt );
90+ }).when (chatModel ).generate (any (), any ());
91+ Map <String , Object > toolContext = MapBuilder .<String , Object >get ().put ("key" , "value" ).build ();
92+ when (this .syncToolCall .call (realName , toolCall .arguments (), toolContext )).thenReturn ("tool result:" );
93+
94+ AiProcessFlow <Prompt , ChatMessage > flow = waterFlowAgent .buildFlow ();
95+ ChatMessage result = flow .converse ()
96+ .bind (ChatOption .custom ().build ())
97+ .bind (AippConst .TOOL_CONTEXT_KEY , toolContext )
98+ .bind (AippConst .TOOLS_KEY , Collections .singletonList (toolInfo ))
99+ .offer (ChatMessages .from (new HumanMessage ("hi" ))).await ();
100+
101+ verify (this .mcpClientFactory , times (0 )).create (any (), any ());
102+ assertEquals (expectResult , result .text ());
103+ }
104+
105+ @ Test
106+ void shouldGetResultWhenRunFlowGivenMcpToolCall () {
107+ WaterFlowAgent waterFlowAgent = new WaterFlowAgent (this .syncToolCall , this .chatModel , this .mcpClientFactory );
108+
109+ String expectResult = "\" tool result:\" 0123" ;
110+ String realName = "realName" ;
111+ String baseUrl = "http://localhost" ;
112+ String sseEndpoint = "/sse" ;
113+ ToolInfo toolInfo = buildMcpToolInfo (realName , baseUrl , sseEndpoint );
114+ ToolCall toolCall = ToolCall .custom ().id ("id" ).name (toolInfo .name ()).arguments ("{}" ).build ();
115+ List <ToolCall > toolCalls = Collections .singletonList (toolCall );
116+ doAnswer (invocation -> {
117+ Prompt prompt = invocation .getArgument (0 );
118+ return mockGenerateResult (toolCalls , prompt );
119+ }).when (chatModel ).generate (any (), any ());
120+ Map <String , Object > toolContext = MapBuilder .<String , Object >get ().put ("key" , "value" ).build ();
121+ McpClient mcpClient = mock (McpClient .class );
122+ when (this .mcpClientFactory .create (baseUrl , sseEndpoint )).thenReturn (mcpClient );
123+ when (mcpClient .callTool (realName , new HashMap <>())).thenReturn ("tool result:" );
124+
125+ AiProcessFlow <Prompt , ChatMessage > flow = waterFlowAgent .buildFlow ();
126+ ChatMessage result = flow .converse ()
127+ .bind (ChatOption .custom ().build ())
128+ .bind (AippConst .TOOL_CONTEXT_KEY , toolContext )
129+ .bind (AippConst .TOOLS_KEY , Collections .singletonList (toolInfo ))
130+ .offer (ChatMessages .from (new HumanMessage ("hi" ))).await ();
131+
132+ verify (this .syncToolCall , times (0 )).call (any (), any (), any ());
133+ assertEquals (expectResult , result .text ());
134+ }
135+
136+ private static Choir <Object > mockGenerateResult (List <ToolCall > toolCalls , Prompt prompt ) {
137+ AtomicInteger step = new AtomicInteger ();
138+ return Choir .create (emitter -> {
139+ if (step .getAndIncrement () == 0 ) {
140+ emitter .emit (new AiMessage ("tool_data" , toolCalls ));
141+ emitter .complete ();
142+ return ;
143+ }
144+ if (CollectionUtils .isNotEmpty (prompt .messages ())) {
145+ emitter .emit (new AiMessage (prompt .messages ().get (prompt .messages ().size () - 1 ).text ()));
146+ }
147+ for (int i = 0 ; i < 4 ; i ++) {
148+ emitter .emit (new AiMessage (String .valueOf (i )));
149+ }
150+ emitter .complete ();
151+ });
152+ }
153+
154+ private static ToolInfo buildToolInfo (String realName ) {
155+ return ToolInfo .custom ()
156+ .name ("tool1" )
157+ .description ("desc" )
158+ .parameters (new HashMap <>())
159+ .extensions (MapBuilder .<String , Object >get ().put (AippConst .TOOL_REAL_NAME , realName ).build ())
160+ .build ();
161+ }
162+
163+ private static ToolInfo buildMcpToolInfo (String realName , String baseUrl , String sseEndpoint ) {
164+ return ToolInfo .custom ()
165+ .name ("tool1" )
166+ .description ("desc" )
167+ .parameters (new HashMap <>())
168+ .extensions (MapBuilder .<String , Object >get ()
169+ .put (AippConst .TOOL_REAL_NAME , realName )
170+ .put (AippConst .MCP_SERVER_KEY ,
171+ MapBuilder .get ().put (AippConst .MCP_SERVER_URL_KEY , baseUrl + sseEndpoint ).build ())
172+ .build ())
173+ .build ();
174+ }
175+ }
0 commit comments