Skip to content

Commit ab29442

Browse files
artromanraman-m
andauthored
#1509 #1683 Replace non-WS protocols for the 'ClientWebSocket' in WebSocketsProxyMiddleware (#1689)
* Update WebSocketsProxyMiddleware.cs Fix WebSocket for SignalR * Repalce url protocol after null check * small refactoring * Add error log when replacing protocol in WebSocketProxyMiddleware Co-authored-by: Raman Maksimchuk <[email protected]> * Fix build * Code review * Fix unit test * Refactor to remove hardcoded strings of schemes * Define public constants * Add unit tests --------- Co-authored-by: raman-m <[email protected]>
1 parent 190b001 commit ab29442

File tree

3 files changed

+109
-39
lines changed

3 files changed

+109
-39
lines changed

src/Ocelot/WebSockets/WebSocketsProxyMiddleware.cs

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using Ocelot.Configuration;
77
using Ocelot.Logging;
88
using Ocelot.Middleware;
9+
using Ocelot.Request.Middleware;
910
using System.Net.WebSockets;
1011

1112
namespace Ocelot.WebSockets
@@ -17,10 +18,14 @@ public class WebSocketsProxyMiddleware : OcelotMiddleware
1718
"Connection", "Host", "Upgrade",
1819
"Sec-WebSocket-Accept", "Sec-WebSocket-Protocol", "Sec-WebSocket-Key", "Sec-WebSocket-Version", "Sec-WebSocket-Extensions",
1920
};
21+
2022
private const int DefaultWebSocketBufferSize = 4096;
2123
private readonly RequestDelegate _next;
2224
private readonly IWebSocketsFactory _factory;
2325

