diff --git a/Directory.Packages.props b/Directory.Packages.props index 0874fa268..67d3e5274 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -35,7 +35,7 @@ - - + + \ No newline at end of file diff --git a/ModelContextProtocol.sln b/ModelContextProtocol.sln index 8bd8f6273..4b9250c15 100644 --- a/ModelContextProtocol.sln +++ b/ModelContextProtocol.sln @@ -31,9 +31,11 @@ EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{B6FB2B28-D5DE-4654-BE9A-45E305DE4852}" ProjectSection(SolutionItems) = preProject Directory.Build.props = Directory.Build.props + Directory.Packages.props = Directory.Packages.props global.json = global.json LICENSE = LICENSE logo.png = logo.png + nuget.config = nuget.config README.MD = README.MD version.json = version.json EndProjectSection diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index 889b53bcd..ab28ef3af 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -1,5 +1,4 @@ using System.Text; -using System.Text.Json; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs index bf24770b7..ce4c038a9 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs @@ -17,34 +17,28 @@ public class McpClientFactoryTests [Fact] public async Task CreateAsync_WithInvalidArgs_Throws() { - await Assert.ThrowsAsync("serverConfig", () => McpClientFactory.CreateAsync((McpServerConfig)null!, _defaultOptions)); + await Assert.ThrowsAsync("serverConfig", () => McpClientFactory.CreateAsync((McpServerConfig)null!, _defaultOptions, cancellationToken: TestContext.Current.CancellationToken)); - await Assert.ThrowsAsync("clientOptions", () => McpClientFactory.CreateAsync( - new McpServerConfig() + await Assert.ThrowsAsync("clientOptions", () => McpClientFactory.CreateAsync(new McpServerConfig() { Name = "name", Id = "id", TransportType = TransportTypes.StdIo, - }, (McpClientOptions)null!)); + }, (McpClientOptions)null!, cancellationToken: TestContext.Current.CancellationToken)); - await Assert.ThrowsAsync("serverConfig", () => McpClientFactory.CreateAsync( - new McpServerConfig() + await Assert.ThrowsAsync("serverConfig", () => McpClientFactory.CreateAsync(new McpServerConfig() { Name = "name", Id = "id", TransportType = "somethingunsupported", - }, - _defaultOptions)); + }, _defaultOptions, cancellationToken: TestContext.Current.CancellationToken)); - await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync( - new McpServerConfig() + await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(new McpServerConfig() { Name = "name", Id = "id", TransportType = TransportTypes.StdIo, - }, - _defaultOptions, - (_, __) => null!)); + }, _defaultOptions, (_, __) => null!, cancellationToken: TestContext.Current.CancellationToken)); } [Fact] @@ -68,7 +62,8 @@ public async Task CreateAsync_WithValidStdioConfig_CreatesNewClient() await using var client = await McpClientFactory.CreateAsync( serverConfig, _defaultOptions, - (_, __) => new NopTransport()); + (_, __) => new NopTransport(), + cancellationToken: TestContext.Current.CancellationToken); // Assert Assert.NotNull(client); @@ -91,7 +86,8 @@ public async Task CreateAsync_WithNoTransportOptions_CreatesNewClient() await using var client = await McpClientFactory.CreateAsync( serverConfig, _defaultOptions, - (_, __) => new NopTransport()); + (_, __) => new NopTransport(), + cancellationToken: TestContext.Current.CancellationToken); // Assert Assert.NotNull(client); @@ -114,7 +110,8 @@ public async Task CreateAsync_WithValidSseConfig_CreatesNewClient() await using var client = await McpClientFactory.CreateAsync( serverConfig, _defaultOptions, - (_, __) => new NopTransport()); + (_, __) => new NopTransport(), + cancellationToken: TestContext.Current.CancellationToken); // Assert Assert.NotNull(client); @@ -144,7 +141,8 @@ public async Task CreateAsync_WithSse_CreatesCorrectTransportOptions() await using var client = await McpClientFactory.CreateAsync( serverConfig, _defaultOptions, - (_, __) => new NopTransport()); + (_, __) => new NopTransport(), + cancellationToken: TestContext.Current.CancellationToken); // Assert Assert.NotNull(client); @@ -171,7 +169,7 @@ public async Task McpFactory_WithInvalidTransportOptions_ThrowsFormatException(s }; // act & assert - await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(config, _defaultOptions)); + await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(config, _defaultOptions, cancellationToken: TestContext.Current.CancellationToken)); } private sealed class NopTransport : IClientTransport @@ -191,7 +189,7 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella { switch (message) { - case JsonRpcRequest request: + case JsonRpcRequest: _channel.Writer.TryWrite(new JsonRpcResponse { Id = ((JsonRpcRequest)message).Id, diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 4fda215d2..af13cc795 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -1,11 +1,12 @@ using ModelContextProtocol.Client; -using ModelContextProtocol.Configuration; -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Protocol.Types; using Microsoft.Extensions.AI; using OpenAI; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Protocol.Messages; using System.Text.Json; +using ModelContextProtocol.Configuration; +using ModelContextProtocol.Protocol.Transport; +using Xunit.Sdk; namespace ModelContextProtocol.Tests; @@ -61,8 +62,8 @@ public async Task ListTools_Stdio(string clientId) // act await using var client = await _fixture.CreateClientAsync(clientId); - var tools = await client.ListToolsAsync().ToListAsync(); - var aiFunctions = await client.GetAIFunctionsAsync(); + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken).ToListAsync(TestContext.Current.CancellationToken); + var aiFunctions = await client.GetAIFunctionsAsync(TestContext.Current.CancellationToken); // assert Assert.NotEmpty(tools); @@ -102,9 +103,9 @@ public async Task CallTool_Stdio_ViaAIFunction_EchoServer(string clientId) // act await using var client = await _fixture.CreateClientAsync(clientId); - var aiFunctions = await client.GetAIFunctionsAsync(); + var aiFunctions = await client.GetAIFunctionsAsync(TestContext.Current.CancellationToken); var echo = aiFunctions.Single(t => t.Name == "echo"); - var result = await echo.InvokeAsync([new KeyValuePair("message", "Hello MCP!")]); + var result = await echo.InvokeAsync([new KeyValuePair("message", "Hello MCP!")], TestContext.Current.CancellationToken); // assert Assert.NotNull(result); @@ -119,7 +120,7 @@ public async Task ListPrompts_Stdio(string clientId) // act await using var client = await _fixture.CreateClientAsync(clientId); - var prompts = await client.ListPromptsAsync().ToListAsync(); + var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken).ToListAsync(TestContext.Current.CancellationToken); // assert Assert.NotEmpty(prompts); @@ -251,7 +252,7 @@ public async Task SubscribeResource_Stdio() await client.SubscribeToResourceAsync("test://static/resource/1", CancellationToken.None); // notifications happen every 5 seconds, so we wait for 10 seconds to ensure we get at least one notification - await Task.Delay(10000); + await Task.Delay(10000, TestContext.Current.CancellationToken); // assert Assert.True(counter > 0); @@ -276,17 +277,17 @@ public async Task UnsubscribeResource_Stdio() await client.SubscribeToResourceAsync("test://static/resource/1", CancellationToken.None); // notifications happen every 5 seconds, so we wait for 10 seconds to ensure we get at least one notification - await Task.Delay(10000); + await Task.Delay(10000, TestContext.Current.CancellationToken); // reset counter int counterAfterSubscribe = counter; - + // unsubscribe await client.UnsubscribeFromResourceAsync("test://static/resource/1", CancellationToken.None); counter = 0; // notifications happen every 5 seconds, so we wait for 10 seconds to ensure we would've gotten at least one notification - await Task.Delay(10000); + await Task.Delay(10000, TestContext.Current.CancellationToken); // assert Assert.True(counterAfterSubscribe > 0); @@ -340,7 +341,7 @@ public async Task GetCompletion_Stdio_PromptReference(string clientId) [Theory] [MemberData(nameof(GetClients))] public async Task Sampling_Stdio(string clientId) - { + { // Set up the sampling handler int samplingHandlerCalls = 0; await using var client = await _fixture.CreateClientAsync(clientId, new() @@ -375,8 +376,8 @@ public async Task Sampling_Stdio(string clientId) { ["prompt"] = "Test prompt", ["maxTokens"] = 100 - } - ); + }, + TestContext.Current.CancellationToken); // assert Assert.NotNull(result); @@ -423,8 +424,8 @@ public async Task Notifications_Stdio(string clientId) await using var client = await _fixture.CreateClientAsync(clientId); // Verify we can send notifications without errors - await client.SendNotificationAsync(NotificationMethods.RootsUpdatedNotification); - await client.SendNotificationAsync("test/notification", new { test = true }); + await client.SendNotificationAsync(NotificationMethods.RootsUpdatedNotification, cancellationToken: TestContext.Current.CancellationToken); + await client.SendNotificationAsync("test/notification", new { test = true }, TestContext.Current.CancellationToken); // assert // no response to check, if no exception is thrown, it's a success @@ -452,13 +453,17 @@ public async Task CallTool_Stdio_MemoryServer() ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } }; - await using var client = await McpClientFactory.CreateAsync(serverConfig, clientOptions, loggerFactory: _fixture.LoggerFactory); + await using var client = await McpClientFactory.CreateAsync( + serverConfig, + clientOptions, + loggerFactory: _fixture.LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken); // act var result = await client.CallToolAsync( "read_graph", - [] - ); + [], + TestContext.Current.CancellationToken); // assert Assert.NotNull(result); @@ -471,14 +476,14 @@ public async Task CallTool_Stdio_MemoryServer() [Fact] public async Task GetAIFunctionsAsync_UsingEverythingServer_ToolsAreProperlyCalled() { - if (s_openAIKey is null) - { - return; // Skip the test if the OpenAI key is not provided - } + SkipTestIfNoOpenAIKey(); // Get the MCP client and tools from it. - await using var client = await McpClientFactory.CreateAsync(_fixture.EverythingServerConfig, _fixture.DefaultOptions); ; - var mappedTools = await client.GetAIFunctionsAsync(); + await using var client = await McpClientFactory.CreateAsync( + _fixture.EverythingServerConfig, + _fixture.DefaultOptions, + cancellationToken: TestContext.Current.CancellationToken); + var mappedTools = await client.GetAIFunctionsAsync(TestContext.Current.CancellationToken); // Create the chat client. using IChatClient chatClient = new OpenAIClient(s_openAIKey).AsChatClient("gpt-4o-mini") @@ -495,7 +500,7 @@ public async Task GetAIFunctionsAsync_UsingEverythingServer_ToolsAreProperlyCall messages.Add(new(ChatRole.User, "Please call the echo tool with the string 'Hello MCP!' and output the response ad verbatim.")); // Call the chat client - var response = await chatClient.GetResponseAsync(messages, new() { Tools = [.. mappedTools], Temperature = 0 }); + var response = await chatClient.GetResponseAsync(messages, new() { Tools = [.. mappedTools], Temperature = 0 }, TestContext.Current.CancellationToken); // Assert Assert.Contains("Echo: Hello MCP!", response.Text); @@ -504,10 +509,7 @@ public async Task GetAIFunctionsAsync_UsingEverythingServer_ToolsAreProperlyCall [Fact] public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() { - if (s_openAIKey is null) - { - return; // Skip the test if the OpenAI key is not provided - } + SkipTestIfNoOpenAIKey(); await using var client = await McpClientFactory.CreateAsync(_fixture.EverythingServerConfig, new() { @@ -519,12 +521,12 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() SamplingHandler = new OpenAIClient(s_openAIKey).AsChatClient("gpt-4o-mini").CreateSamplingHandler(), }, }, - }); + }, cancellationToken: TestContext.Current.CancellationToken); var result = await client.CallToolAsync("sampleLLM", new() { ["prompt"] = "In just a few words, what is the most famous tower in Paris?", - }); + }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.NotEmpty(result.Content); @@ -532,4 +534,9 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() Assert.Contains("LLM sampling result:", result.Content[0].Text); Assert.Contains("Eiffel", result.Content[0].Text); } + + private static void SkipTestIfNoOpenAIKey() + { + Assert.SkipWhen(s_openAIKey is null, "No OpenAI key provided. Skipping test."); + } } diff --git a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj index 6d31fac76..530f57c21 100644 --- a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj +++ b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj @@ -18,7 +18,7 @@ - + runtime; build; native; contentfiles; analyzers; buildtransitive all diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 76616de91..e8ca56bc4 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -7,7 +7,6 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Moq; -using System.Diagnostics; namespace ModelContextProtocol.Tests.Server; @@ -100,7 +99,7 @@ public async Task StartAsync_Should_Throw_InvalidOperationException_If_Already_I server.GetType().GetField("_isInitializing", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)?.SetValue(server, true); // Act & Assert - await Assert.ThrowsAsync(() => server.StartAsync()); + await Assert.ThrowsAsync(() => server.StartAsync(TestContext.Current.CancellationToken)); } [Fact] @@ -110,7 +109,7 @@ public async Task StartAsync_Should_Do_Nothing_If_Already_Initialized() await using var server = new McpServer(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); server.IsInitialized = true; - await server.StartAsync(); + await server.StartAsync(TestContext.Current.CancellationToken); // Assert _serverTransport.Verify(t => t.StartListeningAsync(It.IsAny()), Times.Never); @@ -123,7 +122,7 @@ public async Task StartAsync_ShouldStartListening() await using var server = new McpServer(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); // Act - await server.StartAsync(); + await server.StartAsync(TestContext.Current.CancellationToken); // Assert _serverTransport.Verify(t => t.StartListeningAsync(It.IsAny()), Times.Once); @@ -135,17 +134,16 @@ public async Task StartAsync_Sets_Initialized_After_Transport_Responses_Initiali await using var transport = new TestServerTransport(); await using var server = new McpServer(transport, _options, _loggerFactory.Object, _serviceProvider); - await server.StartAsync(); + await server.StartAsync(TestContext.Current.CancellationToken); // Send initialized notification - await transport.SendMessageAsync( - new JsonRpcNotification + await transport.SendMessageAsync(new JsonRpcNotification { Method = "notifications/initialized" } - ); +, TestContext.Current.CancellationToken); - await Task.Delay(50); + await Task.Delay(50, TestContext.Current.CancellationToken); Assert.True(server.IsInitialized); } @@ -171,7 +169,7 @@ public async Task RequestSamplingAsync_Should_SendRequest() await using var server = new McpServer(transport, _options, _loggerFactory.Object, _serviceProvider); server.ClientCapabilities = new ClientCapabilities { Sampling = new SamplingCapability() }; - await server.StartAsync(); + await server.StartAsync(TestContext.Current.CancellationToken); // Act var result = await server.RequestSamplingAsync(new CreateMessageRequestParams { Messages = [] }, CancellationToken.None); @@ -200,7 +198,7 @@ public async Task RequestRootsAsync_Should_SendRequest() await using var transport = new TestServerTransport(); await using var server = new McpServer(transport, _options, _loggerFactory.Object, _serviceProvider); server.ClientCapabilities = new ClientCapabilities { Roots = new RootsCapability() }; - await server.StartAsync(); + await server.StartAsync(TestContext.Current.CancellationToken); // Act var result = await server.RequestRootsAsync(new ListRootsRequestParams(), CancellationToken.None); @@ -588,7 +586,7 @@ public async Task AsSamplingChatClient_HandlesRequestResponse() Temperature = 0.75f, MaxOutputTokens = 42, StopSequences = ["."], - }); + }, TestContext.Current.CancellationToken); Assert.Equal("amazingmodel", response.ModelId); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index 6451e9e37..48f91dcc5 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -37,13 +37,17 @@ public async Task ConnectAndReceiveMessage_InMemoryServer() }; // Act - await using var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); + await using var client = await McpClientFactory.CreateAsync( + defaultConfig, + defaultOptions, + loggerFactory: loggerFactory, + cancellationToken: TestContext.Current.CancellationToken); // Wait for SSE connection to be established await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); // Send a test message through POST endpoint - await client.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }); + await client.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, TestContext.Current.CancellationToken); // Assert Assert.True(true); @@ -53,10 +57,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer() [Trait("Execution", "Manual")] public async Task ConnectAndReceiveMessage_EverythingServerWithSse() { - if (!EverythingSseServerFixture.IsDockerAvailable) - { - return; - } + Assert.SkipWhen(!EverythingSseServerFixture.IsDockerAvailable, "docker is not available"); using var loggerFactory = LoggerFactory.Create(builder => builder.AddConsole() @@ -82,8 +83,12 @@ public async Task ConnectAndReceiveMessage_EverythingServerWithSse() }; // Create client and run tests - await using var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); - var tools = await client.ListToolsAsync().ToListAsync(); + await using var client = await McpClientFactory.CreateAsync( + defaultConfig, + defaultOptions, + loggerFactory: loggerFactory, + cancellationToken: TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken).ToListAsync(TestContext.Current.CancellationToken); // assert Assert.NotEmpty(tools); @@ -93,10 +98,7 @@ public async Task ConnectAndReceiveMessage_EverythingServerWithSse() [Trait("Execution", "Manual")] public async Task Sampling_Sse_EverythingServer() { - if (!EverythingSseServerFixture.IsDockerAvailable) - { - return; - } + Assert.SkipWhen(!EverythingSseServerFixture.IsDockerAvailable, "docker is not available"); // arrange using var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => @@ -147,17 +149,19 @@ public async Task Sampling_Sse_EverythingServer() }, }; - await using var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); + await using var client = await McpClientFactory.CreateAsync( + defaultConfig, + defaultOptions, + loggerFactory: loggerFactory, + cancellationToken: TestContext.Current.CancellationToken); // Call the server's sampleLLM tool which should trigger our sampling handler - var result = await client.CallToolAsync( - "sampleLLM", - new Dictionary + var result = await client.CallToolAsync("sampleLLM", new Dictionary { ["prompt"] = "Test prompt", ["maxTokens"] = 100 } - ); +, TestContext.Current.CancellationToken); // assert Assert.NotNull(result); @@ -194,13 +198,17 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU }; // Act - await using var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); + await using var client = await McpClientFactory.CreateAsync( + defaultConfig, + defaultOptions, + loggerFactory: loggerFactory, + cancellationToken: TestContext.Current.CancellationToken); // Wait for SSE connection to be established await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); // Send a test message through POST endpoint - await client.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }); + await client.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, TestContext.Current.CancellationToken); // Assert Assert.True(true); @@ -233,7 +241,11 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() }; // Act - await using var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); + await using var client = await McpClientFactory.CreateAsync( + defaultConfig, + defaultOptions, + loggerFactory: loggerFactory, + cancellationToken: TestContext.Current.CancellationToken); // Wait for SSE connection to be established await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); @@ -251,7 +263,7 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() await server.SendTestNotificationAsync("Hello from server!"); // Assert - var message = await receivedNotification.Task.WaitAsync(TimeSpan.FromSeconds(10)); + var message = await receivedNotification.Task.WaitAsync(TimeSpan.FromSeconds(10), TestContext.Current.CancellationToken); Assert.Equal("Hello from server!", message); } @@ -282,7 +294,11 @@ public async Task ConnectTwice_Throws() }; // Act - await using var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); + await using var client = await McpClientFactory.CreateAsync( + defaultConfig, + defaultOptions, + loggerFactory: loggerFactory, + cancellationToken: TestContext.Current.CancellationToken); var mcpClient = (McpClient)client; var transport = (SseClientTransport)mcpClient.Transport; @@ -290,6 +306,6 @@ public async Task ConnectTwice_Throws() await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); // Assert - await Assert.ThrowsAsync(async () => await transport.ConnectAsync()); + await Assert.ThrowsAsync(async () => await transport.ConnectAsync(TestContext.Current.CancellationToken)); } } diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs index 186fee2a1..97f93489d 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs @@ -53,7 +53,7 @@ public async Task ListTools_Sse_TestServer() // act var client = await GetClientAsync(); - var tools = await client.ListToolsAsync().ToListAsync(); + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken).ToListAsync(TestContext.Current.CancellationToken); // assert Assert.NotNull(tools); @@ -145,7 +145,7 @@ public async Task ListPrompts_Sse_TestServer() // act var client = await GetClientAsync(); - var prompts = await client.ListPromptsAsync().ToListAsync(); + var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken).ToListAsync(TestContext.Current.CancellationToken); // assert Assert.NotNull(prompts); @@ -233,14 +233,12 @@ public async Task Sampling_Sse_TestServer() #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously // Call the server's sampleLLM tool which should trigger our sampling handler - var result = await client.CallToolAsync( - "sampleLLM", - new Dictionary + var result = await client.CallToolAsync("sampleLLM", new Dictionary { ["prompt"] = "Test prompt", ["maxTokens"] = 100 - } - ); + }, + TestContext.Current.CancellationToken); // assert Assert.NotNull(result); diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index 67fd84cf8..26aaecd07 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -95,7 +95,7 @@ public async Task ConnectAsync_Should_Connect_Successfully() }; }; - await transport.ConnectAsync(); + await transport.ConnectAsync(TestContext.Current.CancellationToken); } [Fact] @@ -174,7 +174,7 @@ public async Task SendMessageAsync_Handles_Accepted_Response() } }; - await transport.ConnectAsync(); + await transport.ConnectAsync(TestContext.Current.CancellationToken); await transport.SendMessageAsync(new JsonRpcRequest() { Method = "initialize", Id = RequestId.FromNumber(44) }, CancellationToken.None); Assert.True(true); @@ -213,7 +213,7 @@ public async Task SendMessageAsync_Handles_Accepted_Json_RPC_Response() } }; - await transport.ConnectAsync(); + await transport.ConnectAsync(TestContext.Current.CancellationToken); await transport.SendMessageAsync(new JsonRpcRequest() { Method = "initialize", Id = RequestId.FromNumber(44) }, CancellationToken.None); Assert.True(true); @@ -251,7 +251,7 @@ public async Task ReceiveMessagesAsync_Handles_Messages() throw new IOException("Abort"); }; - await transport.ConnectAsync(); + await transport.ConnectAsync(TestContext.Current.CancellationToken); Assert.True(transport.MessageReader.TryRead(out var message)); Assert.NotNull(message); Assert.IsType(message); diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index 44f60cd79..184cd8041 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -52,7 +52,7 @@ public async Task StartListeningAsync_Should_Set_Connected_State() { await using var transport = new StdioServerTransport(_serverOptions); - await transport.StartListeningAsync(); + await transport.StartListeningAsync(TestContext.Current.CancellationToken); Assert.True(transport.IsConnected); } @@ -70,12 +70,12 @@ public async Task SendMessageAsync_Should_Send_Message() Console.SetOut(output); await using var transport = new StdioServerTransport(_serverOptions, NullLoggerFactory.Instance); - await transport.StartListeningAsync(); + await transport.StartListeningAsync(TestContext.Current.CancellationToken); var message = new JsonRpcRequest { Method = "test", Id = RequestId.FromNumber(44) }; - await transport.SendMessageAsync(message); + await transport.SendMessageAsync(message, TestContext.Current.CancellationToken); var result = output.ToString()?.Trim(); var expected = JsonSerializer.Serialize(message, JsonSerializerOptionsExtensions.DefaultOptions); @@ -96,7 +96,7 @@ public async Task SendMessageAsync_Throws_Exception_If_Not_Connected() var message = new JsonRpcRequest { Method = "test" }; - await Assert.ThrowsAsync(() => transport.SendMessageAsync(message)); + await Assert.ThrowsAsync(() => transport.SendMessageAsync(message, TestContext.Current.CancellationToken)); } [Fact] @@ -123,9 +123,9 @@ public async Task ReadMessagesAsync_Should_Read_Messages() Console.SetOut(new StringWriter()); await using var transport = new StdioServerTransport(_serverOptions); - await transport.StartListeningAsync(); + await transport.StartListeningAsync(TestContext.Current.CancellationToken); - var canRead = await transport.MessageReader.WaitToReadAsync(); + var canRead = await transport.MessageReader.WaitToReadAsync(TestContext.Current.CancellationToken); Assert.True(canRead, "Nothing to read here from transport message reader"); Assert.True(transport.MessageReader.TryPeek(out var readMessage)); @@ -144,7 +144,7 @@ public async Task ReadMessagesAsync_Should_Read_Messages() public async Task CleanupAsync_Should_Cleanup_Resources() { var transport = new StdioServerTransport(_serverOptions); - await transport.StartListeningAsync(); + await transport.StartListeningAsync(TestContext.Current.CancellationToken); await transport.DisposeAsync();