diff --git a/pom.xml b/pom.xml index 0145c055ee..98da18ba0e 100644 --- a/pom.xml +++ b/pom.xml @@ -379,6 +379,30 @@ 5.0.1 + + org.springframework.ai + spring-ai-mcp + 1.0.1 + + + + org.springframework.ai + spring-ai-starter-mcp-server-webmvc + 1.0.1 + + + + + + + org.springframework + spring-test + + org.junit.jupiter diff --git a/src/main/java/ca/uhn/fhir/jpa/starter/mcp/CallToolResultFactory.java b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/CallToolResultFactory.java new file mode 100644 index 0000000000..a12672a405 --- /dev/null +++ b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/CallToolResultFactory.java @@ -0,0 +1,46 @@ +package ca.uhn.fhir.jpa.starter.mcp; + +import ca.uhn.fhir.context.FhirContext; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import org.hl7.fhir.instance.model.api.IBaseResource; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +import java.util.Map; + +@Component +public class CallToolResultFactory { + + @Autowired + private FhirContext fhirContext; + + public McpSchema.CallToolResult success( + String resourceType, Interaction interaction, Object response, int status) { + Map payload = Map.of( + "resourceType", resourceType, + "interaction", interaction, + "response", fhirContext.newJsonParser().encodeResourceToString((IBaseResource) response), + "status", status); + + ObjectMapper objectMapper = new ObjectMapper(); + String jacksonData = ""; + try { + jacksonData = objectMapper.writeValueAsString(payload); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + + return McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent(jacksonData)) + .build(); + } + + public McpSchema.CallToolResult failure(String message) { + return McpSchema.CallToolResult.builder() + .isError(true) + .addTextContent(message) + .build(); + } +} diff --git a/src/main/java/ca/uhn/fhir/jpa/starter/mcp/Interaction.java b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/Interaction.java new file mode 100644 index 0000000000..0f8d748a6d --- /dev/null +++ b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/Interaction.java @@ -0,0 +1,35 @@ +package ca.uhn.fhir.jpa.starter.mcp; + +import ca.uhn.fhir.rest.api.RequestTypeEnum; + +public enum Interaction { + CALL_CDS_HOOK("call-cds-hook"), + SEARCH("search"), + READ("read"), + CREATE("create"), + UPDATE("update"), + DELETE("delete"), + PATCH("patch"), + TRANSACTION("transaction"); + + private final String name; + + Interaction(String name) { + this.name = name; + } + + public String getName() { + return name; + } + + public RequestTypeEnum asRequestType() { + return switch (this) { + case SEARCH, READ -> RequestTypeEnum.GET; + case CREATE, TRANSACTION -> RequestTypeEnum.POST; + case UPDATE -> RequestTypeEnum.PUT; + case DELETE -> RequestTypeEnum.DELETE; + case PATCH -> RequestTypeEnum.PATCH; + case CALL_CDS_HOOK -> RequestTypeEnum.POST; + }; + } +} diff --git a/src/main/java/ca/uhn/fhir/jpa/starter/mcp/McpServerConfig.java b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/McpServerConfig.java new file mode 100644 index 0000000000..1867c13329 --- /dev/null +++ b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/McpServerConfig.java @@ -0,0 +1,47 @@ +package ca.uhn.fhir.jpa.starter.mcp; + +import ca.uhn.fhir.rest.server.MCPBridge; +import ca.uhn.fhir.rest.server.RestfulServer; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import org.springframework.boot.web.servlet.ServletRegistrationBean; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; + +// https://mcp-cn.ssshooter.com/sdk/java/mcp-server#sse-servlet +// https://www.baeldung.com/spring-ai-model-context-protocol-mcp +// https://github.com/spring-projects/spring-ai-examples/blob/main/model-context-protocol/weather/manual-webflux-server/src/main/java/org/springframework/ai/mcp/sample/server/McpServerConfig.java +// https://github.com/spring-projects/spring-ai-examples/tree/main/model-context-protocol/weather/starter-stdio-server/src/main/java/org/springframework/ai/mcp/sample/server +// https://github.com/spring-projects/spring-ai-examples/blob/main/model-context-protocol/sampling/mcp-weather-webmvc-server/src/main/java/org/springframework/ai/mcp/sample/server/WeatherService.java + +@Configuration +@EnableWebMvc +public class McpServerConfig implements WebMvcConfigurer { + + @Bean + public MCPBridge mcpBridge(RestfulServer restfulServer, CallToolResultFactory callToolResultFactory) { + return new MCPBridge(restfulServer, callToolResultFactory); + } + + @Bean + public McpSyncServer syncServer(McpSyncServer mcpSyncServer, MCPBridge mcpBridge) { + + mcpBridge.generateTools().stream().forEach(mcpSyncServer::addTool); + return mcpSyncServer; + } + + @Bean + public HttpServletSseServerTransportProvider servletSseServerTransportProvider() { + return new HttpServletSseServerTransportProvider(new ObjectMapper(), "/mcp/message"); + } + + @Bean + public ServletRegistrationBean customServletBean(HttpServletSseServerTransportProvider transportProvider) { + var servetRegistrationBean = new ServletRegistrationBean<>(transportProvider, "/mcp/message", "/sse"); + return servetRegistrationBean; + // return new ServletRegistrationBean(transportProvider); + } +} diff --git a/src/main/java/ca/uhn/fhir/jpa/starter/mcp/RequestBuilder.java b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/RequestBuilder.java new file mode 100644 index 0000000000..a01924cd5b --- /dev/null +++ b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/RequestBuilder.java @@ -0,0 +1,117 @@ +package ca.uhn.fhir.jpa.starter.mcp; + +import ca.uhn.fhir.context.FhirContext; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.gson.Gson; +import org.hl7.fhir.instance.model.api.IBaseResource; +import org.springframework.mock.web.MockHttpServletRequest; + +import java.nio.charset.StandardCharsets; +import java.util.Map; + +public class RequestBuilder { + + private final FhirContext fhirContext; + private final String resourceType; + private final Interaction interaction; + private final Map config; + private final ObjectMapper mapper = new ObjectMapper(); + private final String headers; + private String resource; + + public RequestBuilder(FhirContext fhirContext, Map contextMap, Interaction interaction) { + this.config = contextMap; + if (interaction == Interaction.TRANSACTION) this.resourceType = ""; + else if (contextMap.get("resourceType") instanceof String rt && !rt.isBlank()) this.resourceType = rt; + else throw new IllegalArgumentException("Missing or invalid 'resourceType' in contextMap"); + // this.resourceType = contextMap.get("resourceType") instanceof String rt ? rt : null; + this.headers = contextMap.get("headers") instanceof String h ? h : null; + this.resource = null; + try { + resource = mapper.writeValueAsString(contextMap.get("resource")); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + this.interaction = interaction; + this.fhirContext = fhirContext; + } + + public MockHttpServletRequest buildRequest() { + String basePath = "/" + resourceType; + String method; + MockHttpServletRequest req; + + switch (interaction) { + case SEARCH -> { + method = "GET"; + req = new MockHttpServletRequest(method, basePath); + if (config.get("searchParams") instanceof Map sp) { + sp.forEach((k, v) -> req.addParameter(k.toString(), v.toString())); + } + } + case READ -> { + method = "GET"; + String id = requireString("id"); + req = new MockHttpServletRequest(method, basePath + "/" + id); + } + case CREATE, TRANSACTION -> { + method = "POST"; + req = new MockHttpServletRequest(method, basePath); + applyResourceBody(req, "resource"); + } + case UPDATE -> { + method = "PUT"; + String id = requireString("id"); + req = new MockHttpServletRequest(method, basePath + "/" + id); + applyResourceBody(req, "resource"); + } + case DELETE -> { + method = "DELETE"; + String id = requireString("id"); + req = new MockHttpServletRequest(method, basePath + "/" + id); + } + case PATCH -> { + method = "PATCH"; + String id = requireString("id"); + req = new MockHttpServletRequest(method, basePath + "/" + id); + applyPatchBody(req); + } + default -> throw new IllegalArgumentException("Unsupported interaction: " + interaction); + } + + req.setContentType("application/fhir+json"); + req.addHeader("Accept", "application/fhir+json"); + return req; + } + + private void applyResourceBody(MockHttpServletRequest req, String key) { + Object resourceObj = config.get(key); + String json = new Gson().toJson(resourceObj, Map.class); + req.setContent(json.getBytes(StandardCharsets.UTF_8)); + } + + private void applyPatchBody(MockHttpServletRequest req) { + Object patchBody = config.get("patch"); + if (patchBody == null) { + throw new IllegalArgumentException("Missing 'patch' for patch interaction"); + } + String content; + if (patchBody instanceof String s) { + content = s; + } else if (patchBody instanceof IBaseResource r) { + content = fhirContext.newJsonParser().encodeResourceToString(r); + } else { + throw new IllegalArgumentException("Unsupported patch body type: " + patchBody.getClass()); + } + req.setContent(content.getBytes(StandardCharsets.UTF_8)); + } + + private String requireString(String key) { + Object val = config.get(key); + if (!(val instanceof String s) || s.isBlank()) { + throw new IllegalArgumentException("Missing or invalid '" + key + "'"); + } + return (String) val; + } +} diff --git a/src/main/java/ca/uhn/fhir/jpa/starter/mcp/ToolFactory.java b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/ToolFactory.java new file mode 100644 index 0000000000..693108deb2 --- /dev/null +++ b/src/main/java/ca/uhn/fhir/jpa/starter/mcp/ToolFactory.java @@ -0,0 +1,327 @@ +package ca.uhn.fhir.jpa.starter.mcp; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.Tool; + +public class ToolFactory { + + private static final String READ_FHIR_RESOURCE_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "type of the resource to read" + }, + "id": { + "type": "string", + "description": "id of the resource to read" + } + } + + } + """; + + private static final String CREATE_FHIR_RESOURCE_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "Type of the resource to create" + }, + "resource": { + "type": "object", + "description": "Resource content in JSON format" + }, + "headers": { + "type": "object", + "description": "Headers for create request.\\nAvailable headers: If-None-Exist header for conditional create where the value is search param string.\\nFor example: {\\"If-None-Exist\\": \\"active=false\\"}" + } + }, + "required": ["resourceType", "resource"] + } + """; + + private static final String UPDATE_FHIR_RESOURCE_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "Type of the resource to update" + }, + "id": { + "type": "string", + "description": "ID of the resource to update" + }, + "resource": { + "type": "object", + "description": "Updated resource content in JSON format" + } + }, + "required": ["resourceType", "id", "resource"] + } + """; + + private static final String CONDITIONAL_UPDATE_FHIR_RESOURCE_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "Type of the resource to update" + }, + "resource": { + "type": "object", + "description": "Updated resource content in JSON format" + }, + "query": { + "type": "string", + "description": "Query string with search params separate by \\",\\". For example: \\"_id=pt-1,name=ivan\\". Uses for conditional update." + }, + "headers": { + "type": "object", + "description": "Headers for create request.\\nAvailable headers: If-None-Match header for conditional update where the value is ETag.\\nFor example: {\\"If-None-Match\\": \\"12345\\"}" + } + }, + "required": ["resourceType", "resource"] + } + """; + + private static final String CONDITIONAL_PATCH_FHIR_RESOURCE_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "Type of the resource to patch" + }, + "resource": { + "type": "object", + "description": "Resource content to patch in JSON format" + }, + "query": { + "type": "string", + "description": "Query string with search params separate by \\",\\". For example: \\"_id=pt-1,name=ivan\\". Uses for conditional patch." + }, + "headers": { + "type": "object", + "description": "Headers for create request.\\nAvailable headers: If-None-Match header for conditional patch where the value is ETag.\\nFor example: {\\"If-None-Match\\": \\"12345\\"}" + } + }, + "required": ["resourceType", "resource"] + } + """; + + private static final String PATCH_FHIR_RESOURCE_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "Type of the resource to patch" + }, + "id": { + "type": "string", + "description": "ID of the resource to patch" + }, + "resource": { + "type": "object", + "description": "Resource content to patch in JSON format" + } + }, + "required": ["resourceType", "id", "resource"] + } + """; + + private static final String DELETE_FHIR_RESOURCE_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "Type of the resource to delete" + }, + "id": { + "type": "string", + "description": "ID of the resource to delete" + } + }, + "required": ["resourceType", "id"] + } + """; + + private static final String SEARCH_FHIR_RESOURCES_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "Type of the resource to search" + }, + "query": { + "type": "string", + "description": "Query string with search params separate by \\",\\". For example: \\"_id=pt-1,name=ivan\\"" + } + }, + "required": ["resourceType", "query"] + } + """; + + private static final String CREATE_FHIR_TRANSACTION_SCHEMA = + """ + { + "type": "object", + "properties": { + "resourceType": { + "type": "string", + "description": "A Bundle resource type with type 'transaction' containing multiple FHIR resources" + }, + "resource": { + "type": "object", + "description": "A FHIR Bundle Resource content in JSON format" + } + }, + "required": ["resourceType", "resource"] + } + """; + + // TODO Add a tool for the CDS Hooks discovery endpoint + // Alternatively, should each service be a separate tool? + + // TODO Add other fields from https://cds-hooks.hl7.org/STU2/#http-request-1 + // TODO Context here is for the patient-view hook, https://cds-hooks.hl7.org/hooks/STU1/patient-view.html#context + private static final String CALL_CDS_HOOK_SCHEMA_2_0_1 = + """ + { + "type": "object", + "properties": { + "service": { + "type": "string", + "description": "The CDS Service to call." + }, + "hook": { + "type": "string", + "description": "The hook that triggered this CDS Service call." + }, + "hookInstance": { + "type": "string", + "description": "A universally unique identifier (UUID) for this particular hook call." + }, + "hookContext": { + "type": "object", + "description": "Hook-specific contextual data that the CDS service will need.", + "properties": { + "userId": { + "type": "string", + "description": "The id of the current user. Must be in the format [ResourceType]/[id]." + }, + "patientId": { + "type": "string", + "description": "The FHIR Patient.id of the current patient in context" + }, + "encounterId": { + "type": "string", + "description": "The FHIR Encounter.id of the current encounter in context." + } + } + }, + "prefetch": { + "type": "object", + "description": "Additional data to prefetch for the CDS service call." + } + }, + "required": ["service", "hook", "hookInstance", "context"] + } + """; + + public static Tool readFhirResource() throws JsonProcessingException { + return new Tool( + "read-fhir-resource", + "Read an individual FHIR resource", + mapper.readValue(READ_FHIR_RESOURCE_SCHEMA, McpSchema.JsonSchema.class)); + } + + public static Tool createFhirResource() throws JsonProcessingException { + return new Tool( + "create-fhir-resource", + "Create a new FHIR resource", + mapper.readValue(CREATE_FHIR_RESOURCE_SCHEMA, McpSchema.JsonSchema.class)); + } + + public static Tool updateFhirResource() throws JsonProcessingException { + return new Tool( + "update-fhir-resource", + "Update an existing FHIR resource", + mapper.readValue(UPDATE_FHIR_RESOURCE_SCHEMA, McpSchema.JsonSchema.class)); + } + + public static Tool conditionalUpdateFhirResource() throws JsonProcessingException { + return new Tool( + "conditional-update-fhir-resource", + "Conditional update an existing FHIR resource", + mapper.readValue(CONDITIONAL_UPDATE_FHIR_RESOURCE_SCHEMA, McpSchema.JsonSchema.class)); + } + + public static Tool conditionalPatchFhirResource() throws JsonProcessingException { + return new Tool( + "conditional-patch-fhir-resource", + "Conditional patch an existing FHIR resource", + mapper.readValue(CONDITIONAL_PATCH_FHIR_RESOURCE_SCHEMA, McpSchema.JsonSchema.class)); + } + + public static Tool patchFhirResource() throws JsonProcessingException { + return new Tool( + "patch-fhir-resource", + "Patch an existing FHIR resource", + mapper.readValue(PATCH_FHIR_RESOURCE_SCHEMA, McpSchema.JsonSchema.class)); + } + + public static Tool deleteFhirResource() throws JsonProcessingException { + return new Tool( + "delete-fhir-resource", + "Delete an existing FHIR resource", + mapper.readValue(DELETE_FHIR_RESOURCE_SCHEMA, McpSchema.JsonSchema.class)); + } + + public static Tool searchFhirResources() throws JsonProcessingException { + return new Tool( + "search-fhir-resources", + "Search an existing FHIR resources", + mapper.readValue(SEARCH_FHIR_RESOURCES_SCHEMA, McpSchema.JsonSchema.class)); + } + + public static Tool createFhirTransaction() throws JsonProcessingException { + return new Tool( + "create-fhir-transaction", + "Create a FHIR transaction", + mapper.readValue(CREATE_FHIR_RESOURCE_SCHEMA, McpSchema.JsonSchema.class)); + } + + public static Tool callCdsHook() throws JsonProcessingException { + return new Tool( + "call-cds-hook", + "Call a CDS Hook", + mapper.readValue(CALL_CDS_HOOK_SCHEMA_2_0_1, McpSchema.JsonSchema.class)); + } + + public static final ObjectMapper mapper = new ObjectMapper() + .enable(JsonParser.Feature.ALLOW_COMMENTS) + .enable(JsonParser.Feature.ALLOW_SINGLE_QUOTES) + .enable(JsonParser.Feature.ALLOW_UNQUOTED_FIELD_NAMES) + .enable(JsonParser.Feature.INCLUDE_SOURCE_IN_LOCATION) + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); +} diff --git a/src/main/java/ca/uhn/fhir/rest/server/MCPBridge.java b/src/main/java/ca/uhn/fhir/rest/server/MCPBridge.java new file mode 100644 index 0000000000..4930e4f851 --- /dev/null +++ b/src/main/java/ca/uhn/fhir/rest/server/MCPBridge.java @@ -0,0 +1,216 @@ +package ca.uhn.fhir.rest.server; + +import static org.opencds.cqf.fhir.cr.hapi.config.test.TestCdsHooksConfig.CDS_HOOKS_OBJECT_MAPPER_FACTORY; + +import ca.uhn.fhir.context.FhirContext; +import ca.uhn.fhir.jpa.starter.cdshooks.CdsHooksRequest; +import ca.uhn.fhir.jpa.starter.mcp.CallToolResultFactory; +import ca.uhn.fhir.jpa.starter.mcp.Interaction; +import ca.uhn.fhir.jpa.starter.mcp.RequestBuilder; +import ca.uhn.fhir.jpa.starter.mcp.ToolFactory; +import ca.uhn.fhir.rest.api.server.cdshooks.CdsServiceRequestContextJson; +import ca.uhn.hapi.fhir.cdshooks.api.ICdsServiceRegistry; +import ca.uhn.hapi.fhir.cdshooks.api.json.CdsServiceResponseJson; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonParser; +import com.google.gson.JsonSyntaxException; + +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.spec.McpSchema; + +import org.apache.jena.base.Sys; +import org.hl7.fhir.instance.model.api.IBaseResource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.stereotype.Component; +import org.w3._1999.xhtml.I; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +@Component +public class MCPBridge { + + private static final Logger logger = LoggerFactory.getLogger(MCPBridge.class); + + @Autowired + ICdsServiceRegistry cdsServiceRegistry; + + @Autowired + @Qualifier(CDS_HOOKS_OBJECT_MAPPER_FACTORY) + ObjectMapper objectMapper; + + private final RestfulServer restfulServer; + private final FhirContext fhirContext; + private final CallToolResultFactory callToolResultFactory; + + public MCPBridge(RestfulServer restfulServer, CallToolResultFactory callToolResultFactory) { + this.restfulServer = restfulServer; + this.fhirContext = restfulServer.getFhirContext(); + this.callToolResultFactory = callToolResultFactory; + } + + public List generateTools() { + + try { + return List.of( + // TODO Add CDS Hooks tool only if CR & CDS Hooks are enabled (CDS Hooks depends on CR) + new McpServerFeatures.SyncToolSpecification( + ToolFactory.callCdsHook(), + (exchange, contextMap) -> getToolResult(contextMap, Interaction.CALL_CDS_HOOK)), + new McpServerFeatures.SyncToolSpecification( + ToolFactory.createFhirResource(), + (exchange, contextMap) -> getToolResult(contextMap, Interaction.CREATE)), + new McpServerFeatures.SyncToolSpecification( + ToolFactory.readFhirResource(), + (exchange, contextMap) -> getToolResult(contextMap, Interaction.READ)), + new McpServerFeatures.SyncToolSpecification( + ToolFactory.updateFhirResource(), + (exchange, contextMap) -> getToolResult(contextMap, Interaction.UPDATE)), + new McpServerFeatures.SyncToolSpecification( + ToolFactory.deleteFhirResource(), + (exchange, contextMap) -> getToolResult(contextMap, Interaction.DELETE)), + new McpServerFeatures.SyncToolSpecification( + ToolFactory.conditionalPatchFhirResource(), + (exchange, contextMap) -> getToolResult(contextMap, Interaction.PATCH)), + new McpServerFeatures.SyncToolSpecification( + ToolFactory.searchFhirResources(), + (exchange, contextMap) -> getToolResult(contextMap, Interaction.SEARCH)), + new McpServerFeatures.SyncToolSpecification( + ToolFactory.conditionalUpdateFhirResource(), + (exchange, contextMap) -> getToolResult(contextMap, Interaction.UPDATE)), + new McpServerFeatures.SyncToolSpecification( + ToolFactory.patchFhirResource(), + (exchange, contextMap) -> getToolResult(contextMap, Interaction.PATCH)), + new McpServerFeatures.SyncToolSpecification( + ToolFactory.createFhirTransaction(), + (exchange, contextMap) -> getToolResult(contextMap, Interaction.TRANSACTION))); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private McpSchema.CallToolResult getToolResult(Map contextMap, Interaction interaction) { + + if (interaction == Interaction.CALL_CDS_HOOK) { + + // Print the keys of contextMap + for (String key : contextMap.keySet()) { + System.out.println("Context map key: " + key); + } + + // TODO Build up CDS Hooks request JSON from contextMap + CdsHooksRequest request = new CdsHooksRequest(); + request.setHook(contextMap.get("hook").toString()); + request.setHookInstance(contextMap.get("hookInstance").toString()); + + // Context + CdsServiceRequestContextJson context = new CdsServiceRequestContextJson(); + Map hookContext = (Map) contextMap.get("hookContext"); + if (hookContext.containsKey("userId")) { + context.put("userId", hookContext.get("userId").toString()); + } + if (hookContext.containsKey("patientId")) { + context.put("patientId", hookContext.get("patientId").toString()); + } + if (hookContext.containsKey("encounterId")) { + context.put("encounterId", hookContext.get("encounterId").toString()); + } + request.setContext(context); + + // Prefetch + if (contextMap.containsKey("prefetch")) { + Object prefetch = contextMap.get("prefetch"); + Map prefetchMap = (Map) prefetch; + for (Map.Entry entry : prefetchMap.entrySet()) { + String key = entry.getKey(); + Object value = entry.getValue(); + + // Object is a String -> Object map + // Use a standard JSON library to convert it + String json = new Gson().toJson(value); + System.out.println("Prefetch data for " + key + ": " + value); + System.out.println("FHIR JSON: " + json); + IBaseResource resource = fhirContext.newJsonParser().parseResource(json); + request.addPrefetch(key, resource); + } + } + + try { + String requestString = objectMapper.writeValueAsString(request); + System.out.println("CDS Hooks request JSON: " + requestString); + } catch (JsonProcessingException e) { + e.printStackTrace(); + } + String service = contextMap.get("service").toString(); + CdsServiceResponseJson serviceResponseJson = cdsServiceRegistry.callService(service, request); + + // Copy from CdsHooksServlet, including the comment below + // Using GSON pretty print format as Jackson's is ugly + String jsonResponse = ""; + try { + jsonResponse = new GsonBuilder() + .disableHtmlEscaping() + .setPrettyPrinting() + .create() + .toJson(JsonParser.parseString(objectMapper.writeValueAsString(serviceResponseJson))); + } catch (JsonSyntaxException e) { + // TODO Return MCP Error + jsonResponse = "{\"error\": \"" + e.getMessage() + "\"}"; + logger.error(e.getMessage(), e); + e.printStackTrace(); + } catch (JsonProcessingException e) { + // TODO Return MCP Error + jsonResponse = "{\"error\": \"" + e.getMessage() + "\"}"; + logger.error(e.getMessage(), e); + e.printStackTrace(); + } + + return McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent(jsonResponse)) + .build(); + } + + MockHttpServletResponse response = new MockHttpServletResponse(); + MockHttpServletRequest request = new RequestBuilder(fhirContext, contextMap, interaction).buildRequest(); + + try { + restfulServer.handleRequest(interaction.asRequestType(), request, response); + int status = response.getStatus(); + String body = response.getContentAsString(); + + if (status >= 200 && status < 300) { + if (body.isBlank()) { + return McpSchema.CallToolResult.builder() + .isError(true) + .addTextContent("Empty successful response for " + interaction) + .build(); + } + IBaseResource parsed = fhirContext.newJsonParser().parseResource(body); + + return callToolResultFactory.success( + contextMap.get("resourceType").toString(), interaction, parsed, status); + } else { + return callToolResultFactory.failure(String.format("FHIR server error %d: %s", status, body)); + } + } catch (IOException e) { + logger.error(e.getMessage(), e); + return callToolResultFactory.failure("Dispatch error: " + e.getMessage()); + } catch (Exception e) { + logger.error(e.getMessage(), e); + return McpSchema.CallToolResult.builder() + .isError(true) + .addTextContent("Unexpected error: " + e.getMessage()) + .build(); + } + } +} diff --git a/src/main/resources/application.yaml b/src/main/resources/application.yaml index cff71e2dcd..b143c1d673 100644 --- a/src/main/resources/application.yaml +++ b/src/main/resources/application.yaml @@ -36,7 +36,39 @@ management: export: enabled: true spring: + ai: + mcp: + server: + name: FHIR MCP Server + version: 1.0.0 + type: SYNC + instructions: "This server provides access to a FHIR RESTful API. You can use it to query FHIR resources, perform operations, and retrieve data in a structured format." + sse-message-endpoint: /mcp/messages + capabilities: + tool: true + resource: true + prompt: true + completion: true + stdio: false + enabled: true + #endpoint: /mcp + + #schema: + # fhir-enabled: true + # fhir: + # base-url: http://localhost:8080/fhir + + #query: + # prompt: + # template: | + # You are a FHIR assistant. Translate the following question into a valid FHIR RESTful API query: + # "{{query}}" + # Use the provided FHIR schema: + # {{schema}} + #base-url: /api/v1 + main: + allow-bean-definition-overriding: true allow-circular-references: true flyway: enabled: false @@ -85,7 +117,7 @@ hapi: ### This flag when enabled to true, will avail evaluate measure operations from CR Module. ### Flag is false by default, can be passed as command line argument to override. cr: - enabled: false + enabled: true caregaps: reporter: "default" section_author: "default" @@ -129,7 +161,7 @@ hapi: profile_mode: DECLARED # ENFORCED, DECLARED, OPTIONAL, TRUST, OFF cdshooks: - enabled: false + enabled: true clientIdHeaderName: client_id ### This enables the swagger-ui at /fhir/swagger-ui/index.html as well as the /fhir/api-docs (see https://hapifhir.io/hapi-fhir/docs/server_plain/openapi.html) @@ -242,24 +274,25 @@ hapi: - http://loinc.org/* - https://loinc.org/* - ### Uncomment the following section, and any sub-properties you need in order to enable - ### partitioning support on this server. - partitioning: - allow_references_across_partitions: false - partitioning_include_in_search_hashes: false - default_partition_id: 0 - ### Enable the following setting to enable Database Partitioning Mode - ### See: https://hapifhir.io/hapi-fhir/docs/server_jpa_partitioning/db_partition_mode.html - database_partition_mode_enabled: true - ### Partition Style: Partitioning requires a partition interceptor which helps the server - ### select which partition(s) should be accessed for a given request. You can supply your - ### own interceptor (see https://hapifhir.io/hapi-fhir/docs/server_jpa_partitioning/partitioning.html#partition-interceptors ) - ### but the following setting can also be used to use a built-in form. - ### Patient ID Partitioning Mode uses the patient/subject ID to determine the partition - patient_id_partitioning_mode: true - ### Request tenant mode can be used for a multi-tenancy setup where the request path is - ### expected to have an additional path element, e.g. GET http://example.com/fhir/TENANT-ID/Patient/A - request_tenant_partitioning_mode: false +# ### Uncomment the following section, and any sub-properties you need in order to enable +# ### partitioning support on this server. +# partitioning: +# +# allow_references_across_partitions: false +# partitioning_include_in_search_hashes: false +# default_partition_id: 0 +# ### Enable the following setting to enable Database Partitioning Mode +# ### See: https://hapifhir.io/hapi-fhir/docs/server_jpa_partitioning/db_partition_mode.html +# database_partition_mode_enabled: true +# ### Partition Style: Partitioning requires a partition interceptor which helps the server +# ### select which partition(s) should be accessed for a given request. You can supply your +# ### own interceptor (see https://hapifhir.io/hapi-fhir/docs/server_jpa_partitioning/partitioning.html#partition-interceptors ) +# ### but the following setting can also be used to use a built-in form. +# ### Patient ID Partitioning Mode uses the patient/subject ID to determine the partition +# patient_id_partitioning_mode: true +# ### Request tenant mode can be used for a multi-tenancy setup where the request path is +# ### expected to have an additional path element, e.g. GET http://example.com/fhir/TENANT-ID/Patient/A +# request_tenant_partitioning_mode: false cors: allow_Credentials: true diff --git a/src/test/java/ca/uhn/fhir/jpa/starter/McpTests.java b/src/test/java/ca/uhn/fhir/jpa/starter/McpTests.java new file mode 100644 index 0000000000..df9a687eff --- /dev/null +++ b/src/test/java/ca/uhn/fhir/jpa/starter/McpTests.java @@ -0,0 +1,62 @@ +package ca.uhn.fhir.jpa.starter; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.util.Map; +import java.util.function.Function; + +public class McpTests { + + @Test + @Disabled + public void mcpTests() { + + // Configure sampling handler + Function samplingHandler = request -> { + // Sampling implementation that interfaces with LLM + //return new McpSchema.CreateMessageResult(response); + return null; + }; + + HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder("http://localhost:8080/sse").build(); + // Create a sync client with custom configuration + McpSchema.Role response = null; + McpSyncClient client = McpClient.sync(transport).requestTimeout(Duration.ofSeconds(10)).capabilities(McpSchema.ClientCapabilities.builder().roots(true) // Enable roots capability + .sampling() // Enable sampling capability + + .build()) + //.sampling(samplingHandler) + .build(); + +// Initialize connection + client.initialize(); + +// List available tools + McpSchema.ListToolsResult tools = client.listTools(); + +// Call a tool + McpSchema.CallToolResult result = client.callTool(new McpSchema.CallToolRequest("calculator", Map.of("operation", "add", "a", 2, "b", 3))); + +// List and read resources + McpSchema.ListResourcesResult resources = client.listResources(); + McpSchema.ReadResourceResult resource = client.readResource(new McpSchema.ReadResourceRequest("resource://uri")); + +// List and use prompts + McpSchema.ListPromptsResult prompts = client.listPrompts(); + McpSchema.GetPromptResult prompt = client.getPrompt(new McpSchema.GetPromptRequest("greeting", Map.of("name", "Spring"))); + +// Add/remove roots + //client.addRoot(new McpSchema.Root("file:///path", "description")); + //client.removeRoot("file:///path"); + + //client.callTool() +// Close client + client.closeGracefully(); + } +} diff --git a/src/test/resources/mcp/hello-patient-request.json b/src/test/resources/mcp/hello-patient-request.json new file mode 100644 index 0000000000..2a26967d24 --- /dev/null +++ b/src/test/resources/mcp/hello-patient-request.json @@ -0,0 +1,18 @@ +{ + "hook": "patient-view", + "hookInstance": "8d5a3a2e-6d8b-4f7c-bb2d-2f1b8cf1d7a1", + "context": { + "userId": "Practitioner/123", + "patientId": "123", + "encounterId": "456" + }, + "prefetch": { + "item1": { + "resourceType": "Patient", + "gender": "male", + "birthDate": "1989-10-23", + "id": "123", + "active": true + } + } +} diff --git a/src/test/resources/mcp/mcp-hookContext-object.json b/src/test/resources/mcp/mcp-hookContext-object.json new file mode 100644 index 0000000000..b4648e6583 --- /dev/null +++ b/src/test/resources/mcp/mcp-hookContext-object.json @@ -0,0 +1,5 @@ +{ + "userId": "Practitioner/123", + "patientId": "123", + "encounterId": "456" +} diff --git a/src/test/resources/mcp/mpc-prefetch-object.json b/src/test/resources/mcp/mpc-prefetch-object.json new file mode 100644 index 0000000000..4b3008872a --- /dev/null +++ b/src/test/resources/mcp/mpc-prefetch-object.json @@ -0,0 +1,9 @@ +{ + "item1": { + "resourceType": "Patient", + "gender": "male", + "birthDate": "1989-10-23", + "id": "123", + "active": true + } +} diff --git a/src/test/resources/mcp/plandefinition-hello-patient.xml b/src/test/resources/mcp/plandefinition-hello-patient.xml new file mode 100644 index 0000000000..a9c4aeea65 --- /dev/null +++ b/src/test/resources/mcp/plandefinition-hello-patient.xml @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + <type> + <coding> + <system value="http://terminology.hl7.org/CodeSystem/plan-definition-type" /> + <code value="eca-rule" /> + <display value="ECA Rule" /> + </coding> + </type> + <status value="draft" /> + <experimental value="true" /> + <date value="2024-09-28" /> + <description value="Demo PlanDefinition for Hello Patient" /> + <action> + <title value="Hello, Patient!" /> + <description value="Please state the nature of the medical emergency." /> + <trigger> + <type value="named-event" /> + <name value="patient-view" /> + </trigger> + <condition> + <kind value="applicability" /> + <expression> + <language value="text/cql" /> + <expression value="true" /> + </expression> + </condition> + </action> +</PlanDefinition>