diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 15dc5c98b..a5524c8aa 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -89,10 +89,12 @@ public async Task StartAsync_Should_Throw_InvalidOperationException_If_Already_I { // Arrange await using var server = McpServerFactory.Create(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); - server.GetType().GetField("_isInitializing", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)?.SetValue(server, true); + var task = server.StartAsync(TestContext.Current.CancellationToken); // Act & Assert await Assert.ThrowsAsync(() => server.StartAsync(TestContext.Current.CancellationToken)); + + await task; } [Fact] diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index 53d41cc1f..b91090355 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -110,48 +110,44 @@ public async Task ConnectAsync_Throws_If_Already_Connected() using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, NullLoggerFactory.Instance); - using var mreConnected = new ManualResetEventSlim(false); - using var mreDone = new ManualResetEventSlim(false); + var tcsConnected = new TaskCompletionSource(); + var tcsDone = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var callIndex = 0; - mockHttpHandler.RequestHandler = (request) => + mockHttpHandler.RequestHandler = async (request) => { switch (callIndex++) { case 0: - return Task.FromResult(new HttpResponseMessage + return new HttpResponseMessage { StatusCode = HttpStatusCode.OK, Content = new StringContent("event: endpoint\r\ndata: http://localhost\r\n\r\n") - }); + }; case 1: - mreConnected.Set(); - mreDone.Wait(); - return Task.FromResult(new HttpResponseMessage + tcsConnected.SetResult(); + await tcsDone.Task; + return new HttpResponseMessage { StatusCode = HttpStatusCode.OK, Content = new StringContent("") - }); + }; default: - return Task.FromResult(new HttpResponseMessage + return new HttpResponseMessage { StatusCode = HttpStatusCode.OK, Content = new StringContent("") - }); + }; } }; - var task = Task.Run(async () => - { - await transport.ConnectAsync(TestContext.Current.CancellationToken); - }, TestContext.Current.CancellationToken); - - mreConnected.Wait(TestContext.Current.CancellationToken); + var task = transport.ConnectAsync(TestContext.Current.CancellationToken); + await tcsConnected.Task; Assert.True(transport.IsConnected); var action = async () => await transport.ConnectAsync(); var exception = await Assert.ThrowsAsync(action); Assert.Equal("Transport is already connected", exception.Message); - mreDone.Set(); + tcsDone.SetResult(); await transport.CloseAsync(); await task; }