Skip to content

Commit 65e64d2

Browse files
CSHARP-4552: Fix speculative authentication. (#1037)
1 parent cbc4610 commit 65e64d2

File tree

6 files changed

+159
-49
lines changed

6 files changed

+159
-49
lines changed

src/MongoDB.Driver.Core/Core/Connections/BinaryConnection.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,9 @@ private void OpenHelper(CancellationToken cancellationToken)
267267
helper.OpeningConnection();
268268
_stream = _streamFactory.CreateStream(_endPoint, cancellationToken);
269269
helper.InitializingConnection();
270-
handshakeDescription = _connectionInitializer.SendHello(this, cancellationToken);
271-
_description = _connectionInitializer.Authenticate(this, handshakeDescription, cancellationToken);
270+
var connectionInitializerContext = _connectionInitializer.SendHello(this, cancellationToken);
271+
handshakeDescription = connectionInitializerContext.Description;
272+
_description = _connectionInitializer.Authenticate(this, connectionInitializerContext, cancellationToken);
272273
_sendCompressorType = ChooseSendCompressorTypeIfAny(_description);
273274

274275
helper.OpenedConnection();
@@ -292,8 +293,9 @@ private async Task OpenHelperAsync(CancellationToken cancellationToken)
292293
helper.OpeningConnection();
293294
_stream = await _streamFactory.CreateStreamAsync(_endPoint, cancellationToken).ConfigureAwait(false);
294295
helper.InitializingConnection();
295-
handshakeDescription = await _connectionInitializer.SendHelloAsync(this, cancellationToken).ConfigureAwait(false);
296-
_description = await _connectionInitializer.AuthenticateAsync(this, handshakeDescription, cancellationToken).ConfigureAwait(false);
296+
var connectionInitializerContext = await _connectionInitializer.SendHelloAsync(this, cancellationToken).ConfigureAwait(false);
297+
handshakeDescription = connectionInitializerContext.Description;
298+
_description = await _connectionInitializer.AuthenticateAsync(this, connectionInitializerContext, cancellationToken).ConfigureAwait(false);
297299
_sendCompressorType = ChooseSendCompressorTypeIfAny(_description);
298300

299301
helper.OpenedConnection();

src/MongoDB.Driver.Core/Core/Connections/ConnectionInitializer.cs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,13 @@ public ConnectionInitializer(
4646
_serverApi = serverApi;
4747
}
4848

49-
public ConnectionDescription Authenticate(IConnection connection, ConnectionDescription description, CancellationToken cancellationToken)
49+
public ConnectionDescription Authenticate(IConnection connection, ConnectionInitializerContext connectionInitializerContext, CancellationToken cancellationToken)
5050
{
5151
Ensure.IsNotNull(connection, nameof(connection));
52-
Ensure.IsNotNull(description, nameof(description));
52+
Ensure.IsNotNull(connectionInitializerContext, nameof(connectionInitializerContext));
53+
var authenticators = Ensure.IsNotNull(connectionInitializerContext.Authenticators, nameof(connectionInitializerContext.Authenticators));
54+
var description = Ensure.IsNotNull(connectionInitializerContext.Description, nameof(connectionInitializerContext.Description));
5355

54-
var authenticators = GetAuthenticators(connection.Settings);
5556
AuthenticationHelper.Authenticate(connection, description, authenticators, cancellationToken);
5657

5758
var connectionIdServerValue = description.HelloResult.ConnectionIdServerValue;
@@ -77,12 +78,13 @@ public ConnectionDescription Authenticate(IConnection connection, ConnectionDesc
7778
return description;
7879
}
7980

80-
public async Task<ConnectionDescription> AuthenticateAsync(IConnection connection, ConnectionDescription description, CancellationToken cancellationToken)
81+
public async Task<ConnectionDescription> AuthenticateAsync(IConnection connection, ConnectionInitializerContext connectionInitializerContext, CancellationToken cancellationToken)
8182
{
8283
Ensure.IsNotNull(connection, nameof(connection));
83-
Ensure.IsNotNull(description, nameof(description));
84+
Ensure.IsNotNull(connectionInitializerContext, nameof(connectionInitializerContext));
85+
var authenticators = Ensure.IsNotNull(connectionInitializerContext.Authenticators, nameof(connectionInitializerContext.Authenticators));
86+
var description = Ensure.IsNotNull(connectionInitializerContext.Description, nameof(connectionInitializerContext.Description));
8487

85-
var authenticators = GetAuthenticators(connection.Settings);
8688
await AuthenticationHelper.AuthenticateAsync(connection, description, authenticators, cancellationToken).ConfigureAwait(false);
8789

8890
var connectionIdServerValue = description.HelloResult.ConnectionIdServerValue;
@@ -110,7 +112,7 @@ public async Task<ConnectionDescription> AuthenticateAsync(IConnection connectio
110112
return description;
111113
}
112114

113-
public ConnectionDescription SendHello(IConnection connection, CancellationToken cancellationToken)
115+
public ConnectionInitializerContext SendHello(IConnection connection, CancellationToken cancellationToken)
114116
{
115117
Ensure.IsNotNull(connection, nameof(connection));
116118
var authenticators = GetAuthenticators(connection.Settings);
@@ -122,10 +124,10 @@ public ConnectionDescription SendHello(IConnection connection, CancellationToken
122124
throw new InvalidOperationException("Driver attempted to initialize in load balancing mode, but the server does not support this mode.");
123125
}
124126

125-
return new ConnectionDescription(connection.ConnectionId, helloResult);
127+
return new (new ConnectionDescription(connection.ConnectionId, helloResult), authenticators);
126128
}
127129

128-
public async Task<ConnectionDescription> SendHelloAsync(IConnection connection, CancellationToken cancellationToken)
130+
public async Task<ConnectionInitializerContext> SendHelloAsync(IConnection connection, CancellationToken cancellationToken)
129131
{
130132
Ensure.IsNotNull(connection, nameof(connection));
131133
var authenticators = GetAuthenticators(connection.Settings);
@@ -137,7 +139,7 @@ public async Task<ConnectionDescription> SendHelloAsync(IConnection connection,
137139
throw new InvalidOperationException("Driver attempted to initialize in load balancing mode, but the server does not support this mode.");
138140
}
139141

140-
return new ConnectionDescription(connection.ConnectionId, helloResult);
142+
return new (new ConnectionDescription(connection.ConnectionId, helloResult), authenticators);
141143
}
142144

143145
// private methods

src/MongoDB.Driver.Core/Core/Connections/IConnectionInitializer.cs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2013-present MongoDB Inc.
1+
/* Copyright 2010-present MongoDB Inc.
22
*
33
* Licensed under the Apache License, Version 2.0 (the "License");
44
* you may not use this file except in compliance with the License.
@@ -13,16 +13,31 @@
1313
* limitations under the License.
1414
*/
1515

16+
using System.Collections.Generic;
1617
using System.Threading;
1718
using System.Threading.Tasks;
19+
using MongoDB.Driver.Core.Authentication;
20+
using MongoDB.Driver.Core.Misc;
1821

1922
namespace MongoDB.Driver.Core.Connections
2023
{
24+
internal sealed class ConnectionInitializerContext
25+
{
26+
public ConnectionInitializerContext(ConnectionDescription description, IReadOnlyList<IAuthenticator> authenticators)
27+
{
28+
Description = Ensure.IsNotNull(description, nameof(description));
29+
Authenticators = Ensure.IsNotNull(authenticators, nameof(authenticators));
30+
}
31+
32+
public IReadOnlyList<IAuthenticator> Authenticators { get; }
33+
public ConnectionDescription Description { get; }
34+
}
35+
2136
internal interface IConnectionInitializer
2237
{
23-
ConnectionDescription Authenticate(IConnection connection, ConnectionDescription description, CancellationToken cancellationToken);
24-
Task<ConnectionDescription> AuthenticateAsync(IConnection connection, ConnectionDescription description, CancellationToken cancellationToken);
25-
ConnectionDescription SendHello(IConnection connection, CancellationToken cancellationToken);
26-
Task<ConnectionDescription> SendHelloAsync(IConnection connection, CancellationToken cancellationToken);
38+
ConnectionDescription Authenticate(IConnection connection, ConnectionInitializerContext connectionInitializerContext, CancellationToken cancellationToken);
39+
Task<ConnectionDescription> AuthenticateAsync(IConnection connection, ConnectionInitializerContext connectionInitializerContext, CancellationToken cancellationToken);
40+
ConnectionInitializerContext SendHello(IConnection connection, CancellationToken cancellationToken);
41+
Task<ConnectionInitializerContext> SendHelloAsync(IConnection connection, CancellationToken cancellationToken);
2742
}
2843
}

tests/MongoDB.Driver.Core.Tests/Core/Connections/BinaryConnectionTests.cs

Lines changed: 94 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2013-present MongoDB Inc.
1+
/* Copyright 2010-present MongoDB Inc.
22
*
33
* Licensed under the Apache License, Version 2.0 (the "License");
44
* you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
1414
*/
1515

1616
using System;
17+
using System.Collections.Generic;
1718
using System.IO;
1819
using System.Net;
1920
using System.Net.Sockets;
@@ -23,6 +24,7 @@
2324
using MongoDB.Bson;
2425
using MongoDB.Bson.Serialization.Serializers;
2526
using MongoDB.TestHelpers.XunitExtensions;
27+
using MongoDB.Driver.Core.Authentication;
2628
using MongoDB.Driver.Core.Clusters;
2729
using MongoDB.Driver.Core.Configuration;
2830
using MongoDB.Driver.Core.Events;
@@ -32,19 +34,24 @@
3234
using MongoDB.Driver.Core.TestHelpers.Logging;
3335
using MongoDB.Driver.Core.WireProtocol.Messages;
3436
using MongoDB.Driver.Core.WireProtocol.Messages.Encoders;
37+
using MongoDB.Driver.Core.WireProtocol.Messages.Encoders.BinaryEncoders;
3538
using Moq;
3639
using Xunit;
3740

3841
namespace MongoDB.Driver.Core.Connections
3942
{
4043
public class BinaryConnectionTests : LoggableTestClass
4144
{
45+
private ConnectionInitializerContext _connectionInitializerContext;
4246
private Mock<IConnectionInitializer> _mockConnectionInitializer;
4347
private ConnectionDescription _connectionDescription;
48+
private readonly IReadOnlyList<IAuthenticator> __emptyAuthenticators = new IAuthenticator[0];
4449
private DnsEndPoint _endPoint;
4550
private EventCapturer _capturedEvents;
4651
private MessageEncoderSettings _messageEncoderSettings = new MessageEncoderSettings();
4752
private Mock<IStreamFactory> _mockStreamFactory;
53+
private readonly ServerId _serverId;
54+
4855
private BinaryConnection _subject;
4956

5057
public BinaryConnectionTests(Xunit.Abstractions.ITestOutputHelper output) : base(output)
@@ -53,27 +60,28 @@ public BinaryConnectionTests(Xunit.Abstractions.ITestOutputHelper output) : base
5360
_mockStreamFactory = new Mock<IStreamFactory>();
5461

5562
_endPoint = new DnsEndPoint("localhost", 27017);
56-
var serverId = new ServerId(new ClusterId(), _endPoint);
57-
var connectionId = new ConnectionId(serverId);
63+
_serverId = new ServerId(new ClusterId(), _endPoint);
64+
var connectionId = new ConnectionId(_serverId);
5865
var helloResult = new HelloResult(new BsonDocument { { "ok", 1 }, { "maxMessageSizeBytes", 48000000 }, { "maxWireVersion", WireVersion.Server36 } });
5966
_connectionDescription = new ConnectionDescription(connectionId, helloResult);
67+
_connectionInitializerContext = new ConnectionInitializerContext(_connectionDescription, __emptyAuthenticators);
6068

6169
_mockConnectionInitializer = new Mock<IConnectionInitializer>();
6270
_mockConnectionInitializer
6371
.Setup(i => i.SendHello(It.IsAny<IConnection>(), CancellationToken.None))
64-
.Returns(_connectionDescription);
72+
.Returns(_connectionInitializerContext);
6573
_mockConnectionInitializer
66-
.Setup(i => i.Authenticate(It.IsAny<IConnection>(), It.IsAny<ConnectionDescription>(), CancellationToken.None))
74+
.Setup(i => i.Authenticate(It.IsAny<IConnection>(), It.IsAny<ConnectionInitializerContext>(), CancellationToken.None))
6775
.Returns(_connectionDescription);
6876
_mockConnectionInitializer
6977
.Setup(i => i.SendHelloAsync(It.IsAny<IConnection>(), CancellationToken.None))
70-
.ReturnsAsync(_connectionDescription);
78+
.ReturnsAsync(_connectionInitializerContext);
7179
_mockConnectionInitializer
72-
.Setup(i => i.AuthenticateAsync(It.IsAny<IConnection>(), It.IsAny<ConnectionDescription>(), CancellationToken.None))
80+
.Setup(i => i.AuthenticateAsync(It.IsAny<IConnection>(), It.IsAny<ConnectionInitializerContext>(), CancellationToken.None))
7381
.ReturnsAsync(_connectionDescription);
7482

7583
_subject = new BinaryConnection(
76-
serverId: serverId,
84+
serverId: _serverId,
7785
endPoint: _endPoint,
7886
settings: new ConnectionSettings(),
7987
streamFactory: _mockStreamFactory.Object,
@@ -94,8 +102,7 @@ public void Dispose_should_raise_the_correct_events()
94102

95103
[Theory]
96104
[ParameterAttributeData]
97-
public void Open_should_always_create_description_if_handshake_was_successful(
98-
[Values(false, true)] bool async)
105+
public void Open_should_always_create_description_if_handshake_was_successful([Values(false, true)] bool async)
99106
{
100107
var serviceId = ObjectId.GenerateNewId();
101108
var connectionDescription = new ConnectionDescription(
@@ -105,15 +112,15 @@ public void Open_should_always_create_description_if_handshake_was_successful(
105112
var socketException = new SocketException();
106113
_mockConnectionInitializer
107114
.Setup(i => i.SendHello(It.IsAny<IConnection>(), CancellationToken.None))
108-
.Returns(connectionDescription);
115+
.Returns(new ConnectionInitializerContext(connectionDescription, __emptyAuthenticators));
109116
_mockConnectionInitializer
110117
.Setup(i => i.SendHelloAsync(It.IsAny<IConnection>(), CancellationToken.None))
111-
.ReturnsAsync(connectionDescription);
118+
.ReturnsAsync(new ConnectionInitializerContext(connectionDescription, __emptyAuthenticators));
112119
_mockConnectionInitializer
113-
.Setup(i => i.Authenticate(It.IsAny<IConnection>(), It.IsAny<ConnectionDescription>(), CancellationToken.None))
120+
.Setup(i => i.Authenticate(It.IsAny<IConnection>(), It.IsAny<ConnectionInitializerContext>(), CancellationToken.None))
114121
.Throws(socketException);
115122
_mockConnectionInitializer
116-
.Setup(i => i.AuthenticateAsync(It.IsAny<IConnection>(), It.IsAny<ConnectionDescription>(), CancellationToken.None))
123+
.Setup(i => i.AuthenticateAsync(It.IsAny<IConnection>(), It.IsAny<ConnectionInitializerContext>(), CancellationToken.None))
117124
.ThrowsAsync(socketException);
118125

119126
Exception exception;
@@ -131,6 +138,68 @@ public void Open_should_always_create_description_if_handshake_was_successful(
131138
ex.InnerException.Should().BeOfType<SocketException>();
132139
}
133140

141+
[Theory]
142+
[ParameterAttributeData]
143+
public async Task Open_should_create_authenticators_only_once(
144+
[Values(false, true)] bool async)
145+
{
146+
var connectionDescription = new ConnectionDescription(
147+
new ConnectionId(new ServerId(new ClusterId(), _endPoint)),
148+
new HelloResult(new BsonDocument()));
149+
150+
using var memoryStream = new MemoryStream();
151+
var clonedMessageEncoderSettings = _messageEncoderSettings.Clone();
152+
var encoderFactory = new BinaryMessageEncoderFactory(memoryStream, clonedMessageEncoderSettings, compressorSource: null);
153+
var encoder = encoderFactory.GetCommandResponseMessageEncoder();
154+
encoder.WriteMessage(CreateResponseMessage());
155+
var mockStreamFactory = new Mock<IStreamFactory>();
156+
using var stream = new IgnoreWritesMemoryStream(memoryStream.ToArray());
157+
mockStreamFactory
158+
.Setup(s => s.CreateStream(It.IsAny<EndPoint>(), CancellationToken.None))
159+
.Returns(stream);
160+
mockStreamFactory
161+
.Setup(s => s.CreateStreamAsync(It.IsAny<EndPoint>(), CancellationToken.None))
162+
.ReturnsAsync(stream);
163+
164+
var connectionInitializer = new ConnectionInitializer(
165+
null,
166+
new CompressorConfiguration[0],
167+
new ServerApi(ServerApiVersion.V1)); // use serverApi to choose command message protocol
168+
var authenticatorFactoryMock = new Mock<IAuthenticatorFactory>();
169+
authenticatorFactoryMock
170+
.Setup(a => a.Create())
171+
.Returns(Mock.Of<IAuthenticator>(a => a.CustomizeInitialHelloCommand(It.IsAny<BsonDocument>()) == new BsonDocument(OppressiveLanguageConstants.LegacyHelloCommandName, 1)));
172+
173+
using var subject = new BinaryConnection(
174+
serverId: _serverId,
175+
endPoint: _endPoint,
176+
settings: new ConnectionSettings(new[] { authenticatorFactoryMock.Object }),
177+
streamFactory: mockStreamFactory.Object,
178+
connectionInitializer: connectionInitializer,
179+
eventSubscriber: _capturedEvents,
180+
LoggerFactory);
181+
182+
if (async)
183+
{
184+
await subject.OpenAsync(CancellationToken.None);
185+
}
186+
else
187+
{
188+
subject.Open(CancellationToken.None);
189+
}
190+
191+
authenticatorFactoryMock.Verify(f => f.Create(), Times.Once());
192+
193+
ResponseMessage CreateResponseMessage()
194+
{
195+
var section0Document = $"{{ {OppressiveLanguageConstants.LegacyHelloResponseIsWritablePrimaryFieldName} : true, ok : 1, connectionId : 1 }}";
196+
var section0 = new Type0CommandMessageSection<RawBsonDocument>(
197+
new RawBsonDocument(BsonDocument.Parse(section0Document).ToBson()),
198+
RawBsonDocumentSerializer.Instance);
199+
return new CommandResponseMessage(new CommandMessage(1, RequestMessage.CurrentGlobalRequestId + 1, new[] { section0 }, false));
200+
}
201+
}
202+
134203
[Theory]
135204
[ParameterAttributeData]
136205
public void Open_should_throw_an_ObjectDisposedException_if_the_connection_is_disposed(
@@ -161,7 +230,7 @@ public void Open_should_raise_the_correct_events_upon_failure(
161230
Action act;
162231
if (async)
163232
{
164-
var result = new TaskCompletionSource<ConnectionDescription>();
233+
var result = new TaskCompletionSource<ConnectionInitializerContext>();
165234
result.SetException(new SocketException());
166235
_mockConnectionInitializer.Setup(i => i.SendHelloAsync(It.IsAny<IConnection>(), It.IsAny<CancellationToken>()))
167236
.Returns(result.Task);
@@ -803,5 +872,15 @@ public void SendMessages_should_put_the_messages_on_the_stream_and_raise_the_cor
803872
_capturedEvents.Any().Should().BeFalse();
804873
}
805874
}
875+
876+
// nested type
877+
private sealed class IgnoreWritesMemoryStream : MemoryStream
878+
{
879+
public IgnoreWritesMemoryStream(byte[] bytes) : base(bytes) { }
880+
public override void Write(byte[] buffer, int offset, int count)
881+
{
882+
// Do nothing
883+
}
884+
}
806885
}
807886
}

0 commit comments

Comments
 (0)