diff --git a/src/ModelContextProtocol/Client/McpClientExtensions.cs b/src/ModelContextProtocol/Client/McpClientExtensions.cs index 9cc8ab25f..aaaa4263c 100644 --- a/src/ModelContextProtocol/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol/Client/McpClientExtensions.cs @@ -926,7 +926,7 @@ internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chat { Content = content, Model = chatResponse.ModelId ?? "unknown", - Role = lastMessage?.Role == ChatRole.User ? "user" : "assistant", + Role = lastMessage?.Role == ChatRole.User ? Role.User : Role.Assistant, StopReason = chatResponse.FinishReason == ChatFinishReason.Length ? "maxTokens" : "endTurn", }; } diff --git a/src/ModelContextProtocol/Protocol/Types/CreateMessageResult.cs b/src/ModelContextProtocol/Protocol/Types/CreateMessageResult.cs index 04397bed6..4ae689091 100644 --- a/src/ModelContextProtocol/Protocol/Types/CreateMessageResult.cs +++ b/src/ModelContextProtocol/Protocol/Types/CreateMessageResult.cs @@ -50,5 +50,5 @@ public class CreateMessageResult /// Gets or sets the role of the user who generated the message. /// [JsonPropertyName("role")] - public required string Role { get; init; } + public required Role Role { get; init; } } diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs index 17b4b35ff..aa103c357 100644 --- a/src/ModelContextProtocol/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs @@ -142,7 +142,7 @@ public static async Task RequestSamplingAsync( ModelPreferences = modelPreferences, }, cancellationToken).ConfigureAwait(false); - return new(new ChatMessage(new(result.Role), [result.Content.ToAIContent()])) + return new(new ChatMessage(result.Role is Role.User ? ChatRole.User : ChatRole.Assistant, [result.Content.ToAIContent()])) { ModelId = result.Model, FinishReason = result.StopReason switch diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs index 7913bbb23..280a757d8 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs @@ -219,7 +219,7 @@ public async Task Sampling_Sse_TestServer() return new CreateMessageResult { Model = "test-model", - Role = "assistant", + Role = Role.Assistant, Content = new Content { Type = "text", diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestServerTransport.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestServerTransport.cs index f21660143..0c9cb4377 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestServerTransport.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestServerTransport.cs @@ -72,7 +72,7 @@ private async Task Sampling(JsonRpcRequest request, CancellationToken cancellati await WriteMessageAsync(new JsonRpcResponse { Id = request.Id, - Result = JsonSerializer.SerializeToNode(new CreateMessageResult { Content = new(), Model = "model", Role = "role" }), + Result = JsonSerializer.SerializeToNode(new CreateMessageResult { Content = new(), Model = "model", Role = Role.Assistant }), }, cancellationToken); } diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index 58864a1ba..c9cb3742c 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -84,7 +84,7 @@ public async Task CreateSamplingHandler_ShouldHandleTextMessages(float? temperat Assert.NotNull(result); Assert.Equal("Hello, World!", result.Content.Text); Assert.Equal("test-model", result.Model); - Assert.Equal("assistant", result.Role); + Assert.Equal(Role.Assistant, result.Role); Assert.Equal("endTurn", result.StopReason); } @@ -139,7 +139,7 @@ public async Task CreateSamplingHandler_ShouldHandleImageMessages() Assert.NotNull(result); Assert.Equal(expectedData, result.Content.Data); Assert.Equal("test-model", result.Model); - Assert.Equal("assistant", result.Role); + Assert.Equal(Role.Assistant, result.Role); Assert.Equal("endTurn", result.StopReason); } @@ -201,7 +201,7 @@ public async Task CreateSamplingHandler_ShouldHandleResourceMessages() // Assert Assert.NotNull(result); Assert.Equal("test-model", result.Model); - Assert.Equal(ChatRole.Assistant.ToString(), result.Role); + Assert.Equal(Role.Assistant, result.Role); Assert.Equal("endTurn", result.StopReason); } diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs index 1414f6563..38e01eb62 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs @@ -44,7 +44,7 @@ public async Task CreateAsync_WithCapabilitiesOptions(Type transportType) new CreateMessageResult { Content = new Content { Text = "result" }, Model = "test-model", - Role = "test-role", + Role = Role.User, StopReason = "endTurn" }), }, diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index baca8f1c8..46e928e14 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -374,7 +374,7 @@ public async Task Sampling_Stdio(string clientId) return Task.FromResult(new CreateMessageResult { Model = "test-model", - Role = "assistant", + Role = Role.Assistant, Content = new Content { Type = "text", diff --git a/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs index b5947eaba..149c7d639 100644 --- a/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs +++ b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs @@ -83,7 +83,7 @@ public async Task Sampling_Sse_EverythingServer() return Task.FromResult(new CreateMessageResult { Model = "test-model", - Role = "assistant", + Role = Role.Assistant, Content = new Content { Type = "text", diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 054a47985..dcb688cf6 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -601,7 +601,7 @@ public Task SendRequestAsync(JsonRpcRequest request, Cancellati { Content = new() { Text = "The Eiffel Tower.", Type = "text" }, Model = "amazingmodel", - Role = "assistant", + Role = Role.Assistant, StopReason = "endTurn", }; diff --git a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs index ed9b2e04d..aae1438b9 100644 --- a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs +++ b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs @@ -72,7 +72,7 @@ private async Task Sampling(JsonRpcRequest request, CancellationToken cancellati await WriteMessageAsync(new JsonRpcResponse { Id = request.Id, - Result = JsonSerializer.SerializeToNode(new CreateMessageResult { Content = new(), Model = "model", Role = "role" }), + Result = JsonSerializer.SerializeToNode(new CreateMessageResult { Content = new(), Model = "model", Role = Role.User }), }, cancellationToken); }