@@ -39,167 +39,157 @@ public ValueTask<ArraySegment<byte>> ReadPayloadAsync(ArraySegmentHolder<byte> c
39
39
return ProtocolUtility . ReadPayloadAsync ( m_bufferedByteReader , compressedByteHandler , static ( ) => - 1 , cache , protocolErrorBehavior , ioBehavior ) ;
40
40
}
41
41
42
- public ValueTask < int > WritePayloadAsync ( ReadOnlyMemory < byte > payload , IOBehavior ioBehavior )
42
+ public async ValueTask < int > WritePayloadAsync ( ReadOnlyMemory < byte > payload , IOBehavior ioBehavior )
43
43
{
44
44
// break the payload up into (possibly more than one) uncompressed packets
45
- return ProtocolUtility . WritePayloadAsync ( m_uncompressedStreamByteHandler ! , GetNextUncompressedSequenceNumber , payload , ioBehavior ) . ContinueWith ( _ =>
46
- {
47
- if ( m_uncompressedStream ! . Length == 0 )
48
- return default ;
45
+ await ProtocolUtility . WritePayloadAsync ( m_uncompressedStreamByteHandler ! , GetNextUncompressedSequenceNumber , payload , ioBehavior ) . ConfigureAwait ( false ) ;
46
+
47
+ if ( m_uncompressedStream ! . Length == 0 )
48
+ return default ;
49
+
50
+ if ( ! m_uncompressedStream . TryGetBuffer ( out var uncompressedData ) )
51
+ throw new InvalidOperationException ( "Couldn't get uncompressed stream buffer." ) ;
49
52
50
- if ( ! m_uncompressedStream . TryGetBuffer ( out var uncompressedData ) )
51
- throw new InvalidOperationException ( "Couldn't get uncompressed stream buffer." ) ;
52
-
53
- return CompressAndWrite ( uncompressedData , ioBehavior )
54
- . ContinueWith ( __ =>
55
- {
56
- // reset the uncompressed stream to accept more data
57
- m_uncompressedStream . SetLength ( 0 ) ;
58
- return default ( ValueTask < int > ) ;
59
- } ) ;
60
- } ) ;
53
+ await CompressAndWrite ( uncompressedData , ioBehavior ) . ConfigureAwait ( false ) ;
54
+
55
+ // reset the uncompressed stream to accept more data
56
+ m_uncompressedStream . SetLength ( 0 ) ;
57
+ return default ;
61
58
}
62
59
63
- private ValueTask < int > ReadBytesAsync ( Memory < byte > buffer , ProtocolErrorBehavior protocolErrorBehavior , IOBehavior ioBehavior )
60
+ private async ValueTask < int > ReadBytesAsync ( Memory < byte > buffer , ProtocolErrorBehavior protocolErrorBehavior , IOBehavior ioBehavior )
64
61
{
65
62
// satisfy the read from cache if possible
63
+ int bytesToRead ;
66
64
if ( m_remainingData . Count > 0 )
67
65
{
68
- var bytesToRead = Math . Min ( m_remainingData . Count , buffer . Length ) ;
66
+ bytesToRead = Math . Min ( m_remainingData . Count , buffer . Length ) ;
69
67
m_remainingData . AsSpan ( 0 , bytesToRead ) . CopyTo ( buffer . Span ) ;
70
68
m_remainingData = m_remainingData . Slice ( bytesToRead ) ;
71
- return new ValueTask < int > ( bytesToRead ) ;
69
+ return bytesToRead ;
72
70
}
73
71
74
72
// read the compressed header (seven bytes)
75
- return m_compressedBufferedByteReader . ReadBytesAsync ( m_byteHandler ! , 7 , ioBehavior )
76
- . ContinueWith ( headerReadBytes =>
77
- {
78
- if ( headerReadBytes . Count < 7 )
79
- {
80
- return protocolErrorBehavior == ProtocolErrorBehavior . Ignore ?
81
- default :
82
- ValueTaskExtensions . FromException < int > ( new EndOfStreamException ( "Wanted to read 7 bytes but only read {0} when reading compressed packet header" . FormatInvariant ( headerReadBytes . Count ) ) ) ;
83
- }
84
-
85
- var payloadLength = ( int ) SerializationUtility . ReadUInt32 ( headerReadBytes . AsSpan ( 0 , 3 ) ) ;
86
- var packetSequenceNumber = headerReadBytes . Array ! [ headerReadBytes . Offset + 3 ] ;
87
- var uncompressedLength = ( int ) SerializationUtility . ReadUInt32 ( headerReadBytes . AsSpan ( 4 , 3 ) ) ;
88
-
89
- // verify the compressed packet sequence number
90
- var expectedSequenceNumber = GetNextCompressedSequenceNumber ( ) ;
91
- if ( packetSequenceNumber != expectedSequenceNumber )
92
- {
93
- if ( protocolErrorBehavior == ProtocolErrorBehavior . Ignore )
94
- return default ;
95
-
96
- var exception = MySqlProtocolException . CreateForPacketOutOfOrder ( expectedSequenceNumber , packetSequenceNumber ) ;
97
- return ValueTaskExtensions . FromException < int > ( exception ) ;
98
- }
99
-
100
- // MySQL protocol resets the uncompressed sequence number back to the sequence number of this compressed packet.
101
- // This isn't in the documentation, but the code explicitly notes that uncompressed packets are modified by compression:
102
- // - https://github.com/mysql/mysql-server/blob/c28e258157f39f25e044bb72e8bae1ff00989a3d/sql/net_serv.cc#L276
103
- // - https://github.com/mysql/mysql-server/blob/c28e258157f39f25e044bb72e8bae1ff00989a3d/sql/net_serv.cc#L225-L227
104
- if ( ! m_isContinuationPacket )
105
- m_uncompressedSequenceNumber = packetSequenceNumber ;
106
-
107
- // except this doesn't happen when uncompressed packets need to be broken up across multiple compressed packets
108
- m_isContinuationPacket = payloadLength == ProtocolUtility . MaxPacketSize || uncompressedLength == ProtocolUtility . MaxPacketSize ;
109
-
110
- return m_compressedBufferedByteReader . ReadBytesAsync ( m_byteHandler ! , payloadLength , ioBehavior )
111
- . ContinueWith ( payloadReadBytes =>
112
- {
113
- if ( payloadReadBytes . Count < payloadLength )
114
- {
115
- return protocolErrorBehavior == ProtocolErrorBehavior . Ignore ?
116
- default :
117
- ValueTaskExtensions . FromException < int > ( new EndOfStreamException ( "Wanted to read {0} bytes but only read {1} when reading compressed payload" . FormatInvariant ( payloadLength , payloadReadBytes . Count ) ) ) ;
118
- }
119
-
120
- if ( uncompressedLength == 0 )
121
- {
122
- // data is uncompressed
123
- m_remainingData = payloadReadBytes ;
124
- }
125
- else
126
- {
73
+ var headerReadBytes = await m_compressedBufferedByteReader . ReadBytesAsync ( m_byteHandler ! , 7 , ioBehavior ) . ConfigureAwait ( false ) ;
74
+ if ( headerReadBytes . Count < 7 )
75
+ {
76
+ if ( protocolErrorBehavior == ProtocolErrorBehavior . Ignore )
77
+ return default ;
78
+ throw new EndOfStreamException ( "Wanted to read 7 bytes but only read {0} when reading compressed packet header" . FormatInvariant ( headerReadBytes . Count ) ) ;
79
+ }
80
+
81
+ var payloadLength = ( int ) SerializationUtility . ReadUInt32 ( headerReadBytes . AsSpan ( 0 , 3 ) ) ;
82
+ var packetSequenceNumber = headerReadBytes . Array ! [ headerReadBytes . Offset + 3 ] ;
83
+ var uncompressedLength = ( int ) SerializationUtility . ReadUInt32 ( headerReadBytes . AsSpan ( 4 , 3 ) ) ;
84
+
85
+ // verify the compressed packet sequence number
86
+ var expectedSequenceNumber = GetNextCompressedSequenceNumber ( ) ;
87
+ if ( packetSequenceNumber != expectedSequenceNumber )
88
+ {
89
+ if ( protocolErrorBehavior == ProtocolErrorBehavior . Ignore )
90
+ return default ;
91
+ throw MySqlProtocolException . CreateForPacketOutOfOrder ( expectedSequenceNumber , packetSequenceNumber ) ;
92
+ }
93
+
94
+ // MySQL protocol resets the uncompressed sequence number back to the sequence number of this compressed packet.
95
+ // This isn't in the documentation, but the code explicitly notes that uncompressed packets are modified by compression:
96
+ // - https://github.com/mysql/mysql-server/blob/c28e258157f39f25e044bb72e8bae1ff00989a3d/sql/net_serv.cc#L276
97
+ // - https://github.com/mysql/mysql-server/blob/c28e258157f39f25e044bb72e8bae1ff00989a3d/sql/net_serv.cc#L225-L227
98
+ if ( ! m_isContinuationPacket )
99
+ m_uncompressedSequenceNumber = packetSequenceNumber ;
100
+
101
+ // except this doesn't happen when uncompressed packets need to be broken up across multiple compressed packets
102
+ m_isContinuationPacket = payloadLength == ProtocolUtility . MaxPacketSize || uncompressedLength == ProtocolUtility . MaxPacketSize ;
103
+
104
+ var payloadReadBytes = await m_compressedBufferedByteReader . ReadBytesAsync ( m_byteHandler ! , payloadLength , ioBehavior ) . ConfigureAwait ( false ) ;
105
+ if ( payloadReadBytes . Count < payloadLength )
106
+ {
107
+ if ( protocolErrorBehavior == ProtocolErrorBehavior . Ignore )
108
+ return default ;
109
+ throw new EndOfStreamException ( "Wanted to read {0} bytes but only read {1} when reading compressed payload" . FormatInvariant ( payloadLength , payloadReadBytes . Count ) ) ;
110
+ }
111
+
112
+ if ( uncompressedLength == 0 )
113
+ {
114
+ // data is uncompressed
115
+ m_remainingData = payloadReadBytes ;
116
+ }
117
+ else
118
+ {
127
119
#if NET6_0_OR_GREATER
128
- var uncompressedData = new byte [ uncompressedLength ] ;
129
- using var compressedStream = new MemoryStream ( payloadReadBytes . Array ! , payloadReadBytes . Offset , payloadReadBytes . Count ) ;
130
- using var decompressingStream = new ZLibStream ( compressedStream , CompressionMode . Decompress ) ;
120
+ var uncompressedData = new byte [ uncompressedLength ] ;
121
+ using var compressedStream = new MemoryStream ( payloadReadBytes . Array ! , payloadReadBytes . Offset , payloadReadBytes . Count ) ;
122
+ using var decompressingStream = new ZLibStream ( compressedStream , CompressionMode . Decompress ) ;
131
123
#if NET7_0_OR_GREATER
132
- var totalBytesRead = decompressingStream . ReadAtLeast ( uncompressedData , uncompressedLength , throwOnEndOfStream : false ) ;
124
+ var totalBytesRead = decompressingStream . ReadAtLeast ( uncompressedData , uncompressedLength , throwOnEndOfStream : false ) ;
133
125
#else
134
- int bytesRead , totalBytesRead = 0 ;
135
- do
136
- {
137
- bytesRead = decompressingStream . Read ( uncompressedData , totalBytesRead , uncompressedLength - totalBytesRead ) ;
138
- totalBytesRead += bytesRead ;
139
- } while ( bytesRead > 0 ) ;
126
+ int bytesRead , totalBytesRead = 0 ;
127
+ do
128
+ {
129
+ bytesRead = decompressingStream . Read ( uncompressedData , totalBytesRead , uncompressedLength - totalBytesRead ) ;
130
+ totalBytesRead += bytesRead ;
131
+ } while ( bytesRead > 0 ) ;
140
132
#endif
141
- if ( totalBytesRead != uncompressedLength && protocolErrorBehavior == ProtocolErrorBehavior . Throw )
142
- return ValueTaskExtensions . FromException < int > ( new InvalidOperationException ( "Expected to read {0:n0} uncompressed bytes but only read {1:n0}" . FormatInvariant ( uncompressedLength , totalBytesRead ) ) ) ;
143
- m_remainingData = new ( uncompressedData , 0 , totalBytesRead ) ;
133
+ if ( totalBytesRead != uncompressedLength && protocolErrorBehavior == ProtocolErrorBehavior . Throw )
134
+ throw new InvalidOperationException ( "Expected to read {0:n0} uncompressed bytes but only read {1:n0}" . FormatInvariant ( uncompressedLength , totalBytesRead ) ) ;
135
+ m_remainingData = new ( uncompressedData , 0 , totalBytesRead ) ;
144
136
#else
145
- // check CMF (Compression Method and Flags) and FLG (Flags) bytes for expected values
146
- var cmf = payloadReadBytes . Array ! [ payloadReadBytes . Offset ] ;
147
- var flg = payloadReadBytes . Array [ payloadReadBytes . Offset + 1 ] ;
148
- if ( cmf != 0x78 || ( ( flg & 0x20 ) == 0x20 ) || ( ( cmf * 256 + flg ) % 31 != 0 ) )
149
- {
150
- // CMF = 0x78: 32K Window Size + deflate compression
151
- // FLG & 0x20: has preset dictionary (not supported)
152
- // CMF*256+FLG is a multiple of 31: header checksum
153
- return protocolErrorBehavior == ProtocolErrorBehavior . Ignore ?
154
- default :
155
- ValueTaskExtensions . FromException < int > ( new NotSupportedException ( "Unsupported zlib header: {0:X2}{1:X2}" . FormatInvariant ( cmf , flg ) ) ) ;
156
- }
157
-
158
- // zlib format (https://www.ietf.org/rfc/rfc1950.txt) is: [two header bytes] [deflate-compressed data] [four-byte checksum]
159
- // .NET implements the middle part with DeflateStream; need to handle header and checksum explicitly
160
- const int headerSize = 2 ;
161
- const int checksumSize = 4 ;
162
- var uncompressedData = new byte [ uncompressedLength ] ;
163
- using var compressedStream = new MemoryStream ( payloadReadBytes . Array , payloadReadBytes . Offset + headerSize , payloadReadBytes . Count - headerSize - checksumSize ) ;
164
- using var decompressingStream = new DeflateStream ( compressedStream , CompressionMode . Decompress ) ;
165
- int bytesRead , totalBytesRead = 0 ;
166
- do
167
- {
168
- bytesRead = decompressingStream . Read ( uncompressedData , totalBytesRead , uncompressedLength - totalBytesRead ) ;
169
- totalBytesRead += bytesRead ;
170
- } while ( bytesRead > 0 ) ;
171
- if ( totalBytesRead != uncompressedLength && protocolErrorBehavior == ProtocolErrorBehavior . Throw )
172
- return ValueTaskExtensions . FromException < int > ( new InvalidOperationException ( "Expected to read {0:n0} uncompressed bytes but only read {1:n0}" . FormatInvariant ( uncompressedLength , totalBytesRead ) ) ) ;
173
- m_remainingData = new ( uncompressedData , 0 , totalBytesRead ) ;
174
-
175
- var checksum = Adler32 . Calculate ( uncompressedData . AsSpan ( 0 , totalBytesRead ) ) ;
176
-
177
- var adlerStartOffset = payloadReadBytes . Offset + payloadReadBytes . Count - 4 ;
178
- if ( payloadReadBytes . Array [ adlerStartOffset + 0 ] != ( ( checksum >> 24 ) & 0xFF ) ||
179
- payloadReadBytes . Array [ adlerStartOffset + 1 ] != ( ( checksum >> 16 ) & 0xFF ) ||
180
- payloadReadBytes . Array [ adlerStartOffset + 2 ] != ( ( checksum >> 8 ) & 0xFF ) ||
181
- payloadReadBytes . Array [ adlerStartOffset + 3 ] != ( checksum & 0xFF ) )
182
- {
183
- return protocolErrorBehavior == ProtocolErrorBehavior . Ignore ?
184
- default :
185
- ValueTaskExtensions . FromException < int > ( new NotSupportedException ( "Invalid Adler-32 checksum of uncompressed data." ) ) ;
186
- }
137
+ // check CMF (Compression Method and Flags) and FLG (Flags) bytes for expected values
138
+ var cmf = payloadReadBytes . Array ! [ payloadReadBytes . Offset ] ;
139
+ var flg = payloadReadBytes . Array [ payloadReadBytes . Offset + 1 ] ;
140
+ if ( cmf != 0x78 || ( ( flg & 0x20 ) == 0x20 ) || ( ( cmf * 256 + flg ) % 31 != 0 ) )
141
+ {
142
+ // CMF = 0x78: 32K Window Size + deflate compression
143
+ // FLG & 0x20: has preset dictionary (not supported)
144
+ // CMF*256+FLG is a multiple of 31: header checksum
145
+ if ( protocolErrorBehavior == ProtocolErrorBehavior . Ignore )
146
+ return default ;
147
+ throw new NotSupportedException ( "Unsupported zlib header: {0:X2}{1:X2}" . FormatInvariant ( cmf , flg ) ) ;
148
+ }
149
+
150
+ // zlib format (https://www.ietf.org/rfc/rfc1950.txt) is: [two header bytes] [deflate-compressed data] [four-byte checksum]
151
+ // .NET implements the middle part with DeflateStream; need to handle header and checksum explicitly
152
+ const int headerSize = 2 ;
153
+ const int checksumSize = 4 ;
154
+ var uncompressedData = new byte [ uncompressedLength ] ;
155
+ using var compressedStream = new MemoryStream ( payloadReadBytes . Array , payloadReadBytes . Offset + headerSize , payloadReadBytes . Count - headerSize - checksumSize ) ;
156
+ using var decompressingStream = new DeflateStream ( compressedStream , CompressionMode . Decompress ) ;
157
+ int bytesRead , totalBytesRead = 0 ;
158
+ do
159
+ {
160
+ bytesRead = decompressingStream . Read ( uncompressedData , totalBytesRead , uncompressedLength - totalBytesRead ) ;
161
+ totalBytesRead += bytesRead ;
162
+ } while ( bytesRead > 0 ) ;
163
+ if ( totalBytesRead != uncompressedLength && protocolErrorBehavior == ProtocolErrorBehavior . Throw )
164
+ throw new InvalidOperationException ( "Expected to read {0:n0} uncompressed bytes but only read {1:n0}" . FormatInvariant ( uncompressedLength , totalBytesRead ) ) ;
165
+ m_remainingData = new ( uncompressedData , 0 , totalBytesRead ) ;
166
+
167
+ var checksum = Adler32 . Calculate ( uncompressedData . AsSpan ( 0 , totalBytesRead ) ) ;
168
+
169
+ var adlerStartOffset = payloadReadBytes . Offset + payloadReadBytes . Count - 4 ;
170
+ if ( payloadReadBytes . Array [ adlerStartOffset + 0 ] != ( ( checksum >> 24 ) & 0xFF ) ||
171
+ payloadReadBytes . Array [ adlerStartOffset + 1 ] != ( ( checksum >> 16 ) & 0xFF ) ||
172
+ payloadReadBytes . Array [ adlerStartOffset + 2 ] != ( ( checksum >> 8 ) & 0xFF ) ||
173
+ payloadReadBytes . Array [ adlerStartOffset + 3 ] != ( checksum & 0xFF ) )
174
+ {
175
+ if ( protocolErrorBehavior == ProtocolErrorBehavior . Ignore )
176
+ return default ;
177
+ throw new NotSupportedException ( "Invalid Adler-32 checksum of uncompressed data." ) ;
178
+ }
187
179
#endif
188
- }
189
-
190
- var bytesToRead = Math . Min ( m_remainingData . Count , buffer . Length ) ;
191
- m_remainingData . AsSpan ( 0 , bytesToRead ) . CopyTo ( buffer . Span ) ;
192
- m_remainingData = m_remainingData . Slice ( bytesToRead ) ;
193
- return new ValueTask < int > ( bytesToRead ) ;
194
- } ) ;
195
- } ) ;
180
+ }
181
+
182
+ bytesToRead = Math . Min ( m_remainingData . Count , buffer . Length ) ;
183
+ m_remainingData . AsSpan ( 0 , bytesToRead ) . CopyTo ( buffer . Span ) ;
184
+ m_remainingData = m_remainingData . Slice ( bytesToRead ) ;
185
+ return bytesToRead ;
196
186
}
197
187
198
188
private byte GetNextCompressedSequenceNumber ( ) => m_compressedSequenceNumber ++ ;
199
189
200
190
private int GetNextUncompressedSequenceNumber ( ) => m_uncompressedSequenceNumber ++ ;
201
191
202
- private ValueTask < int > CompressAndWrite ( ArraySegment < byte > remainingUncompressedData , IOBehavior ioBehavior )
192
+ private async ValueTask < int > CompressAndWrite ( ArraySegment < byte > remainingUncompressedData , IOBehavior ioBehavior )
203
193
{
204
194
var remainingUncompressedBytes = Math . Min ( remainingUncompressedData . Count , ProtocolUtility . MaxPacketSize ) ;
205
195
@@ -248,9 +238,9 @@ private ValueTask<int> CompressAndWrite(ArraySegment<byte> remainingUncompressed
248
238
Buffer . BlockCopy ( compressedData . Array ! , compressedData . Offset , buffer , 7 , compressedData . Count ) ;
249
239
250
240
remainingUncompressedData = remainingUncompressedData . Slice ( remainingUncompressedBytes ) ;
251
- return m_byteHandler ! . WriteBytesAsync ( new ArraySegment < byte > ( buffer , 0 , buffer . Length ) , ioBehavior )
252
- . ContinueWith ( _ => remainingUncompressedData . Count == 0 ? default :
253
- CompressAndWrite ( remainingUncompressedData , ioBehavior ) ) ;
241
+ await m_byteHandler ! . WriteBytesAsync ( new ArraySegment < byte > ( buffer , 0 , buffer . Length ) , ioBehavior ) . ConfigureAwait ( false ) ;
242
+ return remainingUncompressedData . Count == 0 ? default :
243
+ await CompressAndWrite ( remainingUncompressedData , ioBehavior ) . ConfigureAwait ( false ) ;
254
244
}
255
245
256
246
// CompressedByteHandler implements IByteHandler and delegates reading bytes back to the CompressedPayloadHandler class.
0 commit comments