diff --git a/Directory.Packages.props b/Directory.Packages.props index 554361cbe..c03c69e57 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -33,6 +33,8 @@ + + diff --git a/src/ModelContextProtocol/Diagnostics.cs b/src/ModelContextProtocol/Diagnostics.cs new file mode 100644 index 000000000..a85aa0643 --- /dev/null +++ b/src/ModelContextProtocol/Diagnostics.cs @@ -0,0 +1,29 @@ +using System.Diagnostics; +using System.Diagnostics.Metrics; + +namespace ModelContextProtocol; + +internal static class Diagnostics +{ + internal static ActivitySource ActivitySource { get; } = new("ModelContextProtocol"); + + /// + /// Follows boundaries from http.server.request.duration/http.client.request.duration + /// + internal static InstrumentAdvice ShortSecondsBucketBoundaries { get; } = new() + { + HistogramBucketBoundaries = [0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1, 2.5, 5, 7.5, 10], + }; + + /// + /// Not based on a standard. Larger bucket sizes for longer lasting operations, e.g. HTTP connection duration. + /// See https://github.com/open-telemetry/semantic-conventions/issues/336 + /// + internal static InstrumentAdvice LongSecondsBucketBoundaries { get; } = new() + { + HistogramBucketBoundaries = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1, 2, 5, 10, 30, 60, 120, 300], + }; + + internal static Meter Meter { get; } = new("ModelContextProtocol"); + +} diff --git a/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs b/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs index 6183eb92e..b8a481086 100644 --- a/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs +++ b/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs @@ -12,15 +12,15 @@ namespace ModelContextProtocol.Protocol.Messages; [JsonConverter(typeof(Converter))] public readonly struct ProgressToken : IEquatable { - /// The id, either a string or a boxed long or null. - private readonly object? _id; + /// The token, either a string or a boxed long or null. + private readonly object? _token; /// Initializes a new instance of the with a specified value. /// The required ID value. public ProgressToken(string value) { Throw.IfNull(value); - _id = value; + _token = value; } /// Initializes a new instance of the with a specified value. @@ -28,28 +28,29 @@ public ProgressToken(string value) public ProgressToken(long value) { // Box the long. Progress tokens are almost always strings in practice, so this should be rare. - _id = value; + _token = value; } - /// Gets whether the identifier is uninitialized. - public bool IsDefault => _id is null; + /// Gets the underlying object for this token. + /// This will either be a , a boxed , or . + public object? Token => _token; /// public override string? ToString() => - _id is string stringValue ? $"\"{stringValue}\"" : - _id is long longValue ? longValue.ToString(CultureInfo.InvariantCulture) : + _token is string stringValue ? $"{stringValue}" : + _token is long longValue ? longValue.ToString(CultureInfo.InvariantCulture) : null; /// /// Compares this ProgressToken to another ProgressToken. /// - public bool Equals(ProgressToken other) => Equals(_id, other._id); + public bool Equals(ProgressToken other) => Equals(_token, other._token); /// public override bool Equals(object? obj) => obj is ProgressToken other && Equals(other); /// - public override int GetHashCode() => _id?.GetHashCode() ?? 0; + public override int GetHashCode() => _token?.GetHashCode() ?? 0; /// /// Compares two ProgressTokens for equality. @@ -83,7 +84,7 @@ public override void Write(Utf8JsonWriter writer, ProgressToken value, JsonSeria { Throw.IfNull(writer); - switch (value._id) + switch (value._token) { case string str: writer.WriteStringValue(str); diff --git a/src/ModelContextProtocol/Protocol/Messages/RequestId.cs b/src/ModelContextProtocol/Protocol/Messages/RequestId.cs index e6fc74418..550428ff9 100644 --- a/src/ModelContextProtocol/Protocol/Messages/RequestId.cs +++ b/src/ModelContextProtocol/Protocol/Messages/RequestId.cs @@ -31,12 +31,13 @@ public RequestId(long value) _id = value; } - /// Gets whether the identifier is uninitialized. - public bool IsDefault => _id is null; + /// Gets the underlying object for this id. + /// This will either be a , a boxed , or . + public object? Id => _id; /// public override string ToString() => - _id is string stringValue ? $"\"{stringValue}\"" : + _id is string stringValue ? stringValue : _id is long longValue ? longValue.ToString(CultureInfo.InvariantCulture) : string.Empty; diff --git a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs index 915dfa18e..df551c4e3 100644 --- a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs @@ -3,6 +3,7 @@ using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Server; using ModelContextProtocol.Utils; using System.Diagnostics.CodeAnalysis; @@ -63,7 +64,7 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella protected void StartSession(ITransport sessionTransport) { _sessionCts = new CancellationTokenSource(); - _session = new McpSession(sessionTransport, EndpointName, _requestHandlers, _notificationHandlers, _logger); + _session = new McpSession(this is IMcpServer, sessionTransport, EndpointName, _requestHandlers, _notificationHandlers, _logger); MessageProcessingTask = _session.ProcessMessagesAsync(_sessionCts.Token); } diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs index 97cbcb592..dc3d62ca0 100644 --- a/src/ModelContextProtocol/Shared/McpSession.cs +++ b/src/ModelContextProtocol/Shared/McpSession.cs @@ -4,9 +4,12 @@ using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Server; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.Metrics; using System.Text.Json; namespace ModelContextProtocol.Shared; @@ -16,9 +19,22 @@ namespace ModelContextProtocol.Shared; /// internal sealed class McpSession : IDisposable { + private static readonly Histogram s_clientSessionDuration = Diagnostics.Meter.CreateHistogram( + "mcp.client.session.duration", "s", "Measures the duration of a client session.", advice: Diagnostics.LongSecondsBucketBoundaries); + private static readonly Histogram s_serverSessionDuration = Diagnostics.Meter.CreateHistogram( + "mcp.server.session.duration", "s", "Measures the duration of a server session.", advice: Diagnostics.LongSecondsBucketBoundaries); + + private static readonly Histogram s_serverRequestDuration = Diagnostics.Meter.CreateHistogram( + "rpc.server.duration", "s", "Measures the duration of inbound RPC.", advice: Diagnostics.ShortSecondsBucketBoundaries); + private static readonly Histogram s_clientRequestDuration = Diagnostics.Meter.CreateHistogram( + "rpc.client.duration", "s", "Measures the duration of outbound RPC.", advice: Diagnostics.ShortSecondsBucketBoundaries); + + private readonly bool _isServer; + private readonly string _transportKind; private readonly ITransport _transport; private readonly RequestHandlers _requestHandlers; private readonly NotificationHandlers _notificationHandlers; + private readonly long _sessionStartingTimestamp = Stopwatch.GetTimestamp(); /// Collection of requests sent on this session and waiting for responses. private readonly ConcurrentDictionary> _pendingRequests = []; @@ -36,12 +52,14 @@ internal sealed class McpSession : IDisposable /// /// Initializes a new instance of the class. /// + /// true if this is a server; false if it's a client. /// An MCP transport implementation. /// The name of the endpoint for logging and debug purposes. /// A collection of request handlers. /// A collection of notification handlers. /// The logger. public McpSession( + bool isServer, ITransport transport, string endpointName, RequestHandlers requestHandlers, @@ -50,6 +68,15 @@ public McpSession( { Throw.IfNull(transport); + _transportKind = transport switch + { + StdioClientSessionTransport or StdioServerTransport => "stdio", + StreamClientSessionTransport or StreamServerTransport => "stream", + SseClientSessionTransport or SseResponseStreamTransport => "sse", + _ => "unknownTransport" + }; + + _isServer = isServer; _transport = transport; EndpointName = endpointName; _requestHandlers = requestHandlers; @@ -121,7 +148,7 @@ await _transport.SendMessageAsync(new JsonRpcError JsonRpc = "2.0", Error = new JsonRpcErrorDetail { - Code = ErrorCodes.InternalError, + Code = (ex as McpServerException)?.ErrorCode ?? ErrorCodes.InternalError, Message = ex.Message } }, cancellationToken).ConfigureAwait(false); @@ -152,23 +179,55 @@ await _transport.SendMessageAsync(new JsonRpcError private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken) { - switch (message) + Histogram durationMetric = _isServer ? s_serverRequestDuration : s_clientRequestDuration; + string method = GetMethodName(message); + + long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; + Activity? activity = Diagnostics.ActivitySource.HasListeners() ? + Diagnostics.ActivitySource.StartActivity(CreateActivityName(method)) : + null; + + TagList tags = default; + bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; + try { - case JsonRpcRequest request: - await HandleRequest(request, cancellationToken).ConfigureAwait(false); - break; + if (addTags) + { + AddStandardTags(ref tags, method); + } - case IJsonRpcMessageWithId messageWithId: - HandleMessageWithId(message, messageWithId); - break; + switch (message) + { + case JsonRpcRequest request: + if (addTags) + { + AddRpcRequestTags(ref tags, activity, request); + } - case JsonRpcNotification notification: - await HandleNotification(notification).ConfigureAwait(false); - break; + await HandleRequest(request, cancellationToken).ConfigureAwait(false); + break; - default: - _logger.EndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name); - break; + case JsonRpcNotification notification: + await HandleNotification(notification).ConfigureAwait(false); + break; + + case IJsonRpcMessageWithId messageWithId: + HandleMessageWithId(message, messageWithId); + break; + + default: + _logger.EndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name); + break; + } + } + catch (Exception e) when (addTags) + { + AddExceptionTags(ref tags, e); + throw; + } + finally + { + FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); } } @@ -212,7 +271,7 @@ private async Task HandleNotification(JsonRpcNotification notification) private void HandleMessageWithId(IJsonRpcMessage message, IJsonRpcMessageWithId messageWithId) { - if (messageWithId.Id.IsDefault) + if (messageWithId.Id.Id is null) { _logger.RequestHasInvalidId(EndpointName); } @@ -229,22 +288,21 @@ private void HandleMessageWithId(IJsonRpcMessage message, IJsonRpcMessageWithId private async Task HandleRequest(JsonRpcRequest request, CancellationToken cancellationToken) { - if (_requestHandlers.TryGetValue(request.Method, out var handler)) - { - _logger.RequestHandlerCalled(EndpointName, request.Method); - var result = await handler(request, cancellationToken).ConfigureAwait(false); - _logger.RequestHandlerCompleted(EndpointName, request.Method); - await _transport.SendMessageAsync(new JsonRpcResponse - { - Id = request.Id, - JsonRpc = "2.0", - Result = result - }, cancellationToken).ConfigureAwait(false); - } - else + if (!_requestHandlers.TryGetValue(request.Method, out var handler)) { _logger.NoHandlerFoundForRequest(EndpointName, request.Method); + throw new McpServerException("The method does not exist or is not available.", ErrorCodes.MethodNotFound); } + + _logger.RequestHandlerCalled(EndpointName, request.Method); + var result = await handler(request, cancellationToken).ConfigureAwait(false); + _logger.RequestHandlerCompleted(EndpointName, request.Method); + await _transport.SendMessageAsync(new JsonRpcResponse + { + Id = request.Id, + JsonRpc = "2.0", + Result = result + }, cancellationToken).ConfigureAwait(false); } /// @@ -264,17 +322,33 @@ public async Task SendRequestAsync(JsonRpcRequest request, Can throw new McpClientException("Transport is not connected"); } + Histogram durationMetric = _isServer ? s_serverRequestDuration : s_clientRequestDuration; + string method = request.Method; + + long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; + using Activity? activity = Diagnostics.ActivitySource.HasListeners() ? + Diagnostics.ActivitySource.StartActivity(CreateActivityName(method)) : + null; + // Set request ID - if (request.Id.IsDefault) + if (request.Id.Id is null) { request.Id = new RequestId($"{_id}-{Interlocked.Increment(ref _nextRequestId)}"); } + TagList tags = default; + bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); _pendingRequests[request.Id] = tcs; - try { + if (addTags) + { + AddStandardTags(ref tags, method); + AddRpcRequestTags(ref tags, activity, request); + } + // Expensive logging, use the logging framework to check if the logger is enabled if (_logger.IsEnabled(LogLevel.Debug)) { @@ -319,9 +393,15 @@ public async Task SendRequestAsync(JsonRpcRequest request, Can _logger.RequestInvalidResponseType(EndpointName, request.Method); throw new McpClientException("Invalid response type"); } + catch (Exception ex) when (addTags) + { + AddExceptionTags(ref tags, ex); + throw; + } finally { _pendingRequests.TryRemove(request.Id, out _); + FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); } } @@ -335,21 +415,49 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca throw new McpClientException("Transport is not connected"); } - if (_logger.IsEnabled(LogLevel.Debug)) + Histogram durationMetric = _isServer ? s_serverRequestDuration : s_clientRequestDuration; + string method = GetMethodName(message); + + long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; + using Activity? activity = Diagnostics.ActivitySource.HasListeners() ? + Diagnostics.ActivitySource.StartActivity(CreateActivityName(method)) : + null; + + TagList tags = default; + bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; + + try { - _logger.SendingMessage(EndpointName, JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo())); - } + if (addTags) + { + AddStandardTags(ref tags, method); + } - await _transport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + if (_logger.IsEnabled(LogLevel.Debug)) + { + _logger.SendingMessage(EndpointName, JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo())); + } + + await _transport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); - // If the sent notification was a cancellation notification, cancel the pending request's await, as either the - // server won't be sending a response, or per the specification, the response should be ignored. There are inherent - // race conditions here, so it's possible and allowed for the operation to complete before we get to this point. - if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification && - GetCancelledNotificationParams(notification.Params) is CancelledNotification cn && - _pendingRequests.TryRemove(cn.RequestId, out var tcs)) + // If the sent notification was a cancellation notification, cancel the pending request's await, as either the + // server won't be sending a response, or per the specification, the response should be ignored. There are inherent + // race conditions here, so it's possible and allowed for the operation to complete before we get to this point. + if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification && + GetCancelledNotificationParams(notification.Params) is CancelledNotification cn && + _pendingRequests.TryRemove(cn.RequestId, out var tcs)) + { + tcs.TrySetCanceled(default); + } + } + catch (Exception ex) when (addTags) + { + AddExceptionTags(ref tags, ex); + throw; + } + finally { - tcs.TrySetCanceled(default); + FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); } } @@ -380,13 +488,127 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca } } + private string CreateActivityName(string method) => + $"mcp.{(_isServer ? "server" : "client")}.{_transportKind}/{method}"; + + private static string GetMethodName(IJsonRpcMessage message) => + message switch + { + JsonRpcRequest request => request.Method, + JsonRpcNotification notification => notification.Method, + _ => "unknownMethod", + }; + + private void AddStandardTags(ref TagList tags, string method) + { + tags.Add("session.id", _id); + tags.Add("rpc.system", "jsonrpc"); + tags.Add("rpc.jsonrpc.version", "2.0"); + tags.Add("rpc.method", method); + tags.Add("network.transport", _transportKind); + + // RPC spans convention also includes: + // server.address, server.port, client.address, client.port, network.peer.address, network.peer.port, network.type + } + + private static void AddRpcRequestTags(ref TagList tags, Activity? activity, JsonRpcRequest request) + { + tags.Add("rpc.jsonrpc.request_id", request.Id.ToString()); + + if (request.Params is JsonElement je) + { + switch (request.Method) + { + case RequestMethods.ToolsCall: + case RequestMethods.PromptsGet: + if (je.TryGetProperty("name", out var prop) && prop.ValueKind == JsonValueKind.String) + { + string name = prop.GetString()!; + tags.Add("mcp.request.params.name", name); + if (activity is not null) + { + activity.DisplayName = $"{request.Method}({name})"; + } + } + break; + + case RequestMethods.ResourcesRead: + if (je.TryGetProperty("uri", out prop) && prop.ValueKind == JsonValueKind.String) + { + string uri = prop.GetString()!; + tags.Add("mcp.request.params.uri", uri); + if (activity is not null) + { + activity.DisplayName = $"{request.Method}({uri})"; + } + } + break; + } + } + } + + private static void AddExceptionTags(ref TagList tags, Exception e) + { + tags.Add("error.type", e.GetType().FullName); + tags.Add("rpc.jsonrpc.error_code", + (e as McpClientException)?.ErrorCode is int clientError ? clientError : + (e as McpServerException)?.ErrorCode is int serverError ? serverError : + e is JsonException ? ErrorCodes.ParseError : + ErrorCodes.InternalError); + } + + private static void FinalizeDiagnostics( + Activity? activity, long? startingTimestamp, Histogram durationMetric, ref TagList tags) + { + try + { + if (startingTimestamp is not null) + { + durationMetric.Record(GetElapsed(startingTimestamp.Value).TotalSeconds, tags); + } + + if (activity is { IsAllDataRequested: true }) + { + foreach (var tag in tags) + { + activity.AddTag(tag.Key, tag.Value); + } + } + } + finally + { + activity?.Dispose(); + } + } + public void Dispose() { + Histogram durationMetric = _isServer ? s_serverSessionDuration : s_clientSessionDuration; + if (durationMetric.Enabled) + { + TagList tags = default; + tags.Add("session.id", _id); + tags.Add("network.transport", _transportKind); + durationMetric.Record(GetElapsed(_sessionStartingTimestamp).TotalSeconds, tags); + } + // Complete all pending requests with cancellation foreach (var entry in _pendingRequests) { entry.Value.TrySetCanceled(); } + _pendingRequests.Clear(); } + +#if !NET + private static readonly double s_timestampToTicks = TimeSpan.TicksPerSecond / (double)Stopwatch.Frequency; +#endif + + private static TimeSpan GetElapsed(long startingTimestamp) => +#if NET + Stopwatch.GetElapsedTime(startingTimestamp); +#else + new((long)(s_timestampToTicks * (Stopwatch.GetTimestamp() - startingTimestamp))); +#endif } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 2b160c52e..f86e27ac5 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -640,7 +640,7 @@ public async Task HandlesIProgressParameter() Assert.Equal(10, array.Length); for (int i = 0; i < array.Length; i++) { - Assert.Equal("\"abc123\"", array[i].ProgressToken.ToString()); + Assert.Equal("abc123", array[i].ProgressToken.ToString()); Assert.Equal(i, array[i].Progress.Progress); Assert.Equal(10, array[i].Progress.Total); Assert.Equal($"Progress {i}", array[i].Progress.Message); diff --git a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs new file mode 100644 index 000000000..42188551d --- /dev/null +++ b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs @@ -0,0 +1,79 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server; +using OpenTelemetry.Trace; +using System.Diagnostics; +using System.IO.Pipelines; + +namespace ModelContextProtocol.Tests; + +[Collection(nameof(DisableParallelization))] +public class DiagnosticTests +{ + [Fact] + public async Task Session_TracksActivities() + { + var activities = new List(); + + using (var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource("ModelContextProtocol") + .AddInMemoryExporter(activities) + .Build()) + { + await RunConnected(async (client, server) => + { + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + Assert.NotNull(tools); + Assert.NotEmpty(tools); + + var tool = tools.First(t => t.Name == "DoubleValue"); + await tool.InvokeAsync(new Dictionary() { ["amount"] = 42 }, TestContext.Current.CancellationToken); + }); + } + + Assert.NotEmpty(activities); + + Activity toolCallActivity = activities.First(a => + a.Tags.Any(t => t.Key == "rpc.method" && t.Value == "tools/call")); + Assert.Equal("DoubleValue", toolCallActivity.Tags.First(t => t.Key == "mcp.request.params.name").Value); + } + + private static async Task RunConnected(Func action) + { + Pipe clientToServerPipe = new(), serverToClientPipe = new(); + StreamServerTransport serverTransport = new(clientToServerPipe.Reader.AsStream(), serverToClientPipe.Writer.AsStream()); + StreamClientTransport clientTransport = new(clientToServerPipe.Writer.AsStream(), serverToClientPipe.Reader.AsStream()); + + Task serverTask; + + await using (IMcpServer server = McpServerFactory.Create(serverTransport, new() + { + ServerInfo = new Implementation { Name = "TestServer", Version = "1.0.0" }, + Capabilities = new() + { + Tools = new() + { + ToolCollection = [McpServerTool.Create((int amount) => amount * 2, new() { Name = "DoubleValue", Description = "Doubles the value." })], + } + } + })) + { + serverTask = server.RunAsync(TestContext.Current.CancellationToken); + + await using (IMcpClient client = await McpClientFactory.CreateAsync(new() + { + Id = "TestServer", + Name = "TestServer", + TransportType = TransportTypes.StdIo, + }, + createTransportFunc: (_, __) => clientTransport, + cancellationToken: TestContext.Current.CancellationToken)) + { + await action(client, server); + } + } + + await serverTask; + } +} diff --git a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj index 7a239ef29..5f839d25d 100644 --- a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj +++ b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj @@ -25,6 +25,8 @@ + + diff --git a/tests/ModelContextProtocol.Tests/Protocol/RequestIdTests.cs b/tests/ModelContextProtocol.Tests/Protocol/RequestIdTests.cs new file mode 100644 index 000000000..1df5ccb73 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/RequestIdTests.cs @@ -0,0 +1,38 @@ +using ModelContextProtocol.Protocol.Messages; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Protocol; + +public class RequestIdTests +{ + [Fact] + public void StringCtor_Roundtrips() + { + RequestId id = new("test-id"); + Assert.Equal("test-id", id.ToString()); + Assert.Equal("\"test-id\"", JsonSerializer.Serialize(id)); + Assert.Same("test-id", id.Id); + + Assert.True(id.Equals(new("test-id"))); + Assert.False(id.Equals(new("tEst-id"))); + Assert.Equal("test-id".GetHashCode(), id.GetHashCode()); + + Assert.Equal(id, JsonSerializer.Deserialize(JsonSerializer.Serialize(id))); + } + + [Fact] + public void Int64Ctor_Roundtrips() + { + RequestId id = new(42); + Assert.Equal("42", id.ToString()); + Assert.Equal("42", JsonSerializer.Serialize(id)); + Assert.Equal(42, Assert.IsType(id.Id)); + + Assert.True(id.Equals(new(42))); + Assert.False(id.Equals(new(43))); + Assert.False(id.Equals(new("42"))); + Assert.Equal(42L.GetHashCode(), id.GetHashCode()); + + Assert.Equal(id, JsonSerializer.Deserialize(JsonSerializer.Serialize(id))); + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 97f6b28b6..e4cf89a0e 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -684,7 +684,7 @@ await transport.SendMessageAsync(new JsonRpcNotification var notification = await notificationReceived.Task; var progress = (ProgressNotification)notification.Params!; - Assert.Equal("\"abc\"", progress.ProgressToken.ToString()); + Assert.Equal("abc", progress.ProgressToken.ToString()); Assert.Equal(50, progress.Progress.Progress); Assert.Equal(100, progress.Progress.Total); Assert.Equal("Progress message", progress.Progress.Message); diff --git a/tests/ModelContextProtocol.Tests/TestAttributes.cs b/tests/ModelContextProtocol.Tests/TestAttributes.cs index 4edbce6ec..8a0140db8 100644 --- a/tests/ModelContextProtocol.Tests/TestAttributes.cs +++ b/tests/ModelContextProtocol.Tests/TestAttributes.cs @@ -1,2 +1,9 @@ -// Uncomment to disable parallel test execution -//[assembly: CollectionBehavior(DisableTestParallelization = true)] \ No newline at end of file +// Uncomment to disable parallel test execution for the whole assembly +//[assembly: CollectionBehavior(DisableTestParallelization = true)] + +/// +/// Enables test classes to individually be attributed as [Collection(nameof(DisableParallelization))] +/// to have those tests run non-concurrently with any other tests. +/// +[CollectionDefinition(nameof(DisableParallelization), DisableParallelization = true)] +public sealed class DisableParallelization; \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index 23061cd9c..8fe8e91c1 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -213,7 +213,7 @@ public async Task ReceiveMessagesAsync_Handles_Messages() Assert.True(session.MessageReader.TryRead(out var message)); Assert.NotNull(message); Assert.IsType(message); - Assert.Equal("\"44\"", ((JsonRpcRequest)message).Id.ToString()); + Assert.Equal("44", ((JsonRpcRequest)message).Id.ToString()); } [Fact]