Skip to content
94 changes: 52 additions & 42 deletions dotnet/src/webdriver/Remote/HttpCommandExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
using System.Threading;
using System.Threading.Tasks;

#nullable enable

namespace OpenQA.Selenium.Remote
{
/// <summary>
Expand All @@ -47,7 +49,8 @@ public class HttpCommandExecutor : ICommandExecutor
private readonly TimeSpan serverResponseTimeout;
private bool isDisposed;
private CommandInfoRepository commandInfoRepository = new W3CWireProtocolCommandInfoRepository();
private HttpClient client;
private HttpClient? client;
private readonly object _createClientLock = new();

private static readonly ILogger _logger = Log.GetLogger<HttpCommandExecutor>();

Expand All @@ -56,6 +59,7 @@ public class HttpCommandExecutor : ICommandExecutor
/// </summary>
/// <param name="addressOfRemoteServer">Address of the WebDriver Server</param>
/// <param name="timeout">The timeout within which the server must respond.</param>
/// <exception cref="ArgumentNullException">If <paramref name="addressOfRemoteServer"/> is <see langword="null"/>.</exception>
public HttpCommandExecutor(Uri addressOfRemoteServer, TimeSpan timeout)
: this(addressOfRemoteServer, timeout, true)
{
Expand Down Expand Up @@ -91,14 +95,14 @@ public HttpCommandExecutor(Uri addressOfRemoteServer, TimeSpan timeout, bool ena
/// Occurs when the <see cref="HttpCommandExecutor"/> is sending an HTTP
/// request to the remote end WebDriver implementation.
/// </summary>
public event EventHandler<SendingRemoteHttpRequestEventArgs> SendingRemoteHttpRequest;
public event EventHandler<SendingRemoteHttpRequestEventArgs>? SendingRemoteHttpRequest;

/// <summary>
/// Gets or sets an <see cref="IWebProxy"/> object to be used to proxy requests
/// between this <see cref="HttpCommandExecutor"/> and the remote end WebDriver
/// implementation.
/// </summary>
public IWebProxy Proxy { get; set; }
public IWebProxy? Proxy { get; set; }

/// <summary>
/// Gets or sets a value indicating whether keep-alive is enabled for HTTP
Expand Down Expand Up @@ -167,17 +171,12 @@ public virtual async Task<Response> ExecuteAsync(Command commandToExecute)
_logger.Debug($"Executing command: {commandToExecute}");
}

HttpCommandInfo info = this.commandInfoRepository.GetCommandInfo<HttpCommandInfo>(commandToExecute.Name);
HttpCommandInfo? info = this.commandInfoRepository.GetCommandInfo<HttpCommandInfo>(commandToExecute.Name);
if (info == null)
{
throw new NotImplementedException(string.Format("The command you are attempting to execute, {0}, does not exist in the protocol dialect used by the remote end.", commandToExecute.Name));
}

if (this.client == null)
{
this.CreateHttpClient();
}

HttpRequestInfo requestInfo = new HttpRequestInfo(this.remoteServerUri, commandToExecute, info);
HttpResponseInfo responseInfo;
try
Expand Down Expand Up @@ -216,42 +215,55 @@ protected virtual void OnSendingRemoteHttpRequest(SendingRemoteHttpRequestEventA
throw new ArgumentNullException(nameof(eventArgs), "eventArgs must not be null");
}

if (this.SendingRemoteHttpRequest != null)
{
this.SendingRemoteHttpRequest(this, eventArgs);
}
this.SendingRemoteHttpRequest?.Invoke(this, eventArgs);
}

private void CreateHttpClient()
private HttpClient Client
{
HttpClientHandler httpClientHandler = new HttpClientHandler();
string userInfo = this.remoteServerUri.UserInfo;
if (!string.IsNullOrEmpty(userInfo) && userInfo.Contains(":"))
{
string[] userInfoComponents = this.remoteServerUri.UserInfo.Split(new char[] { ':' }, 2);
httpClientHandler.Credentials = new NetworkCredential(userInfoComponents[0], userInfoComponents[1]);
httpClientHandler.PreAuthenticate = true;
}

httpClientHandler.Proxy = this.Proxy;

HttpMessageHandler handler = httpClientHandler;

if (_logger.IsEnabled(LogEventLevel.Trace))
get
{
handler = new DiagnosticsHttpHandler(httpClientHandler, _logger);
}
if (this.client is null)
{
lock (_createClientLock)
{
if (this.client is null)
{
HttpClientHandler httpClientHandler = new HttpClientHandler();
string userInfo = this.remoteServerUri.UserInfo;
if (!string.IsNullOrEmpty(userInfo) && userInfo.Contains(":"))
{
string[] userInfoComponents = this.remoteServerUri.UserInfo.Split(new char[] { ':' }, 2);
httpClientHandler.Credentials = new NetworkCredential(userInfoComponents[0], userInfoComponents[1]);
httpClientHandler.PreAuthenticate = true;
}

httpClientHandler.Proxy = this.Proxy;

HttpMessageHandler handler = httpClientHandler;

if (_logger.IsEnabled(LogEventLevel.Trace))
{
handler = new DiagnosticsHttpHandler(httpClientHandler, _logger);
}

var client = new HttpClient(handler);
client.DefaultRequestHeaders.UserAgent.ParseAdd(this.UserAgent);
client.DefaultRequestHeaders.Accept.ParseAdd(RequestAcceptHeader);
client.DefaultRequestHeaders.ExpectContinue = false;
if (!this.IsKeepAliveEnabled)
{
client.DefaultRequestHeaders.Connection.ParseAdd("close");
}

client.Timeout = this.serverResponseTimeout;

this.client = client;
}
}
}

this.client = new HttpClient(handler);
this.client.DefaultRequestHeaders.UserAgent.ParseAdd(this.UserAgent);
this.client.DefaultRequestHeaders.Accept.ParseAdd(RequestAcceptHeader);
this.client.DefaultRequestHeaders.ExpectContinue = false;
if (!this.IsKeepAliveEnabled)
{
this.client.DefaultRequestHeaders.Connection.ParseAdd("close");
return this.client;
}

this.client.Timeout = this.serverResponseTimeout;
}

private async Task<HttpResponseInfo> MakeHttpRequest(HttpRequestInfo requestInfo)
Expand Down Expand Up @@ -288,7 +300,7 @@ private async Task<HttpResponseInfo> MakeHttpRequest(HttpRequestInfo requestInfo
requestMessage.Content.Headers.ContentType = contentTypeHeader;
}

using (HttpResponseMessage responseMessage = await this.client.SendAsync(requestMessage).ConfigureAwait(false))
using (HttpResponseMessage responseMessage = await this.Client.SendAsync(requestMessage).ConfigureAwait(false))
{
var responseBody = await responseMessage.Content.ReadAsStringAsync().ConfigureAwait(false);
var responseContentType = responseMessage.Content.Headers.ContentType?.ToString();
Expand Down Expand Up @@ -331,8 +343,6 @@ private Response CreateResponse(HttpResponseInfo responseInfo)
return response;
}

#nullable enable

/// <summary>
/// Releases all resources used by the <see cref="HttpCommandExecutor"/>.
/// </summary>
Expand Down
Loading