diff --git a/Directory.Packages.props b/Directory.Packages.props index c03c69e57..3bb00dbef 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -1,36 +1,59 @@ true - 9.0.3 10.0.0-preview.2.25163.2 - 9.0.3 9.3.0-preview.1.25161.3 + + + + + + + + + + + + + + + + + + + + + + + + + + - - - runtime; build; native; contentfiles; analyzers; buildtransitive - all - - - - - + - - + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + - - - + + + + + diff --git a/samples/QuickstartClient/QuickstartClient.csproj b/samples/QuickstartClient/QuickstartClient.csproj index b820bedc1..076e28b16 100644 --- a/samples/QuickstartClient/QuickstartClient.csproj +++ b/samples/QuickstartClient/QuickstartClient.csproj @@ -15,6 +15,7 @@ + diff --git a/src/Common/Polyfills/System/Threading/Channels/ChannelExtensions.cs b/src/Common/Polyfills/System/Threading/Channels/ChannelExtensions.cs new file mode 100644 index 000000000..89822eff1 --- /dev/null +++ b/src/Common/Polyfills/System/Threading/Channels/ChannelExtensions.cs @@ -0,0 +1,17 @@ +using System.Runtime.CompilerServices; + +namespace System.Threading.Channels; + +internal static class ChannelExtensions +{ + public static async IAsyncEnumerable ReadAllAsync(this ChannelReader reader, [EnumeratorCancellation] CancellationToken cancellationToken) + { + while (await reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + { + while (reader.TryRead(out var item)) + { + yield return item; + } + } + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj b/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj index 9028365c7..1bc4feb01 100644 --- a/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj +++ b/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj @@ -1,7 +1,7 @@  - net8.0 + net9.0;net8.0 enable enable true diff --git a/src/ModelContextProtocol/Client/McpClientExtensions.cs b/src/ModelContextProtocol/Client/McpClientExtensions.cs index 63af1cf64..c98f76191 100644 --- a/src/ModelContextProtocol/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol/Client/McpClientExtensions.cs @@ -24,7 +24,7 @@ public static Task PingAsync(this IMcpClient client, CancellationToken cancellat return client.SendRequestAsync( RequestMethods.Ping, parameters: null, - McpJsonUtilities.JsonContext.Default.Object, + McpJsonUtilities.JsonContext.Default.Object!, McpJsonUtilities.JsonContext.Default.Object, cancellationToken: cancellationToken); } @@ -52,7 +52,7 @@ public static async Task> ListToolsAsync( { var toolResults = await client.SendRequestAsync( RequestMethods.ToolsList, - CreateCursorDictionary(cursor), + CreateCursorDictionary(cursor)!, McpJsonUtilities.JsonContext.Default.DictionaryStringObject, McpJsonUtilities.JsonContext.Default.ListToolsResult, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -96,7 +96,7 @@ public static async IAsyncEnumerable EnumerateToolsAsync( { var toolResults = await client.SendRequestAsync( RequestMethods.ToolsList, - CreateCursorDictionary(cursor), + CreateCursorDictionary(cursor)!, McpJsonUtilities.JsonContext.Default.DictionaryStringObject, McpJsonUtilities.JsonContext.Default.ListToolsResult, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -128,7 +128,7 @@ public static async Task> ListPromptsAsync( { var promptResults = await client.SendRequestAsync( RequestMethods.PromptsList, - CreateCursorDictionary(cursor), + CreateCursorDictionary(cursor)!, McpJsonUtilities.JsonContext.Default.DictionaryStringObject, McpJsonUtilities.JsonContext.Default.ListPromptsResult, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -166,7 +166,7 @@ public static async IAsyncEnumerable EnumeratePromptsAsync( { var promptResults = await client.SendRequestAsync( RequestMethods.PromptsList, - CreateCursorDictionary(cursor), + CreateCursorDictionary(cursor)!, McpJsonUtilities.JsonContext.Default.DictionaryStringObject, McpJsonUtilities.JsonContext.Default.ListPromptsResult, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -230,7 +230,7 @@ public static async Task> ListResourceTemplatesAsync( { var templateResults = await client.SendRequestAsync( RequestMethods.ResourcesTemplatesList, - CreateCursorDictionary(cursor), + CreateCursorDictionary(cursor)!, McpJsonUtilities.JsonContext.Default.DictionaryStringObject, McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -271,7 +271,7 @@ public static async IAsyncEnumerable EnumerateResourceTemplate { var templateResults = await client.SendRequestAsync( RequestMethods.ResourcesTemplatesList, - CreateCursorDictionary(cursor), + CreateCursorDictionary(cursor)!, McpJsonUtilities.JsonContext.Default.DictionaryStringObject, McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -304,7 +304,7 @@ public static async Task> ListResourcesAsync( { var resourceResults = await client.SendRequestAsync( RequestMethods.ResourcesList, - CreateCursorDictionary(cursor), + CreateCursorDictionary(cursor)!, McpJsonUtilities.JsonContext.Default.DictionaryStringObject, McpJsonUtilities.JsonContext.Default.ListResourcesResult, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -345,7 +345,7 @@ public static async IAsyncEnumerable EnumerateResourcesAsync( { var resourceResults = await client.SendRequestAsync( RequestMethods.ResourcesList, - CreateCursorDictionary(cursor), + CreateCursorDictionary(cursor)!, McpJsonUtilities.JsonContext.Default.DictionaryStringObject, McpJsonUtilities.JsonContext.Default.ListResourcesResult, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -374,7 +374,7 @@ public static Task ReadResourceAsync( return client.SendRequestAsync( RequestMethods.ResourcesRead, - new Dictionary { ["uri"] = uri }, + new Dictionary { ["uri"] = uri }, McpJsonUtilities.JsonContext.Default.DictionaryStringObject, McpJsonUtilities.JsonContext.Default.ReadResourceResult, cancellationToken: cancellationToken); @@ -401,7 +401,7 @@ public static Task GetCompletionAsync(this IMcpClient client, Re return client.SendRequestAsync( RequestMethods.CompletionComplete, - new Dictionary + new Dictionary { ["ref"] = reference, ["argument"] = new Argument { Name = argumentName, Value = argumentValue } @@ -424,7 +424,7 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, return client.SendRequestAsync( RequestMethods.ResourcesSubscribe, - new Dictionary { ["uri"] = uri }, + new Dictionary { ["uri"] = uri }, McpJsonUtilities.JsonContext.Default.DictionaryStringObject, McpJsonUtilities.JsonContext.Default.EmptyResult, cancellationToken: cancellationToken); @@ -443,7 +443,7 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string u return client.SendRequestAsync( RequestMethods.ResourcesUnsubscribe, - new Dictionary { ["uri"] = uri }, + new Dictionary { ["uri"] = uri }, McpJsonUtilities.JsonContext.Default.DictionaryStringObject, McpJsonUtilities.JsonContext.Default.EmptyResult, cancellationToken: cancellationToken); @@ -629,7 +629,7 @@ public static Task SetLoggingLevel(this IMcpClient client, LoggingLevel level, C return client.SendRequestAsync( RequestMethods.LoggingSetLevel, - new Dictionary { ["level"] = level }, + new Dictionary { ["level"] = level }, McpJsonUtilities.JsonContext.Default.DictionaryStringObject, McpJsonUtilities.JsonContext.Default.EmptyResult, cancellationToken: cancellationToken); diff --git a/src/ModelContextProtocol/Diagnostics.cs b/src/ModelContextProtocol/Diagnostics.cs index c8845228d..5b4e31f4d 100644 --- a/src/ModelContextProtocol/Diagnostics.cs +++ b/src/ModelContextProtocol/Diagnostics.cs @@ -7,10 +7,20 @@ internal static class Diagnostics { internal static ActivitySource ActivitySource { get; } = new("Experimental.ModelContextProtocol"); + internal static Meter Meter { get; } = new("Experimental.ModelContextProtocol"); + + internal static Histogram CreateDurationHistogram(string name, string description, bool longBuckets) => + Diagnostics.Meter.CreateHistogram(name, "s", description +#if NET9_0_OR_GREATER + , advice: longBuckets ? LongSecondsBucketBoundaries : ShortSecondsBucketBoundaries +#endif + ); + +#if NET9_0_OR_GREATER /// /// Follows boundaries from http.server.request.duration/http.client.request.duration /// - internal static InstrumentAdvice ShortSecondsBucketBoundaries { get; } = new() + private static InstrumentAdvice ShortSecondsBucketBoundaries { get; } = new() { HistogramBucketBoundaries = [0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1, 2.5, 5, 7.5, 10], }; @@ -19,11 +29,9 @@ internal static class Diagnostics /// Not based on a standard. Larger bucket sizes for longer lasting operations, e.g. HTTP connection duration. /// See https://github.com/open-telemetry/semantic-conventions/issues/336 /// - internal static InstrumentAdvice LongSecondsBucketBoundaries { get; } = new() + private static InstrumentAdvice LongSecondsBucketBoundaries { get; } = new() { HistogramBucketBoundaries = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1, 2, 5, 10, 30, 60, 120, 300], }; - - internal static Meter Meter { get; } = new("Experimental.ModelContextProtocol"); - +#endif } diff --git a/src/ModelContextProtocol/ModelContextProtocol.csproj b/src/ModelContextProtocol/ModelContextProtocol.csproj index 6860381de..c120269bd 100644 --- a/src/ModelContextProtocol/ModelContextProtocol.csproj +++ b/src/ModelContextProtocol/ModelContextProtocol.csproj @@ -1,7 +1,7 @@  - net8.0;netstandard2.0 + net9.0;net8.0;netstandard2.0 true true ModelContextProtocol @@ -12,15 +12,8 @@ true - - - - - - - - + @@ -28,6 +21,24 @@ + + + + + + + + + + + + + + + diff --git a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs index aed3b6dfb..d4e39c8a4 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs @@ -53,7 +53,7 @@ private void WriteJsonRpcMessageToBuffer(SseItem item, IBuffer return; } - JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage); + JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage!); } /// diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs index 8359afa65..589e9078b 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs @@ -29,9 +29,16 @@ public StreamClientSessionTransport( _serverInput = serverInput; EndpointName = endpointName; - // Start reading messages in the background + // Start reading messages in the background. We use the rarer pattern of new Task + Start + // in order to ensure that the body of the task will always see _readTask initialized. + // It is then able to reliably null it out on completion. Logger.TransportReadingMessages(endpointName); - _readTask = Task.Run(() => ReadMessagesAsync(_shutdownCts.Token), CancellationToken.None); + var readTask = new Task( + thisRef => ((StreamClientSessionTransport)thisRef!).ReadMessagesAsync(_shutdownCts.Token), + this, + TaskCreationOptions.DenyChildAttach); + _readTask = readTask.Unwrap(); + readTask.Start(); SetConnected(true); } @@ -117,6 +124,7 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken) } finally { + _readTask = null; await CleanupAsync(cancellationToken).ConfigureAwait(false); } } diff --git a/src/ModelContextProtocol/Protocol/Types/ContextInclusion.cs b/src/ModelContextProtocol/Protocol/Types/ContextInclusion.cs index db45dd48f..bb3bae905 100644 --- a/src/ModelContextProtocol/Protocol/Types/ContextInclusion.cs +++ b/src/ModelContextProtocol/Protocol/Types/ContextInclusion.cs @@ -6,7 +6,7 @@ namespace ModelContextProtocol.Protocol.Types; /// A request to include context from one or more MCP servers (including the caller), to be attached to the prompt. /// See the schema for details /// -[JsonConverter(typeof(JsonStringEnumConverter))] +[JsonConverter(typeof(CustomizableJsonStringEnumConverter))] public enum ContextInclusion { /// diff --git a/src/ModelContextProtocol/Protocol/Types/LoggingLevel.cs b/src/ModelContextProtocol/Protocol/Types/LoggingLevel.cs index df8c4c75a..8098dbbd3 100644 --- a/src/ModelContextProtocol/Protocol/Types/LoggingLevel.cs +++ b/src/ModelContextProtocol/Protocol/Types/LoggingLevel.cs @@ -7,7 +7,7 @@ namespace ModelContextProtocol.Protocol.Types; /// These map to syslog message severities, as specified in RFC-5424: /// https://datatracker.ietf.org/doc/html/rfc5424#section-6.2.1 /// -[JsonConverter(typeof(JsonStringEnumConverter))] +[JsonConverter(typeof(CustomizableJsonStringEnumConverter))] public enum LoggingLevel { /// Detailed debug information, typically only valuable to developers. diff --git a/src/ModelContextProtocol/Protocol/Types/Role.cs b/src/ModelContextProtocol/Protocol/Types/Role.cs index 1cb35ea5b..c025f61ad 100644 --- a/src/ModelContextProtocol/Protocol/Types/Role.cs +++ b/src/ModelContextProtocol/Protocol/Types/Role.cs @@ -6,7 +6,7 @@ namespace ModelContextProtocol.Protocol.Types; /// Represents the type of role in the conversation. /// See the schema for details /// -[JsonConverter(typeof(JsonStringEnumConverter))] +[JsonConverter(typeof(CustomizableJsonStringEnumConverter))] public enum Role { /// diff --git a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs index 9c6744ea8..25bffe5ed 100644 --- a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs +++ b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs @@ -133,9 +133,9 @@ public static AIFunction Create(MethodInfo method, object? target, TemporaryAIFu /// /// /// Return values are serialized to using 's - /// . Arguments that are not already of the expected type are + /// . Arguments that are not already of the expected type are /// marshaled to the expected type via JSON and using 's - /// . If the argument is a , + /// . If the argument is a , /// , or , it is deserialized directly. If the argument is anything else unknown, /// it is round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. /// diff --git a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactoryOptions.cs b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactoryOptions.cs index e1f712d1f..91403b3b9 100644 --- a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactoryOptions.cs +++ b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactoryOptions.cs @@ -72,7 +72,7 @@ public TemporaryAIFunctionFactoryOptions() /// Gets or sets a delegate used to determine the returned by . /// /// - /// By default, the return value of invoking the method wrapped into an by + /// By default, the return value of invoking the method wrapped into an by /// is then JSON serialized, with the resulting returned from the method. /// This default behavior is ideal for the common case where the result will be passed back to an AI service. However, if the caller /// requires more control over the result's marshaling, the property may be set to a delegate that is @@ -82,7 +82,7 @@ public TemporaryAIFunctionFactoryOptions() /// /// When set, the delegate is invoked even for -returning methods, in which case it is invoked with /// a argument. By default, is returned from the - /// method for instances produced by to wrap + /// method for instances produced by to wrap /// -returning methods). /// /// diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs index e9aed2f32..d5e4f930d 100644 --- a/src/ModelContextProtocol/Shared/McpSession.cs +++ b/src/ModelContextProtocol/Shared/McpSession.cs @@ -12,6 +12,7 @@ using System.Diagnostics.Metrics; using System.Text.Json; using System.Text.Json.Nodes; +using System.Threading.Channels; namespace ModelContextProtocol.Shared; @@ -20,15 +21,14 @@ namespace ModelContextProtocol.Shared; /// internal sealed class McpSession : IDisposable { - private static readonly Histogram s_clientSessionDuration = Diagnostics.Meter.CreateHistogram( - "mcp.client.session.duration", "s", "Measures the duration of a client session.", advice: Diagnostics.LongSecondsBucketBoundaries); - private static readonly Histogram s_serverSessionDuration = Diagnostics.Meter.CreateHistogram( - "mcp.server.session.duration", "s", "Measures the duration of a server session.", advice: Diagnostics.LongSecondsBucketBoundaries); - - private static readonly Histogram s_serverRequestDuration = Diagnostics.Meter.CreateHistogram( - "rpc.server.duration", "s", "Measures the duration of inbound RPC.", advice: Diagnostics.ShortSecondsBucketBoundaries); - private static readonly Histogram s_clientRequestDuration = Diagnostics.Meter.CreateHistogram( - "rpc.client.duration", "s", "Measures the duration of outbound RPC.", advice: Diagnostics.ShortSecondsBucketBoundaries); + private static readonly Histogram s_clientSessionDuration = Diagnostics.CreateDurationHistogram( + "mcp.client.session.duration", "Measures the duration of a client session.", longBuckets: true); + private static readonly Histogram s_serverSessionDuration = Diagnostics.CreateDurationHistogram( + "mcp.server.session.duration", "Measures the duration of a server session.", longBuckets: true); + private static readonly Histogram s_clientRequestDuration = Diagnostics.CreateDurationHistogram( + "rpc.client.duration", "Measures the duration of outbound RPC.", longBuckets: false); + private static readonly Histogram s_serverRequestDuration = Diagnostics.CreateDurationHistogram( + "rpc.server.duration", "Measures the duration of inbound RPC.", longBuckets: false); private readonly bool _isServer; private readonly string _transportKind; @@ -174,6 +174,14 @@ await _transport.SendMessageAsync(new JsonRpcError // Normal shutdown _logger.EndpointMessageProcessingCancelled(EndpointName); } + finally + { + // Fail any pending requests, as they'll never be satisfied. + foreach (var entry in _pendingRequests) + { + entry.Value.TrySetException(new InvalidOperationException("The server shut down unexpectedly.")); + } + } } private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken) diff --git a/src/ModelContextProtocol/Utils/Json/CustomizableJsonStringEnumConverter.cs b/src/ModelContextProtocol/Utils/Json/CustomizableJsonStringEnumConverter.cs new file mode 100644 index 000000000..e9c26f18c --- /dev/null +++ b/src/ModelContextProtocol/Utils/Json/CustomizableJsonStringEnumConverter.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Reflection; + +// NOTE: +// This is a temporary workaround for lack of System.Text.Json's JsonStringEnumConverter +// 9.x support for JsonStringEnumMemberNameAttribute. Once all builds use the System.Text.Json 9.x +// version, this whole file can be removed. + +namespace System.Text.Json.Serialization; + +internal sealed class CustomizableJsonStringEnumConverter<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum> : + JsonStringEnumConverter where TEnum : struct, Enum +{ +#if !NET9_0_OR_GREATER + public CustomizableJsonStringEnumConverter() : + base(namingPolicy: ResolveNamingPolicy()) + { + } + + private static JsonNamingPolicy? ResolveNamingPolicy() + { + var map = typeof(TEnum).GetFields(BindingFlags.Public | BindingFlags.Static) + .Select(f => (f.Name, AttributeName: f.GetCustomAttribute()?.Name)) + .Where(pair => pair.AttributeName != null) + .ToDictionary(pair => pair.Name, pair => pair.AttributeName); + + return map.Count > 0 ? new EnumMemberNamingPolicy(map!) : null; + } + + private sealed class EnumMemberNamingPolicy(Dictionary map) : JsonNamingPolicy + { + public override string ConvertName(string name) => + map.TryGetValue(name, out string? newName) ? + newName : + name; + } +#endif +} + +#if !NET9_0_OR_GREATER +/// +/// Determines the string value that should be used when serializing an enum member. +/// +[AttributeUsage(AttributeTargets.Field, AllowMultiple = false)] +internal sealed class JsonStringEnumMemberNameAttribute : Attribute +{ + /// + /// Creates new attribute instance with a specified enum member name. + /// + /// The name to apply to the current enum member. + public JsonStringEnumMemberNameAttribute(string name) + { + Name = name; + } + + /// + /// Gets the name of the enum member. + /// + public string Name { get; } +} +#endif \ No newline at end of file diff --git a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs index 5af25a458..e6245c7fa 100644 --- a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs +++ b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs @@ -78,7 +78,6 @@ internal static bool IsValidMcpToolSchema(JsonElement element) // Keep in sync with CreateDefaultOptions above. [JsonSourceGenerationOptions(JsonSerializerDefaults.Web, - UseStringEnumConverter = true, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, NumberHandling = JsonNumberHandling.AllowReadingFromString)] diff --git a/tests/ModelContextProtocol.TestServer/ModelContextProtocol.TestServer.csproj b/tests/ModelContextProtocol.TestServer/ModelContextProtocol.TestServer.csproj index 67dc6a197..fb6320a07 100644 --- a/tests/ModelContextProtocol.TestServer/ModelContextProtocol.TestServer.csproj +++ b/tests/ModelContextProtocol.TestServer/ModelContextProtocol.TestServer.csproj @@ -2,7 +2,7 @@ Exe - net8.0 + net9.0;net8.0 enable enable TestServer @@ -10,8 +10,11 @@ + + + diff --git a/tests/ModelContextProtocol.TestSseServer/ModelContextProtocol.TestSseServer.csproj b/tests/ModelContextProtocol.TestSseServer/ModelContextProtocol.TestSseServer.csproj index 6633ad4ad..3015fd554 100644 --- a/tests/ModelContextProtocol.TestSseServer/ModelContextProtocol.TestSseServer.csproj +++ b/tests/ModelContextProtocol.TestSseServer/ModelContextProtocol.TestSseServer.csproj @@ -2,7 +2,7 @@ Exe - net8.0 + net9.0;net8.0 enable enable TestSseServer @@ -10,6 +10,7 @@ + diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index 424b51da6..d5a24c997 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -377,10 +377,12 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide { Console.WriteLine("Starting server..."); + int port = args.Length > 0 && uint.TryParse(args[0], out var parsedPort) ? (int)parsedPort : 3001; + var builder = WebApplication.CreateSlimBuilder(args); builder.WebHost.ConfigureKestrel(options => { - options.ListenLocalhost(3001); + options.ListenLocalhost(port); }); ConfigureSerilog(builder.Logging); diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index d131aca54..a6d0a9b61 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -542,21 +542,11 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() [MemberData(nameof(GetClients))] public async Task SetLoggingLevel_ReceivesLoggingMessages(string clientId) { - // arrange - JsonSerializerOptions jsonSerializerOptions = new(JsonSerializerDefaults.Web) - { - TypeInfoResolver = new DefaultJsonTypeInfoResolver(), - Converters = { new JsonStringEnumConverter() }, - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, - NumberHandling = JsonNumberHandling.AllowReadingFromString, - Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping, - }; - TaskCompletionSource receivedNotification = new(); await using var client = await _fixture.CreateClientAsync(clientId); client.AddNotificationHandler(NotificationMethods.LoggingMessageNotification, (notification) => { - var loggingMessageNotificationParameters = JsonSerializer.Deserialize(notification.Params, jsonSerializerOptions); + var loggingMessageNotificationParameters = JsonSerializer.Deserialize(notification.Params); if (loggingMessageNotificationParameters is not null) { receivedNotification.TrySetResult(true); diff --git a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj index 5f839d25d..c8dc1ca90 100644 --- a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj +++ b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj @@ -1,7 +1,7 @@  - net8.0 + net9.0;net8.0 enable enable Latest @@ -20,6 +20,7 @@ runtime; build; native; contentfiles; analyzers; buildtransitive all + diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index f6ecee8f9..0622e656e 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -8,11 +8,25 @@ namespace ModelContextProtocol.Tests; public class SseIntegrationTests(ITestOutputHelper outputHelper) : LoggedTest(outputHelper) { + /// Port number to be grabbed by the next test. + private static int s_nextPort = 3000; + + // If the tests run concurrently against different versions of the runtime, tests can conflict with + // each other in the ports set up for interacting with containers. Ensure that such suites running + // against different TFMs use different port numbers. + private static readonly int s_portOffset = 1000 * (Environment.Version.Major switch + { + int v when v >= 8 => Environment.Version.Major - 7, + _ => 0, + }); + + private static int CreatePortNumber() => Interlocked.Increment(ref s_nextPort) + s_portOffset; + [Fact] public async Task ConnectAndReceiveMessage_InMemoryServer() { // Arrange - await using InMemoryTestSseServer server = new(logger: LoggerFactory.CreateLogger()); + await using InMemoryTestSseServer server = new(CreatePortNumber(), LoggerFactory.CreateLogger()); await server.StartAsync(); var defaultOptions = new McpClientOptions @@ -26,7 +40,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer() Name = "In-memory Test Server", TransportType = TransportTypes.Sse, TransportOptions = [], - Location = "http://localhost:5000/sse" + Location = $"http://localhost:{server.Port}/sse" }; // Act @@ -52,7 +66,7 @@ public async Task ConnectAndReceiveMessage_EverythingServerWithSse() { Assert.SkipWhen(!EverythingSseServerFixture.IsDockerAvailable, "docker is not available"); - int port = 3001; + int port = CreatePortNumber(); await using var fixture = new EverythingSseServerFixture(port); await fixture.StartAsync(); @@ -89,7 +103,7 @@ public async Task Sampling_Sse_EverythingServer() { Assert.SkipWhen(!EverythingSseServerFixture.IsDockerAvailable, "docker is not available"); - int port = 3002; + int port = CreatePortNumber(); await using var fixture = new EverythingSseServerFixture(port); await fixture.StartAsync(); @@ -157,11 +171,10 @@ public async Task Sampling_Sse_EverythingServer() public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventUri() { // Arrange - await using InMemoryTestSseServer server = new(logger: LoggerFactory.CreateLogger()); + await using InMemoryTestSseServer server = new(CreatePortNumber(), LoggerFactory.CreateLogger()); server.UseFullUrlForEndpointEvent = true; await server.StartAsync(); - var defaultOptions = new McpClientOptions { ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } @@ -173,7 +186,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU Name = "In-memory Test Server", TransportType = TransportTypes.Sse, TransportOptions = [], - Location = "http://localhost:5000/sse" + Location = $"http://localhost:{server.Port}/sse" }; // Act @@ -197,7 +210,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU public async Task ConnectAndReceiveNotification_InMemoryServer() { // Arrange - await using InMemoryTestSseServer server = new(logger: LoggerFactory.CreateLogger()); + await using InMemoryTestSseServer server = new(CreatePortNumber(), LoggerFactory.CreateLogger()); await server.StartAsync(); @@ -212,7 +225,7 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() Name = "In-memory Test Server", TransportType = TransportTypes.Sse, TransportOptions = [], - Location = "http://localhost:5000/sse" + Location = $"http://localhost:{server.Port}/sse" }; // Act diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs index 893ac793b..238c7747c 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs @@ -17,16 +17,19 @@ public class SseServerIntegrationTestFixture : IAsyncDisposable public SseServerIntegrationTestFixture() { + // Ensure that test suites running against different TFMs and possibly concurrently use different port numbers. + int port = 3001 + Environment.Version.Major; + DefaultConfig = new McpServerConfig { Id = "test_server", Name = "TestServer", TransportType = TransportTypes.Sse, TransportOptions = [], - Location = "http://localhost:3001/sse" + Location = $"http://localhost:{port}/sse" }; - _serverTask = Program.MainAsync([], new XunitLoggerProvider(_delegatingTestOutputHelper), _stopCts.Token); + _serverTask = Program.MainAsync([port.ToString()], new XunitLoggerProvider(_delegatingTestOutputHelper), _stopCts.Token); } public static McpClientOptions CreateDefaultClientOptions() => new() diff --git a/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs b/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs index 651d85e2c..7d7122a8e 100644 --- a/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs +++ b/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs @@ -26,6 +26,8 @@ public sealed class InMemoryTestSseServer : IAsyncDisposable public InMemoryTestSseServer(int port = 5000, ILogger? logger = null) { + Port = port; + _listener = new HttpListener(); _listener.Prefixes.Add($"http://localhost:{port}/"); _cts = new CancellationTokenSource(); @@ -35,6 +37,8 @@ public InMemoryTestSseServer(int port = 5000, ILogger? lo _messagePath = "/message"; } + public int Port { get; } + /// /// This is to be able to use the full URL for the endpoint event. ///