Skip to content

Commit 6a6fe35

Browse files
committed
Eliminate ContinueWith.
'await' should just be used directly instead of abstracting it behind a utility method.
1 parent 2280759 commit 6a6fe35

File tree

2 files changed

+128
-140
lines changed

2 files changed

+128
-140
lines changed

src/MySqlConnector/Protocol/Serialization/CompressedPayloadHandler.cs

Lines changed: 128 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -39,167 +39,157 @@ public ValueTask<ArraySegment<byte>> ReadPayloadAsync(ArraySegmentHolder<byte> c
3939
return ProtocolUtility.ReadPayloadAsync(m_bufferedByteReader, compressedByteHandler, static () => -1, cache, protocolErrorBehavior, ioBehavior);
4040
}
4141

42-
public ValueTask<int> WritePayloadAsync(ReadOnlyMemory<byte> payload, IOBehavior ioBehavior)
42+
public async ValueTask<int> WritePayloadAsync(ReadOnlyMemory<byte> payload, IOBehavior ioBehavior)
4343
{
4444
// break the payload up into (possibly more than one) uncompressed packets
45-
return ProtocolUtility.WritePayloadAsync(m_uncompressedStreamByteHandler!, GetNextUncompressedSequenceNumber, payload, ioBehavior).ContinueWith(_ =>
46-
{
47-
if (m_uncompressedStream!.Length == 0)
48-
return default;
45+
await ProtocolUtility.WritePayloadAsync(m_uncompressedStreamByteHandler!, GetNextUncompressedSequenceNumber, payload, ioBehavior).ConfigureAwait(false);
46+
47+
if (m_uncompressedStream!.Length == 0)
48+
return default;
49+
50+
if (!m_uncompressedStream.TryGetBuffer(out var uncompressedData))
51+
throw new InvalidOperationException("Couldn't get uncompressed stream buffer.");
4952

50-
if (!m_uncompressedStream.TryGetBuffer(out var uncompressedData))
51-
throw new InvalidOperationException("Couldn't get uncompressed stream buffer.");
52-
53-
return CompressAndWrite(uncompressedData, ioBehavior)
54-
.ContinueWith(__ =>
55-
{
56-
// reset the uncompressed stream to accept more data
57-
m_uncompressedStream.SetLength(0);
58-
return default(ValueTask<int>);
59-
});
60-
});
53+
await CompressAndWrite(uncompressedData, ioBehavior).ConfigureAwait(false);
54+
55+
// reset the uncompressed stream to accept more data
56+
m_uncompressedStream.SetLength(0);
57+
return default;
6158
}
6259

