| 
1 |  | -using ModelContextProtocol.Configuration;  | 
 | 1 | +using Microsoft.Extensions.Logging;  | 
 | 2 | +using ModelContextProtocol.Configuration;  | 
2 | 3 | using ModelContextProtocol.Logging;  | 
3 | 4 | using ModelContextProtocol.Protocol.Messages;  | 
4 | 5 | using ModelContextProtocol.Protocol.Transport;  | 
5 | 6 | using ModelContextProtocol.Protocol.Types;  | 
6 | 7 | using ModelContextProtocol.Shared;  | 
7 | 8 | using ModelContextProtocol.Utils.Json;  | 
8 |  | -using Microsoft.Extensions.Logging;  | 
9 | 9 | using System.Text.Json;  | 
10 | 10 | 
 
  | 
11 | 11 | namespace ModelContextProtocol.Client;  | 
12 | 12 | 
 
  | 
13 | 13 | /// <inheritdoc/>  | 
14 | 14 | internal sealed class McpClient : McpJsonRpcEndpoint, IMcpClient  | 
15 | 15 | {  | 
16 |  | -    private readonly McpClientOptions _options;  | 
17 | 16 |     private readonly IClientTransport _clientTransport;  | 
 | 17 | +    private readonly McpClientOptions _options;  | 
18 | 18 | 
 
  | 
19 |  | -    private volatile bool _isInitializing;  | 
 | 19 | +    private ITransport? _sessionTransport;  | 
 | 20 | +    private CancellationTokenSource? _connectCts;  | 
20 | 21 | 
 
  | 
21 | 22 |     /// <summary>  | 
22 | 23 |     /// Initializes a new instance of the <see cref="McpClient"/> class.  | 
23 | 24 |     /// </summary>  | 
24 |  | -    /// <param name="transport">The transport to use for communication with the server.</param>  | 
 | 25 | +    /// <param name="clientTransport">The transport to use for communication with the server.</param>  | 
25 | 26 |     /// <param name="options">Options for the client, defining protocol version and capabilities.</param>  | 
26 | 27 |     /// <param name="serverConfig">The server configuration.</param>  | 
27 | 28 |     /// <param name="loggerFactory">The logger factory.</param>  | 
28 |  | -    public McpClient(IClientTransport transport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)  | 
29 |  | -        : base(transport, loggerFactory)  | 
 | 29 | +    public McpClient(IClientTransport clientTransport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)  | 
 | 30 | +        : base(loggerFactory)  | 
30 | 31 |     {  | 
 | 32 | +        _clientTransport = clientTransport;  | 
31 | 33 |         _options = options;  | 
32 |  | -        _clientTransport = transport;  | 
33 | 34 | 
 
  | 
34 | 35 |         EndpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})";  | 
35 | 36 | 
 
  | 
@@ -70,95 +71,95 @@ public McpClient(IClientTransport transport, McpClientOptions options, McpServer  | 
70 | 71 |     /// <inheritdoc/>  | 
71 | 72 |     public override string EndpointName { get; }  | 
72 | 73 | 
 
  | 
73 |  | -    /// <inheritdoc/>  | 
74 | 74 |     public async Task ConnectAsync(CancellationToken cancellationToken = default)  | 
75 | 75 |     {  | 
76 |  | -        if (IsInitialized)  | 
77 |  | -        {  | 
78 |  | -            _logger.ClientAlreadyInitialized(EndpointName);  | 
79 |  | -            return;  | 
80 |  | -        }  | 
81 |  | - | 
82 |  | -        if (_isInitializing)  | 
83 |  | -        {  | 
84 |  | -            _logger.ClientAlreadyInitializing(EndpointName);  | 
85 |  | -            throw new InvalidOperationException("Client is already initializing");  | 
86 |  | -        }  | 
 | 76 | +        _connectCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);  | 
 | 77 | +        cancellationToken = _connectCts.Token;  | 
87 | 78 | 
 
  | 
88 |  | -        _isInitializing = true;  | 
89 | 79 |         try  | 
90 | 80 |         {  | 
91 |  | -            CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);  | 
92 |  | - | 
93 | 81 |             // Connect transport  | 
94 |  | -            await _clientTransport.ConnectAsync(CancellationTokenSource.Token).ConfigureAwait(false);  | 
95 |  | - | 
96 |  | -            // Start processing messages  | 
97 |  | -            MessageProcessingTask = ProcessMessagesAsync(CancellationTokenSource.Token);  | 
 | 82 | +            _sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);  | 
 | 83 | +            InitializeSession(_sessionTransport);  | 
 | 84 | +            // We don't want the ConnectAsync token to cancel the session after we've successfully connected.  | 
 | 85 | +            // The base class handles cleaning up the session in DisposeAsync without our help.  | 
 | 86 | +            StartSession(fullSessionCancellationToken: CancellationToken.None);  | 
