Skip to content

Commit dfe9435

Browse files
committed
CSHARP-2884: Async Socket methods do not use Socket.ReceiveTimeout.
1 parent 9a88686 commit dfe9435

File tree

3 files changed

+100
-32
lines changed

3 files changed

+100
-32
lines changed

src/MongoDB.Driver.Core/Core/Connections/BinaryConnection.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,14 +391,15 @@ private async Task<IByteBuffer> ReceiveBufferAsync()
391391
try
392392
{
393393
var messageSizeBytes = new byte[4];
394-
await _stream.ReadBytesAsync(messageSizeBytes, 0, 4, _backgroundTaskCancellationToken).ConfigureAwait(false);
394+
var readTimeout = _stream.CanTimeout ? TimeSpan.FromMilliseconds(_stream.ReadTimeout) : Timeout.InfiniteTimeSpan;
395+
await _stream.ReadBytesAsync(messageSizeBytes, 0, 4, readTimeout, _backgroundTaskCancellationToken).ConfigureAwait(false);
395396
var messageSize = BitConverter.ToInt32(messageSizeBytes, 0);
396397
EnsureMessageSizeIsValid(messageSize);
397398
var inputBufferChunkSource = new InputBufferChunkSource(BsonChunkPool.Default);
398399
var buffer = ByteBufferFactory.Create(inputBufferChunkSource, messageSize);
399400
buffer.Length = messageSize;
400401
buffer.SetBytes(0, messageSizeBytes, 0, 4);
401-
await _stream.ReadBytesAsync(buffer, 4, messageSize - 4, _backgroundTaskCancellationToken).ConfigureAwait(false);
402+
await _stream.ReadBytesAsync(buffer, 4, messageSize - 4, readTimeout, _backgroundTaskCancellationToken).ConfigureAwait(false);
402403
_lastUsedAtUtc = DateTime.UtcNow;
403404
buffer.MakeReadOnly();
404405
return buffer;
@@ -544,7 +545,8 @@ private async Task SendBufferAsync(IByteBuffer buffer, CancellationToken cancell
544545
try
545546
{
546547
// don't use the caller's cancellationToken because once we start writing a message we have to write the whole thing
547-
await _stream.WriteBytesAsync(buffer, 0, buffer.Length, _backgroundTaskCancellationToken).ConfigureAwait(false);
548+
var writeTimeout = _stream.CanTimeout ? TimeSpan.FromMilliseconds(_stream.WriteTimeout) : Timeout.InfiniteTimeSpan;
549+
await _stream.WriteBytesAsync(buffer, 0, buffer.Length, writeTimeout, _backgroundTaskCancellationToken).ConfigureAwait(false);
548550
_lastUsedAtUtc = DateTime.UtcNow;
549551
}
550552
catch (Exception ex)

src/MongoDB.Driver.Core/Core/Misc/StreamExtensionMethods.cs

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,40 @@ public static void EfficientCopyTo(this Stream input, Stream output)
3636
}
3737
}
3838

39+
public static async Task<int> ReadAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
40+
{
41+
var state = 1; // 1 == reading, 2 == done reading, 3 == timedout, 4 == cancelled
42+
43+
var bytesRead = 0;
44+
using (new Timer(_ => ChangeState(3), null, timeout, Timeout.InfiniteTimeSpan))
45+
using (cancellationToken.Register(() => ChangeState(4)))
46+
{
47+
try
48+
{
49+
bytesRead = await stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
50+
ChangeState(2);
51+
}
52+
catch when (state >= 3)
53+
{
54+
// a different exception will be thrown instead below
55+
}
56+
57+
if (state == 3) { throw new TimeoutException(); }
58+
if (state == 4) { throw new OperationCanceledException(); }
59+
}
60+
61+
return bytesRead;
62+
63+
void ChangeState(int to)
64+
{
65+
var from = Interlocked.CompareExchange(ref state, to, 1);
66+
if (from == 1 && to >= 3)
67+
{
68+
try { stream.Dispose(); } catch { } // disposing the stream aborts the read attempt
69+
}
70+
}
71+
}
72+
3973
public static void ReadBytes(this Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken)
4074
{
4175
Ensure.IsNotNull(stream, nameof(stream));
@@ -76,7 +110,7 @@ public static void ReadBytes(this Stream stream, IByteBuffer buffer, int offset,
76110
}
77111
}
78112