63-
private ValueTask<int> ReadBytesAsync(Memory<byte> buffer, ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior)
60+
private async ValueTask<int> ReadBytesAsync(Memory<byte> buffer, ProtocolErrorBehavior protocolErrorBehavior, IOBehavior ioBehavior)
6461
{
6562
// satisfy the read from cache if possible
63+
int bytesToRead;
6664
if (m_remainingData.Count > 0)
6765
{
68-
var bytesToRead = Math.Min(m_remainingData.Count, buffer.Length);
66+
bytesToRead = Math.Min(m_remainingData.Count, buffer.Length);
6967
m_remainingData.AsSpan(0, bytesToRead).CopyTo(buffer.Span);
7068
m_remainingData = m_remainingData.Slice(bytesToRead);
71-
return new ValueTask<int>(bytesToRead);
69+
return bytesToRead;
7270
}
7371

7472
// read the compressed header (seven bytes)
75-
return m_compressedBufferedByteReader.ReadBytesAsync(m_byteHandler!, 7, ioBehavior)
76-
.ContinueWith(headerReadBytes =>
77-
{
78-
if (headerReadBytes.Count < 7)
79-
{
80-
return protocolErrorBehavior == ProtocolErrorBehavior.Ignore ?
81-
default :
82-
ValueTaskExtensions.FromException<int>(new EndOfStreamException("Wanted to read 7 bytes but only read {0} when reading compressed packet header".FormatInvariant(headerReadBytes.Count)));
83-
}
84-
85-
var payloadLength = (int) SerializationUtility.ReadUInt32(headerReadBytes.AsSpan(0, 3));
86-
var packetSequenceNumber = headerReadBytes.Array![headerReadBytes.Offset + 3];
87-
var uncompressedLength = (int) SerializationUtility.ReadUInt32(headerReadBytes.AsSpan(4, 3));
88-
89-
// verify the compressed packet sequence number
90-
var expectedSequenceNumber = GetNextCompressedSequenceNumber();
91-
if (packetSequenceNumber != expectedSequenceNumber)
92-
{
93-
if (protocolErrorBehavior == ProtocolErrorBehavior.Ignore)
94-
return default;
95-
96-
var exception = MySqlProtocolException.CreateForPacketOutOfOrder(expectedSequenceNumber, packetSequenceNumber);
97-
return ValueTaskExtensions.FromException<int>(exception);
98-
}
99-
100-
// MySQL protocol resets the uncompressed sequence number back to the sequence number of this compressed packet.
101-
// This isn't in the documentation, but the code explicitly notes that uncompressed packets are modified by compression:
102-
// - https://github.com/mysql/mysql-server/blob/c28e258157f39f25e044bb72e8bae1ff00989a3d/sql/net_serv.cc#L276
103-
// - https://github.com/mysql/mysql-server/blob/c28e258157f39f25e044bb72e8bae1ff00989a3d/sql/net_serv.cc#L225-L227
104-
if (!m_isContinuationPacket)
105-
m_uncompressedSequenceNumber = packetSequenceNumber;
106-
107-
// except this doesn't happen when uncompressed packets need to be broken up across multiple compressed packets
108-
m_isContinuationPacket = payloadLength == ProtocolUtility.MaxPacketSize || uncompressedLength == ProtocolUtility.MaxPacketSize;
109-
110-
return m_compressedBufferedByteReader.ReadBytesAsync(m_byteHandler!, payloadLength, ioBehavior)
111-
.ContinueWith(payloadReadBytes =>
112-
{
113-
if (payloadReadBytes.Count < payloadLength)
114-
{
115-
return protocolErrorBehavior == ProtocolErrorBehavior.Ignore ?
116-
default :
117-
ValueTaskExtensions.FromException<int>(new EndOfStreamException("Wanted to read {0} bytes but only read {1} when reading compressed payload".FormatInvariant(payloadLength, payloadReadBytes.Count)));
118-
}
119-
120-
if (uncompressedLength == 0)
121-
{
122-
// data is uncompressed
123-
m_remainingData = payloadReadBytes;
124-
}
125-
else
126-
{
73+
var headerReadBytes = await m_compressedBufferedByteReader.ReadBytesAsync(m_byteHandler!, 7, ioBehavior).ConfigureAwait(false);
74+
if (headerReadBytes.Count < 7)
75+
{
76+
if (protocolErrorBehavior == ProtocolErrorBehavior.Ignore)
77+
return default;
78+
throw new EndOfStreamException("Wanted to read 7 bytes but only read {0} when reading compressed packet header".FormatInvariant(headerReadBytes.Count));
79+
}
80+
81+
var payloadLength = (int) SerializationUtility.ReadUInt32(headerReadBytes.AsSpan(0, 3));
82+
var packetSequenceNumber = headerReadBytes.Array![headerReadBytes.Offset + 3];
83+
var uncompressedLength = (int) SerializationUtility.ReadUInt32(headerReadBytes.AsSpan(4, 3));
84+
85+
// verify the compressed packet sequence number
86+
var expectedSequenceNumber = GetNextCompressedSequenceNumber();
87+
if (packetSequenceNumber != expectedSequenceNumber)
88+
{
89+
if (protocolErrorBehavior == ProtocolErrorBehavior.Ignore)
90+
return default;
91+
throw MySqlProtocolException.CreateForPacketOutOfOrder(expectedSequenceNumber, packetSequenceNumber);
92+
}
93+
94+
// MySQL protocol resets the uncompressed sequence number back to the sequence number of this compressed packet.
95+
// This isn't in the documentation, but the code explicitly notes that uncompressed packets are modified by compression:
96+
// - https://github.com/mysql/mysql-server/blob/c28e258157f39f25e044bb72e8bae1ff00989a3d/sql/net_serv.cc#L276
97+
// - https://github.com/mysql/mysql-server/blob/c28e258157f39f25e044bb72e8bae1ff00989a3d/sql/net_serv.cc#L225-L227
98+
if (!m_isContinuationPacket)
99+
m_uncompressedSequenceNumber = packetSequenceNumber;
100+
101+
// except this doesn't happen when uncompressed packets need to be broken up across multiple compressed packets
102+
m_isContinuationPacket = payloadLength == ProtocolUtility.MaxPacketSize || uncompressedLength == ProtocolUtility.MaxPacketSize;
103+
104+
var payloadReadBytes = await m_compressedBufferedByteReader.ReadBytesAsync(m_byteHandler!, payloadLength, ioBehavior).ConfigureAwait(false);
105+
if (payloadReadBytes.Count < payloadLength)
106+
{
107+
if (protocolErrorBehavior == ProtocolErrorBehavior.Ignore)
108+
return default;
109+
throw new EndOfStreamException("Wanted to read {0} bytes but only read {1} when reading compressed payload".FormatInvariant(payloadLength, payloadReadBytes.Count));
110+
}
111+
112+
if (uncompressedLength == 0)
113+
{
114+
// data is uncompressed
115+
m_remainingData = payloadReadBytes;
116+
}
117+
else
118+
{
127119
#if NET6_0_OR_GREATER
128-
var uncompressedData = new byte[uncompressedLength];
129-
using var compressedStream = new MemoryStream(payloadReadBytes.Array!, payloadReadBytes.Offset, payloadReadBytes.Count);
130-
using var decompressingStream = new ZLibStream(compressedStream, CompressionMode.Decompress);
120+
var uncompressedData = new byte[uncompressedLength];
121+
using var compressedStream = new MemoryStream(payloadReadBytes.Array!, payloadReadBytes.Offset, payloadReadBytes.Count);
122+
using var decompressingStream = new ZLibStream(compressedStream, CompressionMode.Decompress);
131123
#if NET7_0_OR_GREATER
132-
var totalBytesRead = decompressingStream.ReadAtLeast(uncompressedData, uncompressedLength, throwOnEndOfStream: false);
124+
var totalBytesRead = decompressingStream.ReadAtLeast(uncompressedData, uncompressedLength, throwOnEndOfStream: false);
133125
#else
134-
int bytesRead, totalBytesRead = 0;
135-
do
136-
{
137-
bytesRead = decompressingStream.Read(uncompressedData, totalBytesRead, uncompressedLength - totalBytesRead);
138-
totalBytesRead += bytesRead;
139-
} while (bytesRead > 0);
126+
int bytesRead, totalBytesRead = 0;
127+
do
128+
{
129+
bytesRead = decompressingStream.Read(uncompressedData, totalBytesRead, uncompressedLength - totalBytesRead);
130+
totalBytesRead += bytesRead;
131+
} while (bytesRead > 0);
140132
#endif
141-
if (totalBytesRead != uncompressedLength && protocolErrorBehavior == ProtocolErrorBehavior.Throw)
142-
return ValueTaskExtensions.FromException<int>(new InvalidOperationException("Expected to read {0:n0} uncompressed bytes but only read {1:n0}".FormatInvariant(uncompressedLength, totalBytesRead)));
143-
m_remainingData = new(uncompressedData, 0, totalBytesRead);
133+
if (totalBytesRead != uncompressedLength && protocolErrorBehavior == ProtocolErrorBehavior.Throw)
134+
throw new InvalidOperationException("Expected to read {0:n0} uncompressed bytes but only read {1:n0}".FormatInvariant(uncompressedLength, totalBytesRead));
135+
m_remainingData = new(uncompressedData, 0, totalBytesRead);
144136
#else
145-
// check CMF (Compression Method and Flags) and FLG (Flags) bytes for expected values
146-
var cmf = payloadReadBytes.Array![payloadReadBytes.Offset];
147-
var flg = payloadReadBytes.Array[payloadReadBytes.Offset + 1];
148-
if (cmf != 0x78 || ((flg & 0x20) == 0x20) || ((cmf * 256 + flg) % 31 != 0))
149-
{
150-
// CMF = 0x78: 32K Window Size + deflate compression
151-
// FLG & 0x20: has preset dictionary (not supported)
152-
// CMF*256+FLG is a multiple of 31: header checksum
153-
return protocolErrorBehavior == ProtocolErrorBehavior.Ignore ?
154-
default :
155-
ValueTaskExtensions.FromException<int>(new NotSupportedException("Unsupported zlib header: {0:X2}{1:X2}".FormatInvariant(cmf, flg)));
156-
}
157-
158-
// zlib format (https://www.ietf.org/rfc/rfc1950.txt) is: [two header bytes] [deflate-compressed data] [four-byte checksum]
159-
// .NET implements the middle part with DeflateStream; need to handle header and checksum explicitly
160-
const int headerSize = 2;
161-
const int checksumSize = 4;
162-
var uncompressedData = new byte[uncompressedLength];
163-
using var compressedStream = new MemoryStream(payloadReadBytes.Array, payloadReadBytes.Offset + headerSize, payloadReadBytes.Count - headerSize - checksumSize);
164-
using var decompressingStream = new DeflateStream(compressedStream, CompressionMode.Decompress);
165-
int bytesRead, totalBytesRead = 0;
166-
do
167-
{
168-
bytesRead = decompressingStream.Read(uncompressedData, totalBytesRead, uncompressedLength - totalBytesRead);
169-
totalBytesRead += bytesRead;
170-
} while (bytesRead > 0);
171-
if (totalBytesRead != uncompressedLength && protocolErrorBehavior == ProtocolErrorBehavior.Throw)
172-
return ValueTaskExtensions.FromException<int>(new InvalidOperationException("Expected to read {0:n0} uncompressed bytes but only read {1:n0}".FormatInvariant(uncompressedLength, totalBytesRead)));
173-
m_remainingData = new(uncompressedData, 0, totalBytesRead);
174-
175-
var checksum = Adler32.Calculate(uncompressedData.AsSpan(0, totalBytesRead));
176-
177-
var adlerStartOffset = payloadReadBytes.Offset + payloadReadBytes.Count - 4;
178-
if (payloadReadBytes.Array[adlerStartOffset + 0] != ((checksum >> 24) & 0xFF) ||
179-
payloadReadBytes.Array[adlerStartOffset + 1] != ((checksum >> 16) & 0xFF) ||
180-
payloadReadBytes.Array[adlerStartOffset + 2] != ((checksum >> 8) & 0xFF) ||
181-
payloadReadBytes.Array[adlerStartOffset + 3] != (checksum & 0xFF))
182-
{
183-
return protocolErrorBehavior == ProtocolErrorBehavior.Ignore ?
184-
default :
185-
ValueTaskExtensions.FromException<int>(new NotSupportedException("Invalid Adler-32 checksum of uncompressed data."));
186-
}
137+
// check CMF (Compression Method and Flags) and FLG (Flags) bytes for expected values
138+
var cmf = payloadReadBytes.Array![payloadReadBytes.Offset];
139+
var flg = payloadReadBytes.Array[payloadReadBytes.Offset + 1];
140+
if (cmf != 0x78 || ((flg & 0x20) == 0x20) || ((cmf * 256 + flg) % 31 != 0))
141+
{
142+
// CMF = 0x78: 32K Window Size + deflate compression
143+
// FLG & 0x20: has preset dictionary (not supported)
144+
// CMF*256+FLG is a multiple of 31: header checksum
145+
if (protocolErrorBehavior == ProtocolErrorBehavior.Ignore)
146+
return default;
147+
throw new NotSupportedException("Unsupported zlib header: {0:X2}{1:X2}".FormatInvariant(cmf, flg));
148+
}
149+
150+
// zlib format (https://www.ietf.org/rfc/rfc1950.txt) is: [two header bytes] [deflate-compressed data] [four-byte checksum]
151+
// .NET implements the middle part with DeflateStream; need to handle header and checksum explicitly
152+
const int headerSize = 2;
153+
const int checksumSize = 4;
154+
var uncompressedData = new byte[uncompressedLength];
155+
using var compressedStream = new MemoryStream(payloadReadBytes.Array, payloadReadBytes.Offset + headerSize, payloadReadBytes.Count - headerSize - checksumSize);
156+
using var decompressingStream = new DeflateStream(compressedStream, CompressionMode.Decompress);
157+
int bytesRead, totalBytesRead = 0;
158+
do
159+
{
160+
bytesRead = decompressingStream.Read(uncompressedData, totalBytesRead, uncompressedLength - totalBytesRead);
161+
totalBytesRead += bytesRead;
162+
} while (bytesRead > 0);
163+
if (totalBytesRead != uncompressedLength && protocolErrorBehavior == ProtocolErrorBehavior.Throw)
164+
throw new InvalidOperationException("Expected to read {0:n0} uncompressed bytes but only read {1:n0}".FormatInvariant(uncompressedLength, totalBytesRead));
165+
m_remainingData = new(uncompressedData, 0, totalBytesRead);
166+
167+
var checksum = Adler32.Calculate(uncompressedData.AsSpan(0, totalBytesRead));
168+
169+
var adlerStartOffset = payloadReadBytes.Offset + payloadReadBytes.Count - 4;
170+
if (payloadReadBytes.Array[adlerStartOffset + 0] != ((checksum >> 24) & 0xFF) ||
171+
payloadReadBytes.Array[adlerStartOffset + 1] != ((checksum >> 16) & 0xFF) ||
172+
payloadReadBytes.Array[adlerStartOffset + 2] != ((checksum >> 8) & 0xFF) ||
173+
payloadReadBytes.Array[adlerStartOffset + 3] != (checksum & 0xFF))
174+
{
175+
if (protocolErrorBehavior == ProtocolErrorBehavior.Ignore)
176+
return default;
177+
throw new NotSupportedException("Invalid Adler-32 checksum of uncompressed data.");
178+
}
187179
#endif
188-
}
189-
190-
var bytesToRead = Math.Min(m_remainingData.Count, buffer.Length);
191-
m_remainingData.AsSpan(0, bytesToRead).CopyTo(buffer.Span);
192-
m_remainingData = m_remainingData.Slice(bytesToRead);
193-
return new ValueTask<int>(bytesToRead);
194-
});
195-
});
180+
}
181+
182+
bytesToRead = Math.Min(m_remainingData.Count, buffer.Length);
183+
m_remainingData.AsSpan(0, bytesToRead).CopyTo(buffer.Span);
184+
m_remainingData = m_remainingData.Slice(bytesToRead);
185+
return bytesToRead;
196186
}
197187