98 | 87 | 
 
  | 
99 | 88 |             // Perform initialization sequence  | 
100 |  | -            await InitializeAsync(CancellationTokenSource.Token).ConfigureAwait(false);  | 
 | 89 | +            using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);  | 
 | 90 | +            initializationCts.CancelAfter(_options.InitializationTimeout);  | 
101 | 91 | 
 
  | 
102 |  | -            IsInitialized = true;  | 
 | 92 | +            try  | 
 | 93 | +            {  | 
 | 94 | +                // Send initialize request  | 
 | 95 | +                var initializeResponse = await SendRequestAsync<InitializeResult>(  | 
 | 96 | +                    new JsonRpcRequest  | 
 | 97 | +                    {  | 
 | 98 | +                        Method = "initialize",  | 
 | 99 | +                        Params = new InitializeRequestParams()  | 
 | 100 | +                        {  | 
 | 101 | +                            ProtocolVersion = _options.ProtocolVersion,  | 
 | 102 | +                            Capabilities = _options.Capabilities ?? new ClientCapabilities(),  | 
 | 103 | +                            ClientInfo = _options.ClientInfo,  | 
 | 104 | +                        }  | 
 | 105 | +                    },  | 
 | 106 | +                    initializationCts.Token).ConfigureAwait(false);  | 
 | 107 | + | 
 | 108 | +                // Store server information  | 
 | 109 | +                _logger.ServerCapabilitiesReceived(EndpointName,  | 
 | 110 | +                    capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities),  | 
 | 111 | +                    serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation));  | 
 | 112 | + | 
 | 113 | +                ServerCapabilities = initializeResponse.Capabilities;  | 
 | 114 | +                ServerInfo = initializeResponse.ServerInfo;  | 
 | 115 | +                ServerInstructions = initializeResponse.Instructions;  | 
 | 116 | + | 
 | 117 | +                // Validate protocol version  | 
 | 118 | +                if (initializeResponse.ProtocolVersion != _options.ProtocolVersion)  | 
 | 119 | +                {  | 
 | 120 | +                    _logger.ServerProtocolVersionMismatch(EndpointName, _options.ProtocolVersion, initializeResponse.ProtocolVersion);  | 
 | 121 | +                    throw new McpClientException($"Server protocol version mismatch. Expected {_options.ProtocolVersion}, got {initializeResponse.ProtocolVersion}");  | 
 | 122 | +                }  | 
 | 123 | + | 
 | 124 | +                // Send initialized notification  | 
 | 125 | +                await SendMessageAsync(  | 
 | 126 | +                    new JsonRpcNotification { Method = "notifications/initialized" },  | 
 | 127 | +                    initializationCts.Token).ConfigureAwait(false);  | 
 | 128 | +            }  | 
 | 129 | +            catch (OperationCanceledException) when (initializationCts.IsCancellationRequested)  | 
 | 130 | +            {  | 
 | 131 | +                _logger.ClientInitializationTimeout(EndpointName);  | 
 | 132 | +                throw new McpClientException("Initialization timed out");  | 
 | 133 | +            }  | 
