diff --git a/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs b/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs index 5a5d10ffb..e00ddbabe 100644 --- a/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs @@ -146,6 +146,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken = stderrRollingLog.Enqueue(data); } + _options.StandardErrorLines?.Invoke(data); + LogReadStderr(logger, endpointName, data); } }; diff --git a/src/ModelContextProtocol.Core/Client/StdioClientTransportOptions.cs b/src/ModelContextProtocol.Core/Client/StdioClientTransportOptions.cs index 650602246..5930e9d2a 100644 --- a/src/ModelContextProtocol.Core/Client/StdioClientTransportOptions.cs +++ b/src/ModelContextProtocol.Core/Client/StdioClientTransportOptions.cs @@ -69,4 +69,9 @@ public required string Command /// /// public TimeSpan ShutdownTimeout { get; set; } = TimeSpan.FromSeconds(5); + + /// + /// Gets or sets a callback that is invoked for each line of stderr received from the server process. + /// + public Action? StandardErrorLines { get; set; } } diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs index b8d8d714b..40602a9ed 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs @@ -1,6 +1,7 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Tests.Utils; using System.Runtime.InteropServices; +using System.Text; namespace ModelContextProtocol.Tests.Transport; @@ -18,4 +19,31 @@ public async Task CreateAsync_ValidProcessInvalidServer_Throws() IOException e = await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); Assert.Contains(id, e.ToString()); } + + [Fact] + public async Task CreateAsync_ValidProcessInvalidServer_StdErrCallbackInvoked() + { + string id = Guid.NewGuid().ToString("N"); + + int count = 0; + StringBuilder sb = new(); + Action stdErrCallback = line => + { + Assert.NotNull(line); + lock (sb) + { + sb.AppendLine(line); + count++; + } + }; + + StdioClientTransport transport = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? + new(new() { Command = "cmd", Arguments = ["/C", $"echo \"{id}\" >&2"], StandardErrorLines = stdErrCallback }, LoggerFactory) : + new(new() { Command = "ls", Arguments = [id], StandardErrorLines = stdErrCallback }, LoggerFactory); + + await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); + + Assert.InRange(count, 1, int.MaxValue); + Assert.Contains(id, sb.ToString()); + } }