198188
private byte GetNextCompressedSequenceNumber() => m_compressedSequenceNumber++;
199189

200190
private int GetNextUncompressedSequenceNumber() => m_uncompressedSequenceNumber++;
201191

202-
private ValueTask<int> CompressAndWrite(ArraySegment<byte> remainingUncompressedData, IOBehavior ioBehavior)
192+
private async ValueTask<int> CompressAndWrite(ArraySegment<byte> remainingUncompressedData, IOBehavior ioBehavior)
203193
{
204194
var remainingUncompressedBytes = Math.Min(remainingUncompressedData.Count, ProtocolUtility.MaxPacketSize);
205195

@@ -248,9 +238,9 @@ private ValueTask<int> CompressAndWrite(ArraySegment<byte> remainingUncompressed
248238
Buffer.BlockCopy(compressedData.Array!, compressedData.Offset, buffer, 7, compressedData.Count);
249239

250240
remainingUncompressedData = remainingUncompressedData.Slice(remainingUncompressedBytes);
251-
return m_byteHandler!.WriteBytesAsync(new ArraySegment<byte>(buffer, 0, buffer.Length), ioBehavior)
252-
.ContinueWith(_ => remainingUncompressedData.Count == 0 ? default :
253-
CompressAndWrite(remainingUncompressedData, ioBehavior));
241+
await m_byteHandler!.WriteBytesAsync(new ArraySegment<byte>(buffer, 0, buffer.Length), ioBehavior).ConfigureAwait(false);
242+
return remainingUncompressedData.Count == 0 ? default :
243+
await CompressAndWrite(remainingUncompressedData, ioBehavior).ConfigureAwait(false);
254244
}
255245

256246
// CompressedByteHandler implements IByteHandler and delegates reading bytes back to the CompressedPayloadHandler class.

src/MySqlConnector/Utilities/ValueTaskExtensions.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,5 @@ namespace MySqlConnector.Utilities;
22

33
internal static class ValueTaskExtensions
44
{
5-
public static async ValueTask<TResult> ContinueWith<T, TResult>(this ValueTask<T> valueTask, Func<T, ValueTask<TResult>> continuation) => await continuation(await valueTask.ConfigureAwait(false)).ConfigureAwait(false);
6-
75
public static ValueTask<T> FromException<T>(Exception exception) => new ValueTask<T>(Task.FromException<T>(exception));
86
}

0 commit comments

Comments
 (0)