79-
public static async Task ReadBytesAsync(this Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken)
113+
public static async Task ReadBytesAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
80114
{
81115
Ensure.IsNotNull(stream, nameof(stream));
82116
Ensure.IsNotNull(buffer, nameof(buffer));
@@ -85,7 +119,7 @@ public static async Task ReadBytesAsync(this Stream stream, byte[] buffer, int o
85119

86120
while (count > 0)
87121
{
88-
var bytesRead = await stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
122+
var bytesRead = await stream.ReadAsync(buffer, offset, count, timeout, cancellationToken).ConfigureAwait(false);
89123
if (bytesRead == 0)
90124
{
91125
throw new EndOfStreamException();
@@ -95,7 +129,7 @@ public static async Task ReadBytesAsync(this Stream stream, byte[] buffer, int o
95129
}
96130
}
97131

98-
public static async Task ReadBytesAsync(this Stream stream, IByteBuffer buffer, int offset, int count, CancellationToken cancellationToken)
132+
public static async Task ReadBytesAsync(this Stream stream, IByteBuffer buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
99133
{
100134
Ensure.IsNotNull(stream, nameof(stream));
101135
Ensure.IsNotNull(buffer, nameof(buffer));
@@ -106,7 +140,7 @@ public static async Task ReadBytesAsync(this Stream stream, IByteBuffer buffer,
106140
{
107141
var backingBytes = buffer.AccessBackingBytes(offset);
108142
var bytesToRead = Math.Min(count, backingBytes.Count);
109-
var bytesRead = await stream.ReadAsync(backingBytes.Array, backingBytes.Offset, bytesToRead, cancellationToken).ConfigureAwait(false);
143+
var bytesRead = await stream.ReadAsync(backingBytes.Array, backingBytes.Offset, bytesToRead, timeout, cancellationToken).ConfigureAwait(false);
110144
if (bytesRead == 0)
111145
{
112146
throw new EndOfStreamException();
@@ -116,6 +150,38 @@ public static async Task ReadBytesAsync(this Stream stream, IByteBuffer buffer,
116150
}
117151
}
118152

153+
154+
public static async Task WriteAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
155+
{
156+
var state = 1; // 1 == writing, 2 == done writing, 3 == timedout, 4 == cancelled
157+
158+
using (new Timer(_ => ChangeState(3), null, timeout, Timeout.InfiniteTimeSpan))
159+
using (cancellationToken.Register(() => ChangeState(4)))
160+
{
161+
try
162+
{
163+
await stream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
164+
ChangeState(2);
165+
}
166+
catch when (state >= 3)
167+
{
168+
// a different exception will be thrown instead below
169+
}
170+
171+
if (state == 3) { throw new TimeoutException(); }
172+
if (state == 4) { throw new OperationCanceledException(); }
173+
}
174+
175+
void ChangeState(int to)
176+
{
177+
var from = Interlocked.CompareExchange(ref state, to, 1);
178+
if (from == 1 && to >= 3)
179+
{
180+
try { stream.Dispose(); } catch { } // disposing the stream aborts the write attempt
181+
}
182+
}
183+
}
184+
119185
public static void WriteBytes(this Stream stream, IByteBuffer buffer, int offset, int count, CancellationToken cancellationToken)
120186
{
121187
Ensure.IsNotNull(stream, nameof(stream));
@@ -134,7 +200,7 @@ public static void WriteBytes(this Stream stream, IByteBuffer buffer, int offset
134200
}
135201
}
136202

137-
public static async Task WriteBytesAsync(this Stream stream, IByteBuffer buffer, int offset, int count, CancellationToken cancellationToken)
203+
public static async Task WriteBytesAsync(this Stream stream, IByteBuffer buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken)
138204
{
139205
Ensure.IsNotNull(stream, nameof(stream));
140206
Ensure.IsNotNull(buffer, nameof(buffer));
@@ -145,7 +211,7 @@ public static async Task WriteBytesAsync(this Stream stream, IByteBuffer buffer,
145211
{
146212
var backingBytes = buffer.AccessBackingBytes(offset);
147213
var bytesToWrite = Math.Min(count, backingBytes.Count);
148-
await stream.WriteAsync(backingBytes.Array, backingBytes.Offset, bytesToWrite, cancellationToken).ConfigureAwait(false);
214+
await stream.WriteAsync(backingBytes.Array, backingBytes.Offset, bytesToWrite, timeout, cancellationToken).ConfigureAwait(false);
149215
offset += bytesToWrite;
150216
count -= bytesToWrite;
151217
}

0 commit comments

Comments
 (0)