Skip to content

Commit 739d177

Browse files
authored
[dotnet] Refactor WebSocket communication for BiDi (#12614)
* [dotnet] Refactor WebSocket communication for BiDi Replaces the existing WebSocket communication mechanism with one more robust. Rather than immediately relying on event handlers to react to events and command reponses, it writes the incoming data to a queue which is read from a different thread. This eliminates the issue where the user might have multiple simultaneous sends or receives to the WebSocket while their event handler is running. It also dispatches incoming events on different threads for the same reason. This should eliminate at least some of the issues surrounding socket communication with bidirectional features, whether implemented using CDP or the WebDriver BiDi protocol. * Address review comments * Removing use of System.Threading.Channels * nitpick: fix XML doc comment * Simplify WebSocket message queue processing code * Omit added test from Firefox * revert add of now unused nuget packages
1 parent cbda4dd commit 739d177

File tree

7 files changed

+422
-152
lines changed

7 files changed

+422
-152
lines changed

dotnet/src/webdriver/DevTools/DevToolsSession.cs

Lines changed: 57 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,7 @@
1919
using System;
2020
using System.Collections.Concurrent;
2121
using System.Globalization;
22-
using System.IO;
2322
using System.Net.Http;
24-
using System.Net.WebSockets;
25-
using System.Text;
2623
using System.Threading;
2724
using System.Threading.Tasks;
2825
using Newtonsoft.Json;
@@ -50,15 +47,14 @@ public class DevToolsSession : IDevToolsSession
5047
private bool isDisposed = false;
5148
private string attachedTargetId;
5249

53-
private ClientWebSocket sessionSocket;
50+
private WebSocketConnection connection;
5451
private ConcurrentDictionary<long, DevToolsCommandData> pendingCommands = new ConcurrentDictionary<long, DevToolsCommandData>();
52+
private readonly BlockingCollection<string> messageQueue = new BlockingCollection<string>();
53+
private readonly Task messageQueueMonitorTask;
5554
private long currentCommandId = 0;
5655

5756
private DevToolsDomains domains;
5857

59-
private CancellationTokenSource receiveCancellationToken;
60-
private Task receiveTask;
61-
6258
/// <summary>
6359
/// Initializes a new instance of the DevToolsSession class, using the specified WebSocket endpoint.
6460
/// </summary>
@@ -76,6 +72,8 @@ public DevToolsSession(string endpointAddress)
7672
{
7773
this.websocketAddress = endpointAddress;
7874
}
75+
this.messageQueueMonitorTask = Task.Run(() => this.MonitorMessageQueue());
76+
this.messageQueueMonitorTask.ConfigureAwait(false);
7977
}
8078

8179
/// <summary>
@@ -213,15 +211,13 @@ public T GetVersionSpecificDomains<T>() where T : DevToolsSessionDomains
213211

214212
var message = new DevToolsCommandData(Interlocked.Increment(ref this.currentCommandId), this.ActiveSessionId, commandName, commandParameters);
215213

216-
if (this.sessionSocket != null && this.sessionSocket.State == WebSocketState.Open)
214+
if (this.connection != null && this.connection.IsActive)
217215
{
218216
LogTrace("Sending {0} {1}: {2}", message.CommandId, message.CommandName, commandParameters.ToString());
219217

220-
var contents = JsonConvert.SerializeObject(message);
221-
var contentBuffer = Encoding.UTF8.GetBytes(contents);
222-
218+
string contents = JsonConvert.SerializeObject(message);
223219
this.pendingCommands.TryAdd(message.CommandId, message);
224-
await this.sessionSocket.SendAsync(new ArraySegment<byte>(contentBuffer), WebSocketMessageType.Text, true, cancellationToken);
220+
await this.connection.SendData(contents);
225221

226222
var responseWasReceived = await Task.Run(() => message.SyncEvent.Wait(millisecondsTimeout.Value, cancellationToken));
227223

@@ -230,8 +226,7 @@ public T GetVersionSpecificDomains<T>() where T : DevToolsSessionDomains
230226
throw new InvalidOperationException($"A command response was not received: {commandName}");
231227
}
232228

