diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 57bcd191b..2d9d055f3 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -24,6 +24,7 @@ import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; import io.modelcontextprotocol.spec.McpSchema.Role; import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; @@ -45,6 +46,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.awaitility.Awaitility.await; +import static org.junit.Assert.assertThat; import static org.mockito.Mockito.mock; public class WebFluxSseIntegrationTests { @@ -142,13 +144,16 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - var messages = List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var craeteMessageRequest = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), - Map.of()); + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { assertThat(result).isNotNull(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 7ba9ccc19..3ff755ca9 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -20,6 +20,7 @@ import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; import io.modelcontextprotocol.spec.McpSchema.Role; import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; @@ -45,6 +46,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.awaitility.Awaitility.await; +import static org.junit.Assert.assertThat; import static org.mockito.Mockito.mock; public class WebMvcSseIntegrationTests { @@ -199,13 +201,16 @@ void testCreateMessageSuccess() throws InterruptedException { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - var messages = List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var craeteMessageRequest = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), - Map.of()); + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { assertThat(result).isNotNull(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 232b7bfdb..37d9e0c0a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.spec; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -763,15 +764,61 @@ public record CallToolResult( // @formatter:off @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record ModelPreferences(// @formatter:off - @JsonProperty("hints") List hints, - @JsonProperty("costPriority") Double costPriority, - @JsonProperty("speedPriority") Double speedPriority, - @JsonProperty("intelligencePriority") Double intelligencePriority) { - } // @formatter:on + @JsonProperty("hints") List hints, + @JsonProperty("costPriority") Double costPriority, + @JsonProperty("speedPriority") Double speedPriority, + @JsonProperty("intelligencePriority") Double intelligencePriority) { + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private List hints; + private Double costPriority; + private Double speedPriority; + private Double intelligencePriority; + + public Builder hints(List hints) { + this.hints = hints; + return this; + } + + public Builder addHint(String name) { + if (this.hints == null) { + this.hints = new ArrayList<>(); + } + this.hints.add(new ModelHint(name)); + return this; + } + + public Builder costPriority(Double costPriority) { + this.costPriority = costPriority; + return this; + } + + public Builder speedPriority(Double speedPriority) { + this.speedPriority = speedPriority; + return this; + } + + public Builder intelligencePriority(Double intelligencePriority) { + this.intelligencePriority = intelligencePriority; + return this; + } + + public ModelPreferences build() { + return new ModelPreferences(hints, costPriority, speedPriority, intelligencePriority); + } + } +} // @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record ModelHint(@JsonProperty("name") String name) { + public static ModelHint of(String name) { + return new ModelHint(name); + } } @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -799,6 +846,66 @@ public enum ContextInclusionStrategy { @JsonProperty("thisServer") THIS_SERVER, @JsonProperty("allServers") ALL_SERVERS } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private List messages; + private ModelPreferences modelPreferences; + private String systemPrompt; + private ContextInclusionStrategy includeContext; + private Double temperature; + private int maxTokens; + private List stopSequences; + private Map metadata; + + public Builder messages(List messages) { + this.messages = messages; + return this; + } + + public Builder modelPreferences(ModelPreferences modelPreferences) { + this.modelPreferences = modelPreferences; + return this; + } + + public Builder systemPrompt(String systemPrompt) { + this.systemPrompt = systemPrompt; + return this; + } + + public Builder includeContext(ContextInclusionStrategy includeContext) { + this.includeContext = includeContext; + return this; + } + + public Builder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder stopSequences(List stopSequences) { + this.stopSequences = stopSequences; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public CreateMessageRequest build() { + return new CreateMessageRequest(messages, modelPreferences, systemPrompt, + includeContext, temperature, maxTokens, stopSequences, metadata); + } + } }// @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index 290141bb2..fd8a4e9f9 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -21,6 +21,7 @@ import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; import io.modelcontextprotocol.spec.McpSchema.Role; import io.modelcontextprotocol.spec.McpSchema.Root; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; @@ -162,13 +163,16 @@ void testCreateMessageSuccess() throws InterruptedException { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - var messages = List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message"))); - var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0); - - var craeteMessageRequest = new McpSchema.CreateMessageRequest(messages, modelPrefs, null, - McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(), - Map.of()); + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { assertThat(result).isNotNull(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index e18c23c4c..1b8adc33b 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -524,10 +524,16 @@ void testCreateMessageRequest() throws Exception { Map metadata = new HashMap<>(); metadata.put("session", "test-session"); - McpSchema.CreateMessageRequest request = new McpSchema.CreateMessageRequest(Collections.singletonList(message), - preferences, "You are a helpful assistant", - McpSchema.CreateMessageRequest.ContextInclusionStrategy.THIS_SERVER, 0.7, 1000, - Arrays.asList("STOP", "END"), metadata); + McpSchema.CreateMessageRequest request = McpSchema.CreateMessageRequest.builder() + .messages(Collections.singletonList(message)) + .modelPreferences(preferences) + .systemPrompt("You are a helpful assistant") + .includeContext(McpSchema.CreateMessageRequest.ContextInclusionStrategy.THIS_SERVER) + .temperature(0.7) + .maxTokens(1000) + .stopSequences(Arrays.asList("STOP", "END")) + .metadata(metadata) + .build(); String value = mapper.writeValueAsString(request); @@ -543,8 +549,12 @@ void testCreateMessageRequest() throws Exception { void testCreateMessageResult() throws Exception { McpSchema.TextContent content = new McpSchema.TextContent("Assistant response"); - McpSchema.CreateMessageResult result = new McpSchema.CreateMessageResult(McpSchema.Role.ASSISTANT, content, - "gpt-4", McpSchema.CreateMessageResult.StopReason.END_TURN); + McpSchema.CreateMessageResult result = McpSchema.CreateMessageResult.builder() + .role(McpSchema.Role.ASSISTANT) + .content(content) + .model("gpt-4") + .stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN) + .build(); String value = mapper.writeValueAsString(result);