diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index 08ad81d4c..598b4adff 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -10,9 +10,10 @@ namespace ModelContextProtocol.TestSseServer; public class Program { - private static ILoggerFactory CreateLoggerFactory() + private static ILoggerFactory CreateLoggerFactory() => LoggerFactory.Create(ConfigureSerilog); + + public static void ConfigureSerilog(ILoggingBuilder loggingBuilder) { - // Use serilog Log.Logger = new LoggerConfiguration() .MinimumLevel.Verbose() // Capture all log levels .WriteTo.File(Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "logs", "TestServer_.log"), @@ -21,15 +22,12 @@ private static ILoggerFactory CreateLoggerFactory() .CreateLogger(); var logsPath = Path.Combine(AppContext.BaseDirectory, "testserver.log"); - return LoggerFactory.Create(builder => - { - builder.AddSerilog(); - }); + loggingBuilder.AddSerilog(); } public static Task Main(string[] args) => MainAsync(args); - public static async Task MainAsync(string[] args, CancellationToken cancellationToken = default) + public static async Task MainAsync(string[] args, ILoggerFactory? loggerFactory = null, CancellationToken cancellationToken = default) { Console.WriteLine("Starting server..."); @@ -385,7 +383,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st }, }; - using var loggerFactory = CreateLoggerFactory(); + loggerFactory ??= CreateLoggerFactory(); server = McpServerFactory.Create(new HttpListenerSseServerTransport("TestServer", 3001, loggerFactory), options, loggerFactory); Console.WriteLine("Server initialized."); diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs index 1e054ef75..b7de29d55 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs @@ -5,9 +5,10 @@ namespace ModelContextProtocol.Tests; -public class ClientIntegrationTestFixture : IDisposable +public class ClientIntegrationTestFixture { - public ILoggerFactory LoggerFactory { get; } + private ILoggerFactory? _loggerFactory; + public McpClientOptions DefaultOptions { get; } public McpServerConfig EverythingServerConfig { get; } public McpServerConfig TestServerConfig { get; } @@ -16,10 +17,6 @@ public class ClientIntegrationTestFixture : IDisposable public ClientIntegrationTestFixture() { - LoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => - builder.AddConsole() - .SetMinimumLevel(LogLevel.Debug)); - DefaultOptions = new() { ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" }, @@ -56,17 +53,16 @@ public ClientIntegrationTestFixture() } } + public void Initialize(ILoggerFactory loggerFactory) + { + _loggerFactory = loggerFactory; + } + public Task CreateClientAsync(string clientId, McpClientOptions? clientOptions = null) => McpClientFactory.CreateAsync(clientId switch { "everything" => EverythingServerConfig, "test_server" => TestServerConfig, _ => throw new ArgumentException($"Unknown client ID: {clientId}") - }, clientOptions ?? DefaultOptions, loggerFactory: LoggerFactory); - - public void Dispose() - { - LoggerFactory?.Dispose(); - GC.SuppressFinalize(this); - } + }, clientOptions ?? DefaultOptions, loggerFactory: _loggerFactory); } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index f1996ca32..b210fb03e 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -6,6 +6,7 @@ using System.Text.Json; using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Tests.Utils; using Xunit.Sdk; using System.Text.Encodings.Web; using System.Text.Json.Serialization.Metadata; @@ -13,15 +14,17 @@ namespace ModelContextProtocol.Tests; -public class ClientIntegrationTests : IClassFixture +public class ClientIntegrationTests : LoggedTest, IClassFixture { private static readonly string? s_openAIKey = Environment.GetEnvironmentVariable("AI:OpenAI:ApiKey")!; private readonly ClientIntegrationTestFixture _fixture; - public ClientIntegrationTests(ClientIntegrationTestFixture fixture) + public ClientIntegrationTests(ClientIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) + : base(testOutputHelper) { _fixture = fixture; + _fixture.Initialize(LoggerFactory); } public static IEnumerable GetClients() => @@ -474,7 +477,7 @@ public async Task CallTool_Stdio_MemoryServer() await using var client = await McpClientFactory.CreateAsync( serverConfig, clientOptions, - loggerFactory: _fixture.LoggerFactory, + loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); // act diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs index 345c59cac..8ae723ca6 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs @@ -1,18 +1,18 @@ using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; -using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Tests.Utils; using Moq; namespace ModelContextProtocol.Tests.Server; -public class McpServerFactoryTests +public class McpServerFactoryTests : LoggedTest { private readonly Mock _serverTransport; private readonly McpServerOptions _options; - private readonly IServiceProvider _serviceProvider; - public McpServerFactoryTests() + public McpServerFactoryTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) { _serverTransport = new Mock(); _options = new McpServerOptions @@ -21,14 +21,13 @@ public McpServerFactoryTests() ProtocolVersion = "1.0", InitializationTimeout = TimeSpan.FromSeconds(30) }; - _serviceProvider = new Mock().Object; } [Fact] public async Task Create_Should_Initialize_With_Valid_Parameters() { // Arrange & Act - await using IMcpServer server = McpServerFactory.Create(_serverTransport.Object, _options, NullLoggerFactory.Instance); + await using IMcpServer server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory); // Assert Assert.NotNull(server); @@ -38,13 +37,13 @@ public async Task Create_Should_Initialize_With_Valid_Parameters() public void Constructor_Throws_For_Null_ServerTransport() { // Arrange, Act & Assert - Assert.Throws("serverTransport", () => McpServerFactory.Create(null!, _options, NullLoggerFactory.Instance)); + Assert.Throws("serverTransport", () => McpServerFactory.Create(null!, _options, LoggerFactory)); } [Fact] public void Constructor_Throws_For_Null_Options() { // Arrange, Act & Assert - Assert.Throws("serverOptions", () => McpServerFactory.Create(_serverTransport.Object, null!, NullLoggerFactory.Instance)); + Assert.Throws("serverOptions", () => McpServerFactory.Create(_serverTransport.Object, null!, LoggerFactory)); } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index a5524c8aa..cddb99264 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -11,20 +11,18 @@ namespace ModelContextProtocol.Tests.Server; -public class McpServerTests +public class McpServerTests : LoggedTest { private readonly Mock _serverTransport; - private readonly Mock _loggerFactory; private readonly Mock _logger; private readonly McpServerOptions _options; private readonly IServiceProvider _serviceProvider; - public McpServerTests() + public McpServerTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) { _serverTransport = new Mock(); - _loggerFactory = new Mock(); _logger = new Mock(); - _loggerFactory.Setup(f => f.CreateLogger(It.IsAny())).Returns(_logger.Object); _options = CreateOptions(); _serviceProvider = new Mock().Object; } @@ -44,7 +42,7 @@ private static McpServerOptions CreateOptions(ServerCapabilities? capabilities = public async Task Constructor_Should_Initialize_With_Valid_Parameters() { // Arrange & Act - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); + await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); // Assert Assert.NotNull(server); @@ -54,14 +52,14 @@ public async Task Constructor_Should_Initialize_With_Valid_Parameters() public void Constructor_Throws_For_Null_Transport() { // Arrange, Act & Assert - Assert.Throws(() => McpServerFactory.Create(null!, _options, _loggerFactory.Object, _serviceProvider)); + Assert.Throws(() => McpServerFactory.Create(null!, _options, LoggerFactory, _serviceProvider)); } [Fact] public void Constructor_Throws_For_Null_Options() { // Arrange, Act & Assert - Assert.Throws(() => McpServerFactory.Create(_serverTransport.Object, null!, _loggerFactory.Object, _serviceProvider)); + Assert.Throws(() => McpServerFactory.Create(_serverTransport.Object, null!, LoggerFactory, _serviceProvider)); } [Fact] @@ -78,7 +76,7 @@ public async Task Constructor_Does_Not_Throw_For_Null_Logger() public async Task Constructor_Does_Not_Throw_For_Null_ServiceProvider() { // Arrange & Act - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, _loggerFactory.Object, null); + await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, null); // Assert Assert.NotNull(server); @@ -88,7 +86,7 @@ public async Task Constructor_Does_Not_Throw_For_Null_ServiceProvider() public async Task StartAsync_Should_Throw_InvalidOperationException_If_Already_Initializing() { // Arrange - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); + await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); var task = server.StartAsync(TestContext.Current.CancellationToken); // Act & Assert @@ -101,7 +99,7 @@ public async Task StartAsync_Should_Throw_InvalidOperationException_If_Already_I public async Task StartAsync_Should_Do_Nothing_If_Already_Initialized() { // Arrange - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); + await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); SetInitialized(server, true); await server.StartAsync(TestContext.Current.CancellationToken); @@ -114,7 +112,7 @@ public async Task StartAsync_Should_Do_Nothing_If_Already_Initialized() public async Task StartAsync_ShouldStartListening() { // Arrange - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); + await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); // Act await server.StartAsync(TestContext.Current.CancellationToken); @@ -127,7 +125,7 @@ public async Task StartAsync_ShouldStartListening() public async Task StartAsync_Sets_Initialized_After_Transport_Responses_Initialized_Notification() { await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, _loggerFactory.Object, _serviceProvider); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, _serviceProvider); await server.StartAsync(TestContext.Current.CancellationToken); @@ -147,7 +145,7 @@ await transport.SendMessageAsync(new JsonRpcNotification public async Task RequestSamplingAsync_Should_Throw_McpServerException_If_Client_Does_Not_Support_Sampling() { // Arrange - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); + await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); SetClientCapabilities(server, new ClientCapabilities()); var action = () => server.RequestSamplingAsync(new CreateMessageRequestParams { Messages = [] }, CancellationToken.None); @@ -161,7 +159,7 @@ public async Task RequestSamplingAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, _loggerFactory.Object, _serviceProvider); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, _serviceProvider); SetClientCapabilities(server, new ClientCapabilities { Sampling = new SamplingCapability() }); await server.StartAsync(TestContext.Current.CancellationToken); @@ -179,7 +177,7 @@ public async Task RequestSamplingAsync_Should_SendRequest() public async Task RequestRootsAsync_Should_Throw_McpServerException_If_Client_Does_Not_Support_Roots() { // Arrange - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); + await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); SetClientCapabilities(server, new ClientCapabilities()); // Act & Assert @@ -191,7 +189,7 @@ public async Task RequestRootsAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, _loggerFactory.Object, _serviceProvider); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, _serviceProvider); SetClientCapabilities(server, new ClientCapabilities { Roots = new RootsCapability() }); await server.StartAsync(TestContext.Current.CancellationToken); @@ -208,7 +206,7 @@ public async Task RequestRootsAsync_Should_SendRequest() [Fact] public async Task Throws_Exception_If_Not_Connected() { - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, _loggerFactory.Object, _serviceProvider); + await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); SetClientCapabilities(server, new ClientCapabilities { Roots = new RootsCapability() }); _serverTransport.SetupGet(t => t.IsConnected).Returns(false); @@ -555,7 +553,7 @@ private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, s var options = CreateOptions(serverCapabilities); configureOptions?.Invoke(options); - await using var server = McpServerFactory.Create(transport, options, _loggerFactory.Object, _serviceProvider); + await using var server = McpServerFactory.Create(transport, options, LoggerFactory, _serviceProvider); await server.StartAsync(); @@ -587,7 +585,7 @@ private async Task Throws_Exception_If_No_Handler_Assigned(ServerCapabilities se await using var transport = new TestServerTransport(); var options = CreateOptions(serverCapabilities); - Assert.Throws(() => McpServerFactory.Create(transport, options, _loggerFactory.Object, _serviceProvider)); + Assert.Throws(() => McpServerFactory.Create(transport, options, LoggerFactory, _serviceProvider)); } [Fact] diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs index e9a361c18..2702ed83c 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs @@ -2,23 +2,29 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Test.Utils; +using ModelContextProtocol.TestSseServer; namespace ModelContextProtocol.Tests; public class SseServerIntegrationTestFixture : IAsyncDisposable { - private readonly CancellationTokenSource _stopCts = new(); private readonly Task _serverTask; + private readonly CancellationTokenSource _stopCts = new(); + + private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper = new(); + private readonly ILoggerFactory _redirectingLoggerFactory; - public ILoggerFactory LoggerFactory { get; } public McpClientOptions DefaultOptions { get; } public McpServerConfig DefaultConfig { get; } public SseServerIntegrationTestFixture() { - LoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => - builder.AddConsole() - .SetMinimumLevel(LogLevel.Debug)); + _redirectingLoggerFactory = LoggerFactory.Create(builder => + { + Program.ConfigureSerilog(builder); + builder.AddProvider(new XunitLoggerProvider(_delegatingTestOutputHelper)); + }); DefaultOptions = new() { @@ -34,12 +40,17 @@ public SseServerIntegrationTestFixture() Location = "http://localhost:3001/sse" }; - _serverTask = TestSseServer.Program.MainAsync([], _stopCts.Token); + _serverTask = Program.MainAsync([], _redirectingLoggerFactory, _stopCts.Token); + } + + public void Initialize(ITestOutputHelper output) + { + _delegatingTestOutputHelper.CurrentTestOutputHelper = output; } public async ValueTask DisposeAsync() { - LoggerFactory.Dispose(); + _delegatingTestOutputHelper.CurrentTestOutputHelper = null; _stopCts.Cancel(); try { @@ -48,6 +59,19 @@ public async ValueTask DisposeAsync() catch (OperationCanceledException) { } + _redirectingLoggerFactory.Dispose(); _stopCts.Dispose(); } + + private class DelegatingTestOutputHelper() : ITestOutputHelper + { + public ITestOutputHelper? CurrentTestOutputHelper { get; set; } + + public string Output => CurrentTestOutputHelper?.Output ?? string.Empty; + + public void Write(string message) => CurrentTestOutputHelper?.Write(message); + public void Write(string format, params object[] args) => CurrentTestOutputHelper?.Write(format, args); + public void WriteLine(string message) => CurrentTestOutputHelper?.WriteLine(message); + public void WriteLine(string format, params object[] args) => CurrentTestOutputHelper?.WriteLine(format, args); + } } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs index 97f93489d..7142ee055 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs @@ -1,15 +1,18 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Tests.Utils; namespace ModelContextProtocol.Tests; -public class SseServerIntegrationTests : IClassFixture +public class SseServerIntegrationTests : LoggedTest, IClassFixture { private readonly SseServerIntegrationTestFixture _fixture; - public SseServerIntegrationTests(SseServerIntegrationTestFixture fixture) + public SseServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) + : base(testOutputHelper) { _fixture = fixture; + _fixture.Initialize(testOutputHelper); } private Task GetClientAsync(McpClientOptions? options = null) @@ -17,7 +20,7 @@ private Task GetClientAsync(McpClientOptions? options = null) return McpClientFactory.CreateAsync( _fixture.DefaultConfig, options ?? _fixture.DefaultOptions, - loggerFactory: _fixture.LoggerFactory); + loggerFactory: LoggerFactory); } [Fact] diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index b91090355..761c55e7a 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -1,5 +1,4 @@ -using Microsoft.Extensions.Logging.Abstractions; -using ModelContextProtocol.Configuration; +using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Tests.Utils; @@ -8,12 +7,13 @@ namespace ModelContextProtocol.Tests.Transport; -public class SseClientTransportTests +public class SseClientTransportTests : LoggedTest { private readonly McpServerConfig _serverConfig; private readonly SseClientTransportOptions _transportOptions; - public SseClientTransportTests() + public SseClientTransportTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) { _serverConfig = new McpServerConfig { @@ -39,7 +39,7 @@ public SseClientTransportTests() public async Task Constructor_Should_Initialize_With_Valid_Parameters() { // Act - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, NullLoggerFactory.Instance); + await using var transport = new SseClientTransport(_transportOptions, _serverConfig, LoggerFactory); // Assert Assert.NotNull(transport); @@ -58,21 +58,21 @@ public async Task Constructor_Should_Initialize_With_Valid_Parameters() [Fact] public void Constructor_Throws_For_Null_Options() { - var exception = Assert.Throws(() => new SseClientTransport(null!, _serverConfig, NullLoggerFactory.Instance)); + var exception = Assert.Throws(() => new SseClientTransport(null!, _serverConfig, LoggerFactory)); Assert.Equal("transportOptions", exception.ParamName); } [Fact] public void Constructor_Throws_For_Null_Config() { - var exception = Assert.Throws(() => new SseClientTransport(_transportOptions, null!, NullLoggerFactory.Instance)); + var exception = Assert.Throws(() => new SseClientTransport(_transportOptions, null!, LoggerFactory)); Assert.Equal("serverConfig", exception.ParamName); } [Fact] public void Constructor_Throws_For_Null_HttpClientg() { - var exception = Assert.Throws(() => new SseClientTransport(_transportOptions, _serverConfig, null!, NullLoggerFactory.Instance)); + var exception = Assert.Throws(() => new SseClientTransport(_transportOptions, _serverConfig, null!, LoggerFactory)); Assert.Equal("httpClient", exception.ParamName); } @@ -81,7 +81,7 @@ public async Task ConnectAsync_Should_Connect_Successfully() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, NullLoggerFactory.Instance); + await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); bool firstCall = true; @@ -109,7 +109,7 @@ public async Task ConnectAsync_Throws_If_Already_Connected() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, NullLoggerFactory.Instance); + await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); var tcsConnected = new TaskCompletionSource(); var tcsDone = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var callIndex = 0; @@ -157,7 +157,7 @@ public async Task ConnectAsync_Throws_Exception_On_Failure() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, NullLoggerFactory.Instance); + await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); var retries = 0; mockHttpHandler.RequestHandler = (request) => @@ -177,7 +177,7 @@ public async Task ConnectAsync_Throws_Exception_On_Failure() [Fact] public async Task SendMessageAsync_Throws_Exception_If_MessageEndpoint_Not_Set() { - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, NullLoggerFactory.Instance); + await using var transport = new SseClientTransport(_transportOptions, _serverConfig, LoggerFactory); // Assert await Assert.ThrowsAsync(() => transport.SendMessageAsync(new JsonRpcRequest() { Method = "test" }, CancellationToken.None)); @@ -188,7 +188,7 @@ public async Task SendMessageAsync_Handles_Accepted_Response() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, NullLoggerFactory.Instance); + await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); var firstCall = true; mockHttpHandler.RequestHandler = (request) => @@ -227,7 +227,7 @@ public async Task SendMessageAsync_Handles_Accepted_Json_RPC_Response() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, NullLoggerFactory.Instance); + await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); var firstCall = true; mockHttpHandler.RequestHandler = (request) => @@ -266,7 +266,7 @@ public async Task ReceiveMessagesAsync_Handles_Messages() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, NullLoggerFactory.Instance); + await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); var callIndex = 0; mockHttpHandler.RequestHandler = (request) => @@ -303,7 +303,7 @@ public async Task ReceiveMessagesAsync_Handles_Messages() [Fact] public async Task CloseAsync_Should_Dispose_Resources() { - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, NullLoggerFactory.Instance); + await using var transport = new SseClientTransport(_transportOptions, _serverConfig, LoggerFactory); await transport.CloseAsync(); @@ -313,7 +313,7 @@ public async Task CloseAsync_Should_Dispose_Resources() [Fact] public async Task DisposeAsync_Should_Dispose_Resources() { - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, NullLoggerFactory.Instance); + await using var transport = new SseClientTransport(_transportOptions, _serverConfig, LoggerFactory); await transport.DisposeAsync(); diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index 9a6cd64b2..94ae62728 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -1,8 +1,8 @@ -using Microsoft.Extensions.Logging.Abstractions; -using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; using ModelContextProtocol.Utils.Json; using System.IO.Pipelines; using System.Text; @@ -10,11 +10,12 @@ namespace ModelContextProtocol.Tests.Transport; -public class StdioServerTransportTests +public class StdioServerTransportTests : LoggedTest { private readonly McpServerOptions _serverOptions; - public StdioServerTransportTests() + public StdioServerTransportTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) { _serverOptions = new McpServerOptions { @@ -68,7 +69,7 @@ public async Task SendMessageAsync_Should_Send_Message() _serverOptions.ServerInfo.Name, new Pipe().Reader.AsStream(), output, - NullLoggerFactory.Instance); + LoggerFactory); await transport.StartListeningAsync(TestContext.Current.CancellationToken); @@ -122,7 +123,7 @@ public async Task ReadMessagesAsync_Should_Read_Messages() _serverOptions.ServerInfo.Name, input, Stream.Null, - NullLoggerFactory.Instance); + LoggerFactory); await transport.StartListeningAsync(TestContext.Current.CancellationToken); @@ -165,7 +166,7 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() _serverOptions.ServerInfo.Name, new Pipe().Reader.AsStream(), output, - NullLoggerFactory.Instance); + LoggerFactory); await transport.StartListeningAsync(TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs b/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs new file mode 100644 index 000000000..3cf83ac42 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Utils/LoggedTest.cs @@ -0,0 +1,18 @@ +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Test.Utils; + +namespace ModelContextProtocol.Tests.Utils; + +public class LoggedTest(ITestOutputHelper testOutputHelper) +{ + public ITestOutputHelper TestOutputHelper { get; } = testOutputHelper; + public ILoggerFactory LoggerFactory { get; } = CreateLoggerFactory(testOutputHelper); + + private static ILoggerFactory CreateLoggerFactory(ITestOutputHelper testOutputHelper) + { + return Microsoft.Extensions.Logging.LoggerFactory.Create(builder => + { + builder.AddProvider(new XunitLoggerProvider(testOutputHelper)); + }); + } +} diff --git a/tests/ModelContextProtocol.Tests/Utils/XunitLoggerProvider.cs b/tests/ModelContextProtocol.Tests/Utils/XunitLoggerProvider.cs new file mode 100644 index 000000000..c76d2649a --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Utils/XunitLoggerProvider.cs @@ -0,0 +1,52 @@ +using System.Globalization; +using System.Text; +using Microsoft.Extensions.Logging; + +namespace ModelContextProtocol.Test.Utils; + +public class XunitLoggerProvider(ITestOutputHelper output) : ILoggerProvider +{ + public ILogger CreateLogger(string categoryName) + { + return new XunitLogger(output, categoryName); + } + + public void Dispose() + { + } + + private class XunitLogger(ITestOutputHelper output, string category) : ILogger + { + public void Log( + LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) + { + var sb = new StringBuilder(); + + var timestamp = DateTimeOffset.UtcNow.ToString("s", CultureInfo.InvariantCulture); + var prefix = $"| [{timestamp}] {category} {logLevel}: "; + var lines = formatter(state, exception); + sb.Append(prefix); + sb.Append(lines); + + if (exception is not null) + { + sb.AppendLine(); + sb.Append(exception.ToString()); + } + + output.WriteLine(sb.ToString()); + } + + public bool IsEnabled(LogLevel logLevel) => true; + + public IDisposable BeginScope(TState state) where TState : notnull + => new NoopDisposable(); + + private sealed class NoopDisposable : IDisposable + { + public void Dispose() + { + } + } + } +}