103 | 134 |         }  | 
104 | 135 |         catch (Exception e)  | 
105 | 136 |         {  | 
106 | 137 |             _logger.ClientInitializationError(EndpointName, e);  | 
107 |  | -            await CleanupAsync().ConfigureAwait(false);  | 
 | 138 | +            await DisposeAsync().ConfigureAwait(false);  | 
108 | 139 |             throw;  | 
109 | 140 |         }  | 
110 |  | -        finally  | 
111 |  | -        {  | 
112 |  | -            _isInitializing = false;  | 
113 |  | -        }  | 
114 | 141 |     }  | 
115 | 142 | 
 
  | 
116 |  | -    private async Task InitializeAsync(CancellationToken cancellationToken)  | 
 | 143 | +    /// <inheritdoc/>  | 
 | 144 | +    public override async ValueTask DisposeUnsynchronizedAsync()  | 
117 | 145 |     {  | 
118 |  | -        using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);  | 
119 |  | -        initializationCts.CancelAfter(_options.InitializationTimeout);  | 
 | 146 | +        if (_connectCts is not null)  | 
 | 147 | +        {  | 
 | 148 | +            await _connectCts.CancelAsync().ConfigureAwait(false);  | 
 | 149 | +        }  | 
120 | 150 | 
 
  | 
121 | 151 |         try  | 
122 | 152 |         {  | 
123 |  | -            // Send initialize request  | 
124 |  | -            var initializeResponse = await SendRequestAsync<InitializeResult>(  | 
125 |  | -                new JsonRpcRequest  | 
126 |  | -                {  | 
127 |  | -                    Method = "initialize",  | 
128 |  | -                    Params = new InitializeRequestParams()  | 
129 |  | -                    {  | 
130 |  | -                        ProtocolVersion = _options.ProtocolVersion,  | 
131 |  | -                        Capabilities = _options.Capabilities ?? new ClientCapabilities(),  | 
132 |  | -                        ClientInfo = _options.ClientInfo  | 
133 |  | -                    }  | 
134 |  | -                },  | 
135 |  | -                initializationCts.Token).ConfigureAwait(false);  | 
136 |  | - | 
137 |  | -            // Store server information  | 
138 |  | -            _logger.ServerCapabilitiesReceived(EndpointName,   | 
139 |  | -                capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities),  | 
140 |  | -                serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation));  | 
141 |  | - | 
142 |  | -            ServerCapabilities = initializeResponse.Capabilities;  | 
143 |  | -            ServerInfo = initializeResponse.ServerInfo;  | 
144 |  | -            ServerInstructions = initializeResponse.Instructions;  | 
145 |  | - | 
146 |  | -            // Validate protocol version  | 
147 |  | -            if (initializeResponse.ProtocolVersion != _options.ProtocolVersion)  | 
 | 153 | +            await base.DisposeUnsynchronizedAsync().ConfigureAwait(false);  | 
 | 154 | +        }  | 
 | 155 | +        finally  | 
 | 156 | +        {  | 
 | 157 | +            if (_sessionTransport is not null)  | 
148 | 158 |             {  | 
149 |  | -                _logger.ServerProtocolVersionMismatch(EndpointName, _options.ProtocolVersion, initializeResponse.ProtocolVersion);  | 
150 |  | -                throw new McpClientException($"Server protocol version mismatch. Expected {_options.ProtocolVersion}, got {initializeResponse.ProtocolVersion}");  | 
 | 159 | +                await _sessionTransport.DisposeAsync().ConfigureAwait(false);  | 
151 | 160 |             }  | 
152 | 161 | 
 
  | 
153 |  | -            // Send initialized notification  | 
154 |  | -            await SendMessageAsync(  | 
155 |  | -                new JsonRpcNotification { Method = "notifications/initialized" },  | 
156 |  | -                initializationCts.Token).ConfigureAwait(false);  | 
157 |  | -        }  | 
158 |  | -        catch (OperationCanceledException) when (initializationCts.IsCancellationRequested)  | 
159 |  | -        {  | 
160 |  | -            _logger.ClientInitializationTimeout(EndpointName);  | 
161 |  | -            throw new McpClientException("Initialization timed out");  | 
 | 162 | +            _connectCts?.Dispose();  | 
162 | 163 |         }  | 
163 | 164 |     }  | 
164 | 165 | }  | 
0 commit comments