Skip to content

Commit 02a23ab

Browse files
committed
feat: support MCP tool calls
1 parent 47953df commit 02a23ab

File tree

6 files changed

+90
-0
lines changed

6 files changed

+90
-0
lines changed

src/main/java/com/epam/aidial/deployment/manager/service/McpService.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
import io.modelcontextprotocol.spec.McpSchema;
1111
import lombok.RequiredArgsConstructor;
1212
import org.apache.commons.lang3.StringUtils;
13+
import org.apache.commons.lang3.exception.ExceptionUtils;
1314
import org.springframework.stereotype.Service;
15+
import org.springframework.transaction.annotation.Transactional;
1416

1517
import java.util.function.BiFunction;
1618

@@ -48,6 +50,21 @@ private <T> T get(String deploymentId, String nextCursor, BiFunction<McpSyncClie
4850
}
4951
}
5052

53+
@Transactional(readOnly = true)
54+
public McpSchema.CallToolResult callTool(String deploymentId, McpSchema.CallToolRequest callToolRequest) {
55+
var deployment = getDeployment(deploymentId);
56+
var endpointPath = mcpEndpointPathResolver.resolveEndpointPath(deployment);
57+
58+
try (var mcpClient = mcpClientFactory.create(deployment.getUrl(), endpointPath, deployment.getTransport())) {
59+
mcpClient.initialize();
60+
return mcpClient.callTool(callToolRequest);
61+
} catch (Exception e) {
62+
String reason = ExceptionUtils.getRootCause(e).getMessage();
63+
throw new McpClientException(("Failed to call a tool via MCP server. Deployment id: %s. Transport: '%s'. Path: '%s'. Reason: %s")
64+
.formatted(deploymentId, deployment.getTransport(), deployment.getMcpEndpointPath(), reason), e);
65+
}
66+
}
67+
5168
private McpDeployment getDeployment(String deploymentId) {
5269
var deployment = deploymentService.getDeployment(deploymentId)
5370
.orElseThrow(() -> new EntityNotFoundException("Deployment not found by id: %s"

src/main/java/com/epam/aidial/deployment/manager/web/controller/McpController.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
import com.epam.aidial.deployment.manager.service.McpService;
44
import io.modelcontextprotocol.spec.McpSchema;
55
import lombok.RequiredArgsConstructor;
6+
import org.springframework.http.MediaType;
67
import org.springframework.util.MimeTypeUtils;
78
import org.springframework.web.bind.annotation.GetMapping;
89
import org.springframework.web.bind.annotation.PathVariable;
10+
import org.springframework.web.bind.annotation.PostMapping;
11+
import org.springframework.web.bind.annotation.RequestBody;
912
import org.springframework.web.bind.annotation.RequestMapping;
1013
import org.springframework.web.bind.annotation.RequestParam;
1114
import org.springframework.web.bind.annotation.RestController;
@@ -24,6 +27,12 @@ public McpSchema.ListToolsResult getTools(@PathVariable String deploymentId,
2427
return mcpService.getTools(deploymentId, nextCursor);
2528
}
2629

30+
@PostMapping(path = "/{deploymentId}/call-tool", produces = MediaType.APPLICATION_JSON_VALUE)
31+
public McpSchema.CallToolResult callTool(@PathVariable String deploymentId,
32+
@RequestBody McpSchema.CallToolRequest callToolRequest) {
33+
return mcpService.callTool(deploymentId, callToolRequest);
34+
}
35+
2736
@GetMapping(path = "/{deploymentId}/resources",
2837
produces = MimeTypeUtils.APPLICATION_JSON_VALUE)
2938
public McpSchema.ListResourcesResult getResources(@PathVariable String deploymentId,

src/test/java/com/epam/aidial/deployment/manager/service/McpServiceTest.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,29 @@ void testGetTools_withNullCursor() {
205205
verify(mcpSyncClient).listTools(null);
206206
}
207207

208+
@Test
209+
void testCallTool() {
210+
// Given
211+
var deployment = createDeployment();
212+
String endpointPath = "/sseCustom";
213+
deployment.setMcpEndpointPath(endpointPath);
214+
deployment.setTransport(McpTransport.SSE);
215+
216+
when(deploymentService.getDeployment(DEPLOYMENT_ID)).thenReturn(Optional.of(deployment));
217+
when(mcpClientFactory.create(DEPLOYMENT_URL, endpointPath, McpTransport.SSE)).thenReturn(mcpSyncClient);
218+
219+
var callToolRequest = Mockito.mock(McpSchema.CallToolRequest.class);
220+
var expectedCallToolResult = Mockito.mock(McpSchema.CallToolResult.class);
221+
when(mcpSyncClient.callTool(callToolRequest)).thenReturn(expectedCallToolResult);
222+
223+
// When
224+
var result = mcpService.callTool(DEPLOYMENT_ID, callToolRequest);
225+
226+
// Then
227+
assertThat(result).isEqualTo(expectedCallToolResult);
228+
verify(mcpSyncClient).callTool(callToolRequest);
229+
}
230+
208231
private McpDeployment createDeployment() {
209232
var deployment = new McpDeployment();
210233
deployment.setId(DEPLOYMENT_ID);

src/test/java/com/epam/aidial/deployment/manager/web/controller/none/McpControllerTest.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
import com.epam.aidial.deployment.manager.service.McpService;
55
import com.epam.aidial.deployment.manager.utils.ResourceUtils;
66
import com.epam.aidial.deployment.manager.web.controller.McpController;
7+
import com.fasterxml.jackson.core.type.TypeReference;
78
import com.fasterxml.jackson.databind.ObjectMapper;
89
import io.modelcontextprotocol.spec.McpSchema;
910
import org.junit.jupiter.api.Test;
1011
import org.springframework.beans.factory.annotation.Autowired;
1112
import org.springframework.boot.test.autoconfigure.web.servlet.WebMvcTest;
1213
import org.springframework.context.annotation.Import;
14+
import org.springframework.http.HttpHeaders;
15+
import org.springframework.http.MediaType;
1316
import org.springframework.test.context.bean.override.mockito.MockitoBean;
1417
import org.springframework.test.json.JsonCompareMode;
1518

@@ -18,6 +21,7 @@
1821
import static org.mockito.Mockito.verify;
1922
import static org.mockito.Mockito.when;
2023
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
24+
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
2125
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
2226
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
2327

@@ -27,6 +31,9 @@
2731
})
2832
class McpControllerTest extends AbstractControllerNoneSecureTest {
2933

34+
private static final String CALL_TOOL_REQUEST_DTO_JSON_PATH = "/mcp/call_tool_request_dto.json";
35+
private static final String CALL_TOOL_RESULT_DTO_JSON_PATH = "/mcp/call_tool_result_dto.json";
36+
3037
private static final String DEPLOYMENT_ID = String.valueOf(UUID.randomUUID());
3138
private static final String NEXT_CURSOR = "some-cursor";
3239

@@ -136,4 +143,23 @@ void testGetPromptsWithCursor() throws Exception {
136143

137144
verify(mcpService).getPrompts(DEPLOYMENT_ID, NEXT_CURSOR);
138145
}
146+
147+
@Test
148+
void testCallTool() throws Exception {
149+
var callToolRequestDtoJson = ResourceUtils.readResource(CALL_TOOL_REQUEST_DTO_JSON_PATH);
150+
var callToolRequestDto = objectMapper.readValue(callToolRequestDtoJson, new TypeReference<McpSchema.CallToolRequest>() {
151+
});
152+
153+
var callToolResultDtoJson = ResourceUtils.readResource(CALL_TOOL_RESULT_DTO_JSON_PATH);
154+
var callToolResultDto = objectMapper.readValue(callToolResultDtoJson, new TypeReference<McpSchema.CallToolResult>() {
155+
});
156+
157+
when(mcpService.callTool(DEPLOYMENT_ID, callToolRequestDto)).thenReturn(callToolResultDto);
158+
159+
mockMvc.perform(post("/api/v1/deployments/mcp/{deploymentId}/call-tool", DEPLOYMENT_ID)
160+
.header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
161+
.content(callToolRequestDtoJson))
162+
.andExpect(status().isOk())
163+
.andExpect(content().json(callToolResultDtoJson, JsonCompareMode.LENIENT));
164+
}
139165
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"name": "get_simple_price",
3+
"arguments": {
4+
"vs_currencies": "usd",
5+
"ids": "bitcoin"
6+
}
7+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"content": [
3+
{
4+
"type": "text",
5+
"text": "{\n \"bitcoin\": {\n \"usd\": 93575\n }\n}"
6+
}
7+
]
8+
}

0 commit comments

Comments
 (0)