26+
public const string IgnoredSslWarningFormat = $"You have ignored all SSL warnings by using {nameof(DownstreamRoute.DangerousAcceptAnyServerCertificateValidator)} for this downstream route! {nameof(DownstreamRoute.UpstreamPathTemplate)}: '{{0}}', {nameof(DownstreamRoute.DownstreamPathTemplate)}: '{{1}}'.";
27+
public const string InvalidSchemeWarningFormat = "Invalid scheme has detected which will be replaced! Scheme '{0}' of the downstream '{1}'.";
28+
2429
public WebSocketsProxyMiddleware(IOcelotLoggerFactory loggerFactory,
2530
RequestDelegate next,
2631
IWebSocketsFactory factory)
@@ -73,21 +78,26 @@ private static async Task PumpWebSocket(WebSocket source, WebSocket destination,
7378

7479
public async Task Invoke(HttpContext httpContext)
7580
{
76-
var uri = httpContext.Items.DownstreamRequest().ToUri();
81+
var downstreamRequest = httpContext.Items.DownstreamRequest();
7782
var downstreamRoute = httpContext.Items.DownstreamRoute();
78-
await Proxy(httpContext, uri, downstreamRoute);
83+
await Proxy(httpContext, downstreamRequest, downstreamRoute);
7984
}
8085

81-
private async Task Proxy(HttpContext context, string serverEndpoint, DownstreamRoute downstreamRoute)
86+
private async Task Proxy(HttpContext context, DownstreamRequest request, DownstreamRoute route)
8287
{
8388
if (context == null)
8489
{
8590
throw new ArgumentNullException(nameof(context));
8691
}
8792

88-
if (serverEndpoint == null)
93+
if (request == null)
94+
{
95+
throw new ArgumentNullException(nameof(request));
96+
}
97+
98+
if (route == null)
8999
{
90-
throw new ArgumentNullException(nameof(serverEndpoint));
100+
throw new ArgumentNullException(nameof(route));
91101
}
92102

93103
if (!context.WebSockets.IsWebSocketRequest)
@@ -97,10 +107,10 @@ private async Task Proxy(HttpContext context, string serverEndpoint, DownstreamR
97107

98108
var client = _factory.CreateClient(); // new ClientWebSocket();
99109

100-
if (downstreamRoute.DangerousAcceptAnyServerCertificateValidator)
110+
if (route.DangerousAcceptAnyServerCertificateValidator)
101111
{
102112
client.Options.RemoteCertificateValidationCallback = (request, certificate, chain, errors) => true;
103-
Logger.LogWarning($"You have ignored all SSL warnings by using {nameof(DownstreamRoute.DangerousAcceptAnyServerCertificateValidator)} for this downstream route! {nameof(DownstreamRoute.UpstreamPathTemplate)}: '{downstreamRoute.UpstreamPathTemplate}', {nameof(DownstreamRoute.DownstreamPathTemplate)}: '{downstreamRoute.DownstreamPathTemplate}'.");
113+
Logger.LogWarning(string.Format(IgnoredSslWarningFormat, route.UpstreamPathTemplate, route.DownstreamPathTemplate));
104114
}
105115

106116
foreach (var protocol in context.WebSockets.WebSocketRequestedProtocols)
@@ -125,7 +135,16 @@ private async Task Proxy(HttpContext context, string serverEndpoint, DownstreamR
125135
}
126136
}
127137

128-
var destinationUri = new Uri(serverEndpoint);
138+
// Only Uris starting with 'ws://' or 'wss://' are supported in System.Net.WebSockets.ClientWebSocket
139+
var scheme = request.Scheme;
140+
if (!scheme.StartsWith(Uri.UriSchemeWs))
141+
{
142+
Logger.LogWarning(string.Format(InvalidSchemeWarningFormat, scheme, request.ToUri()));
143+
request.Scheme = scheme == Uri.UriSchemeHttp ? Uri.UriSchemeWs
144+
: scheme == Uri.UriSchemeHttps ? Uri.UriSchemeWss : scheme;
145+
}
146+
147+
var destinationUri = new Uri(request.ToUri());
129148
await client.ConnectAsync(destinationUri, context.RequestAborted);
130149

131150
using (var server = await context.WebSockets.AcceptWebSocketAsync(client.SubProtocol))

test/Ocelot.UnitTests/WebSockets/MockWebSocket.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ protected virtual void Dispose(bool disposing)
161161
// // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
162162
// Dispose(disposing: false);
163163
// }
164-
165164
public override void Dispose()
166165
{
167166
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method

test/Ocelot.UnitTests/WebSockets/WebSocketsProxyMiddlewareTests.cs

Lines changed: 82 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ public class WebSocketsProxyMiddlewareTests
1919

2020
private readonly Mock<HttpContext> _context;
2121
private readonly Mock<IOcelotLogger> _logger;
22+
private readonly Mock<IClientWebSocket> _client;
2223

2324
public WebSocketsProxyMiddlewareTests()
2425
{
@@ -27,48 +28,48 @@ public WebSocketsProxyMiddlewareTests()
2728
_factory = new Mock<IWebSocketsFactory>();
2829

2930
_context = new Mock<HttpContext>();
31+
_context.SetupGet(x => x.WebSockets.IsWebSocketRequest).Returns(true);
32+
3033
_logger = new Mock<IOcelotLogger>();
3134
_loggerFactory.Setup(x => x.CreateLogger<WebSocketsProxyMiddleware>())
3235
.Returns(_logger.Object);
3336

3437
_middleware = new WebSocketsProxyMiddleware(_loggerFactory.Object, _next.Object, _factory.Object);
38+
39+
_client = new Mock<IClientWebSocket>();
40+
_factory.Setup(x => x.CreateClient()).Returns(_client.Object);
3541
}
3642

3743
[Fact]
38-
public void ShouldIgnoreAllSslWarnings_WhenDangerousAcceptAnyServerCertificateValidatorIsTrue()
44+
public void ShouldIgnoreAllSslWarningsWhenDangerousAcceptAnyServerCertificateValidatorIsTrue()
3945
{
40-
this.Given(x => x.GivenPropertyDangerousAcceptAnyServerCertificateValidator(true))
46+
List<object> actual = new();
47+
this.Given(x => x.GivenPropertyDangerousAcceptAnyServerCertificateValidator(true, actual))
4148
.And(x => x.AndDoNotSetupProtocolsAndHeaders())
42-
.And(x => x.AndDoNotConnectReally())
49+
.And(x => x.AndDoNotConnectReally(null))
4350
.When(x => x.WhenInvokeWithHttpContext())
44-
.Then(x => x.ThenIgnoredAllSslWarnings())
51+
.Then(x => x.ThenIgnoredAllSslWarnings(actual))
4552
.BDDfy();
4653
}
4754

48-
private void GivenPropertyDangerousAcceptAnyServerCertificateValidator(bool enabled)
55+
private void GivenPropertyDangerousAcceptAnyServerCertificateValidator(bool enabled, List<object> actual)
4956
{
50-
var request = new HttpRequestMessage(HttpMethod.Get, "http://localhost:80");
57+
var request = new HttpRequestMessage(HttpMethod.Get, $"{Uri.UriSchemeWs}://localhost:12345");
5158
var downstream = new DownstreamRequest(request);
5259
var route = new DownstreamRouteBuilder()
5360
.WithDangerousAcceptAnyServerCertificateValidator(enabled)
5461
.Build();
5562
_context.SetupGet(x => x.Items).Returns(new Dictionary<object, object>
56-
{
57-
{ "DownstreamRequest", downstream },
58-
{ "DownstreamRoute", route },
59-
});
60-
61-
_context.SetupGet(x => x.WebSockets.IsWebSocketRequest).Returns(true);
62-
63-
_client = new Mock<IClientWebSocket>();
64-
_factory.Setup(x => x.CreateClient()).Returns(_client.Object);
63+
{
64+
{ "DownstreamRequest", downstream },
65+
{ "DownstreamRoute", route },
66+
});
6567

6668
_client.SetupSet(x => x.Options.RemoteCertificateValidationCallback = It.IsAny<RemoteCertificateValidationCallback>())
67-
.Callback<RemoteCertificateValidationCallback>(value => _callback = value);
69+
.Callback<RemoteCertificateValidationCallback>(actual.Add);
6870

69-
_warning = string.Empty;
7071
_logger.Setup(x => x.LogWarning(It.IsAny<string>()))
71-
.Callback<string>(message => _warning = message);
72+
.Callback<string>(actual.Add);
7273
}
7374

7475
private void AndDoNotSetupProtocolsAndHeaders()
@@ -77,9 +78,11 @@ private void AndDoNotSetupProtocolsAndHeaders()
7778
_context.SetupGet(x => x.Request.Headers).Returns(new HeaderDictionary());
7879
}
7980

80-
private void AndDoNotConnectReally()
81+
private void AndDoNotConnectReally(Action<Uri, CancellationToken> callbackConnectAsync)
8182
{
82-
_client.Setup(x => x.ConnectAsync(It.IsAny<Uri>(), It.IsAny<CancellationToken>())).Verifiable();
83+
Action<Uri, CancellationToken> doNothing = (u, t) => { };
84+
_client.Setup(x => x.ConnectAsync(It.IsAny<Uri>(), It.IsAny<CancellationToken>()))
85+
.Callback(callbackConnectAsync ?? doNothing);
8386
var clientSocket = new Mock<WebSocket>();
8487
var serverSocket = new Mock<WebSocket>();
8588
_client.Setup(x => x.ToWebSocket()).Returns(clientSocket.Object);
@@ -97,28 +100,77 @@ private void AndDoNotConnectReally()
97100
serverSocket.SetupGet(x => x.CloseStatus).Returns(WebSocketCloseStatus.Empty);
98101
}
99102

100-
private Mock<IClientWebSocket> _client;
101-
private RemoteCertificateValidationCallback _callback;
102-
private string _warning;
103-
104103
private async Task WhenInvokeWithHttpContext()
105104
{
106105
await _middleware.Invoke(_context.Object);
107106
}
108107

109-
private void ThenIgnoredAllSslWarnings()
108+
private void ThenIgnoredAllSslWarnings(List<object> actual)
110109
{
111-
_context.Object.Items.DownstreamRoute().DangerousAcceptAnyServerCertificateValidator
112-
.ShouldBeTrue();
110+
var route = _context.Object.Items.DownstreamRoute();
111+
var request = _context.Object.Items.DownstreamRequest();
112+
route.DangerousAcceptAnyServerCertificateValidator.ShouldBeTrue();
113113

114114
_logger.Verify(x => x.LogWarning(It.IsAny<string>()), Times.Once());
115-
_warning.ShouldNotBeNullOrEmpty();
115+
var warning = actual.Last() as string;
116+
warning.ShouldNotBeNullOrEmpty();
117+
var expectedWarning = string.Format(WebSocketsProxyMiddleware.IgnoredSslWarningFormat, route.UpstreamPathTemplate, route.DownstreamPathTemplate);
118+
warning.ShouldBe(expectedWarning);
116119

117120
_client.VerifySet(x => x.Options.RemoteCertificateValidationCallback = It.IsAny<RemoteCertificateValidationCallback>(),
118121
Times.Once());
119122

120-
_callback.ShouldNotBeNull();
121-
var validation = _callback.Invoke(null, null, null, SslPolicyErrors.None);
123+
var callback = actual.First() as RemoteCertificateValidationCallback;
124+
callback.ShouldNotBeNull();
125+
var validation = callback.Invoke(null, null, null, SslPolicyErrors.None);
122126
validation.ShouldBeTrue();
123127
}
128+
129+
[Theory]
130+
[InlineData("http", "ws")]
131+
[InlineData("https", "wss")]
132+
[InlineData("ftp", "ftp")]
133+
public void ShouldReplaceNonWsSchemes(string scheme, string expectedScheme)
134+
{
135+
List<object> actual = new();
136+
this.Given(x => x.GivenNonWebsocketScheme(scheme, actual))
137+
.And(x => x.AndDoNotSetupProtocolsAndHeaders())
138+
.And(x => x.AndDoNotConnectReally((uri, token) => actual.Add(uri)))
139+
.When(x => x.WhenInvokeWithHttpContext())
140+
.Then(x => x.ThenNonWsSchemesAreReplaced(scheme, expectedScheme, actual))
141+
.BDDfy();
142+
}
143+
144+
private void GivenNonWebsocketScheme(string scheme, List<object> actual)
145+
{
146+
var requestMessage = new HttpRequestMessage(HttpMethod.Get, $"{scheme}://localhost:12345");
147+
var request = new DownstreamRequest(requestMessage);
148+
var route = new DownstreamRouteBuilder().Build();
149+
var items = new Dictionary<object, object>
150+
{
151+
{ "DownstreamRequest", request },
152+
{ "DownstreamRoute", route },
153+
};
154+
_context.SetupGet(x => x.Items).Returns(items);
155+
156+
_logger.Setup(x => x.LogWarning(It.IsAny<string>()))
157+
.Callback<string>(actual.Add);
158+
}
159+
160+
private void ThenNonWsSchemesAreReplaced(string scheme, string expectedScheme, List<object> actual)
161+
{
162+
var route = _context.Object.Items.DownstreamRoute();
163+
var request = _context.Object.Items.DownstreamRequest();
164+
route.DangerousAcceptAnyServerCertificateValidator.ShouldBeFalse();
165+
166+
_logger.Verify(x => x.LogWarning(It.IsAny<string>()), Times.Once());
167+
var warning = actual.First() as string;
168+
warning.ShouldNotBeNullOrEmpty();
169+
warning.ShouldContain($"'{scheme}'");
170+
var expectedWarning = string.Format(WebSocketsProxyMiddleware.InvalidSchemeWarningFormat, scheme, request.ToUri().Replace(expectedScheme, scheme));
171+
warning.ShouldBe(expectedWarning);
172+
173+
request.Scheme.ShouldBe(expectedScheme);
174+
((Uri)actual.Last()).Scheme.ShouldBe(expectedScheme);
175+
}
124176
}

0 commit comments

Comments
 (0)