233-
DevToolsCommandData modified;
234-
if (this.pendingCommands.TryRemove(message.CommandId, out modified))
229+
if (this.pendingCommands.TryRemove(message.CommandId, out DevToolsCommandData modified))
235230
{
236231
if (modified.IsError)
237232
{
@@ -256,10 +251,7 @@ public T GetVersionSpecificDomains<T>() where T : DevToolsSessionDomains
256251
}
257252
else
258253
{
259-
if (this.sessionSocket != null)
260-
{
261-
LogTrace("WebSocket is not connected (current state is {0}); not sending {1}", this.sessionSocket.State, message.CommandName);
262-
}
254+
LogTrace("WebSocket is not connected; not sending {0}", message.CommandName);
263255
}
264256

265257
return null;
@@ -330,11 +322,7 @@ protected void Dispose(bool disposing)
330322
{
331323
this.Domains.Target.TargetDetached -= this.OnTargetDetached;
332324
this.pendingCommands.Clear();
333-
this.TerminateSocketConnection();
334-
335-
// Note: Canceling the receive task will dispose of
336-
// the underlying ClientWebSocket instance.
337-
this.CancelReceiveTask();
325+
this.TerminateSocketConnection().GetAwaiter().GetResult();
338326
}
339327

340328
this.isDisposed = true;
@@ -377,28 +365,6 @@ private async Task<int> InitializeProtocol(int requestedProtocolVersion)
377365
return protocolVersion;
378366
}
379367

380-
private async Task InitializeSocketConnection()
381-
{
382-
LogTrace("Creating WebSocket");
383-
this.sessionSocket = new ClientWebSocket();
384-
this.sessionSocket.Options.KeepAliveInterval = TimeSpan.Zero;
385-
386-
try
387-
{
388-
var timeoutTokenSource = new CancellationTokenSource(this.openConnectionWaitTimeSpan);
389-
await this.sessionSocket.ConnectAsync(new Uri(this.websocketAddress), timeoutTokenSource.Token);
390-
while (this.sessionSocket.State != WebSocketState.Open && !timeoutTokenSource.Token.IsCancellationRequested) ;
391-
}
392-
catch (OperationCanceledException e)
393-
{
394-
throw new WebDriverException(string.Format(CultureInfo.InvariantCulture, "Could not establish WebSocket connection within {0} seconds.", this.openConnectionWaitTimeSpan.TotalSeconds), e);
395-
}
396-
397-
LogTrace("WebSocket created; starting message listener");
398-
this.receiveCancellationToken = new CancellationTokenSource();
399-
this.receiveTask = Task.Run(() => ReceiveMessage().ConfigureAwait(false));
400-
}
401-
402368
private async Task InitializeSession()
403369
{
404370
LogTrace("Creating session");
@@ -445,116 +411,56 @@ private void OnTargetDetached(object sender, TargetDetachedEventArgs e)
445411
}
446412
}
447413

448-
private void TerminateSocketConnection()
414+
private async Task InitializeSocketConnection()
449415
{
450-
if (this.sessionSocket != null && this.sessionSocket.State == WebSocketState.Open)
451-
{
452-
var closeConnectionTokenSource = new CancellationTokenSource(this.closeConnectionWaitTimeSpan);
453-
try
454-
{
455-
// Since Chromium-based DevTools does not respond to the close
456-
// request with a correctly echoed WebSocket close packet, but
457-
// rather just terminates the socket connection, so we have to
458-
// catch the exception thrown when the socket is terminated
459-
// unexpectedly. Also, because we are using async, waiting for
460-
// the task to complete might throw a TaskCanceledException,
461-
// which we should also catch. Additiionally, there are times
462-
// when mulitple failure modes can be seen, which will throw an
463-
// AggregateException, consolidating several exceptions into one,
464-
// and this too must be caught. Finally, the call to CloseAsync
465-
// will hang even though the connection is already severed.
466-
// Wait for the task to complete for a short time (since we're
467-
// restricted to localhost, the default of 2 seconds should be
468-
// plenty; if not, change the initialization of the timout),
469-
// and if the task is still running, then we assume the connection
470-
// is properly closed.
471-
LogTrace("Sending socket close request");
472-
Task closeTask = Task.Run(async () => await this.sessionSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, string.Empty, closeConnectionTokenSource.Token));
473-
closeTask.Wait();
474-
}
475-
catch (WebSocketException)
476-
{
477-
}
478-
catch (TaskCanceledException)
479-
{
480-
}
481-
catch (AggregateException)
482-
{
483-
}
484-
}
416+
LogTrace("Creating WebSocket");
417+
this.connection = new WebSocketConnection(this.openConnectionWaitTimeSpan, this.closeConnectionWaitTimeSpan);
418+
connection.DataReceived += OnConnectionDataReceived;
419+
await connection.Start(this.websocketAddress);
420+
LogTrace("WebSocket created");
485421
}
486422

487-
private void CancelReceiveTask()
423+
private async Task TerminateSocketConnection()
488424
{
489-
if (this.receiveTask != null)
425+
LogTrace("Closing WebSocket");
426+
if (this.connection != null && this.connection.IsActive)
490427
{
491-
// Wait for the recieve task to be completely exited (for
492-
// whatever reason) before attempting to dispose it. Also
493-
// note that canceling the receive task will dispose of the
494-
// underlying WebSocket.
495-
this.receiveCancellationToken.Cancel();
496-
this.receiveTask.Wait();
497-
this.receiveTask.Dispose();
498-
this.receiveTask = null;
428+
await this.connection.Stop();
429+
await this.ShutdownMessageQueue();
499430
}
431+
LogTrace("WebSocket closed");
500432
}
501433

502-
private async Task ReceiveMessage()
434+
private async Task ShutdownMessageQueue()
503435
{
504-
var cancellationToken = this.receiveCancellationToken.Token;
505-
try
506-
{
507-
var buffer = WebSocket.CreateClientBuffer(1024, 1024);
508-
while (this.sessionSocket.State != WebSocketState.Closed && !cancellationToken.IsCancellationRequested)
509-
{
510-
WebSocketReceiveResult result = await this.sessionSocket.ReceiveAsync(buffer, cancellationToken);
511-
if (!cancellationToken.IsCancellationRequested)
512-
{
513-
if (result.MessageType == WebSocketMessageType.Close && this.sessionSocket.State == WebSocketState.CloseReceived)
514-
{
515-
LogTrace("Got WebSocket close message from browser");
516-
await this.sessionSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, string.Empty, cancellationToken);
517-
}
518-
}
519-
520-
if (this.sessionSocket.State == WebSocketState.Open && result.MessageType != WebSocketMessageType.Close)
521-
{
522-
using (var stream = new MemoryStream())
523-
{
524-
stream.Write(buffer.Array, 0, result.Count);
525-
while (!result.EndOfMessage)
526-
{
527-
result = await this.sessionSocket.ReceiveAsync(buffer, cancellationToken);
528-
stream.Write(buffer.Array, 0, result.Count);
529-
}
530-
531-
stream.Seek(0, SeekOrigin.Begin);
532-
using (var reader = new StreamReader(stream, Encoding.UTF8))
533-
{
534-
string message = reader.ReadToEnd();
535-
536-
// fire and forget
537-
// TODO: we need implement some kind of queue
538-
Task.Run(() => ProcessIncomingMessage(message));
539-
}
540-
}
541-
}
542-
}
543-
}
544-
catch (OperationCanceledException)
436+
// THe WebSockect connection is always closed before this method
437+
// is called, so there will eventually be no more data written
438+
// into the message queue, meaning this loop should be guaranteed
439+
// to complete.
440+
while (this.connection.IsActive)
545441
{
442+
await Task.Delay(TimeSpan.FromMilliseconds(10));
546443
}
547-
catch (WebSocketException)
548-
{
549-
}
550-
finally
444+
445+
this.messageQueue.CompleteAdding();
446+
await this.messageQueueMonitorTask;
447+
}
448+
449+
private void MonitorMessageQueue()
450+
{
451+
// GetConsumingEnumerable blocks until if BlockingCollection.IsCompleted
452+
// is false (i.e., is still able to be written to), and there are no items
453+
// in the collection. Once any items are added to the collection, the method
454+
// unblocks and we can process any items in the collection at that moment.
455+
// Once IsCompleted is true, the method unblocks with no items in returned
456+
// in the IEnumerable, meaning the foreach loop will terminate gracefully.
457+
foreach (string message in this.messageQueue.GetConsumingEnumerable())
551458
{
552-
this.sessionSocket.Dispose();
553-
this.sessionSocket = null;
459+
this.ProcessMessage(message);
554460
}
555461
}
556462

557-
private void ProcessIncomingMessage(string message)
463+
private void ProcessMessage(string message)
558464
{
559465
var messageObject = JObject.Parse(message);
560466

@@ -594,7 +500,12 @@ private void ProcessIncomingMessage(string message)
594500

595501
LogTrace("Recieved Event {0}: {1}", method, eventData.ToString());
596502

597-
OnDevToolsEventReceived(new DevToolsEventReceivedEventArgs(methodParts[0], methodParts[1], eventData));
503+
// Dispatch the event on a new thread so that any event handlers
504+
// responding to the event will not block this thread from processing
505+
// DevTools commands that may be sent in the body of the attached
506+
// event handler. If thread pool starvation seems to become a problem,
507+
// we can switch to a channel-based queue.
508+
Task.Run(() => OnDevToolsEventReceived(new DevToolsEventReceivedEventArgs(methodParts[0], methodParts[1], eventData)));
598509

599510
return;
600511
}
@@ -610,6 +521,11 @@ private void OnDevToolsEventReceived(DevToolsEventReceivedEventArgs e)
610521
}
611522
}
612523

524+
private void OnConnectionDataReceived(object sender, WebSocketConnectionDataReceivedEventArgs e)
525+
{
526+
this.messageQueue.Add(e.Data);
527+
}
528+
613529
private void LogTrace(string message, params object[] args)
614530
{
615531
if (LogMessage != null)

0 commit comments

Comments
 (0)