@@ -36,7 +36,7 @@ func init() {
3636 }
3737}
3838
39- func zDecompress (src , dst []byte ) (int , error ) {
39+ func zDecompress (src []byte , dst * bytes. Buffer ) (int , error ) {
4040 br := bytes .NewReader (src )
4141 var zr io.ReadCloser
4242 var err error
@@ -51,27 +51,11 @@ func zDecompress(src, dst []byte) (int, error) {
5151 return 0 , err
5252 }
5353 }
54- defer func () {
55- zr .Close ()
56- zrPool .Put (zr )
57- }()
5854
59- lenRead := 0
60- size := len (dst )
61-
62- for lenRead < size {
63- n , err := zr .Read (dst [lenRead :])
64- lenRead += n
65-
66- if err == io .EOF {
67- if lenRead < size {
68- return lenRead , io .ErrUnexpectedEOF
69- }
70- } else if err != nil {
71- return lenRead , err
72- }
73- }
74- return lenRead , nil
55+ n , _ := dst .ReadFrom (zr ) // ignore err because zr.Close() will return it again.
56+ err = zr .Close () // zr.Close() may return chuecksum error.
57+ zrPool .Put (zr )
58+ return int (n ), err
7559}
7660
7761func zCompress (src []byte , dst io.Writer ) error {
@@ -100,7 +84,7 @@ func (c *compIO) reset() {
10084 c .buff .Reset ()
10185}
10286
103- func (c * compIO ) readNext (need int , r readwriteFunc ) ([]byte , error ) {
87+ func (c * compIO ) readNext (need int , r readerFunc ) ([]byte , error ) {
10488 for c .buff .Len () < need {
10589 if err := c .readCompressedPacket (r ); err != nil {
10690 return nil , err
@@ -110,7 +94,7 @@ func (c *compIO) readNext(need int, r readwriteFunc) ([]byte, error) {
11094 return data [:need :need ], nil // prevent caller writes into c.buff
11195}
11296
113- func (c * compIO ) readCompressedPacket (r readwriteFunc ) error {
97+ func (c * compIO ) readCompressedPacket (r readerFunc ) error {
11498 header , err := c .mc .buf .readNext (7 , r ) // size of compressed header
11599 if err != nil {
116100 return err
@@ -121,19 +105,17 @@ func (c *compIO) readCompressedPacket(r readwriteFunc) error {
121105 comprLength := getUint24 (header [0 :3 ])
122106 compressionSequence := uint8 (header [3 ])
123107 uncompressedLength := getUint24 (header [4 :7 ])
124- if debugTrace {
108+ if debug {
125109 fmt .Printf ("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n " ,
126110 comprLength , uncompressedLength , compressionSequence , c .mc .sequence )
127111 }
128- if compressionSequence != c .mc .sequence {
129- // return ErrPktSync
130- // server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes)
131- // before receiving all packets from client. In this case, seqnr is younger than expected.
132- if debugTrace {
133- fmt .Printf ("WARN: unexpected cmpress seq nr: expected %v, got %v" ,
134- c .mc .sequence , compressionSequence )
135- }
136- // TODO(methane): report error when the packet is not an error packet.
112+ // Do not return ErrPktSync here.
113+ // Server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes)
114+ // before receiving all packets from client. In this case, seqnr is younger than expected.
115+ // NOTE: Both of mariadbclient and mysqlclient do not check seqnr. Only server checks it.
116+ if debug && compressionSequence != c .mc .sequence {
117+ fmt .Printf ("WARN: unexpected cmpress seq nr: expected %v, got %v" ,
118+ c .mc .sequence , compressionSequence )
137119 }
138120 c .mc .sequence = compressionSequence + 1
139121 c .mc .compressSequence = c .mc .sequence
@@ -152,31 +134,29 @@ func (c *compIO) readCompressedPacket(r readwriteFunc) error {
152134
153135 // use existing capacity in bytesBuf if possible
154136 c .buff .Grow (uncompressedLength )
155- dec := c .buff .AvailableBuffer ()[:uncompressedLength ]
156- lenRead , err := zDecompress (comprData , dec )
137+ nread , err := zDecompress (comprData , & c .buff )
157138 if err != nil {
158139 return err
159140 }
160- if lenRead != uncompressedLength {
141+ if nread != uncompressedLength {
161142 return fmt .Errorf ("invalid compressed packet: uncompressed length in header is %d, actual %d" ,
162- uncompressedLength , lenRead )
143+ uncompressedLength , nread )
163144 }
164- c .buff .Write (dec ) // fast copy. See bytes.Buffer.AvailableBuffer() doc.
165145 return nil
166146}
167147
148+ const minCompressLength = 150
168149const maxPayloadLen = maxPacketSize - 4
169150
170151// writePackets sends one or some packets with compression.
171152// Use this instead of mc.netConn.Write() when mc.compress is true.
172153func (c * compIO ) writePackets (packets []byte ) (int , error ) {
173154 totalBytes := len (packets )
174- dataLen := len (packets )
175155 blankHeader := make ([]byte , 7 )
176156 buf := & c .buff
177157
178- for dataLen > 0 {
179- payloadLen := min (maxPayloadLen , dataLen )
158+ for len ( packets ) > 0 {
159+ payloadLen := min (maxPayloadLen , len ( packets ) )
180160 payload := packets [:payloadLen ]
181161 uncompressedLen := payloadLen
182162
@@ -190,8 +170,8 @@ func (c *compIO) writePackets(packets []byte) (int, error) {
190170 } else {
191171 zCompress (payload , buf )
192172 // do not compress if compressed data is larger than uncompressed data
193- // I intentionally miss 7 byte header in the buf; compress should compress more than 7 bytes.
194- if buf .Len () > uncompressedLen {
173+ // I intentionally miss 7 byte header in the buf; compress more than 7 bytes.
174+ if buf .Len () >= uncompressedLen {
195175 buf .Reset ()
196176 buf .Write (blankHeader )
197177 buf .Write (payload )
@@ -204,7 +184,6 @@ func (c *compIO) writePackets(packets []byte) (int, error) {
204184 // up compressed bytes that is returned by underlying Write().
205185 return totalBytes - len (packets ) + n , err
206186 }
207- dataLen -= payloadLen
208187 packets = packets [payloadLen :]
209188 }
210189
@@ -216,7 +195,7 @@ func (c *compIO) writePackets(packets []byte) (int, error) {
216195func (c * compIO ) writeCompressedPacket (data []byte , uncompressedLen int ) (int , error ) {
217196 mc := c .mc
218197 comprLength := len (data ) - 7
219- if debugTrace {
198+ if debug {
220199 fmt .Printf (
221200 "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v" ,
222201 comprLength , uncompressedLen , mc .compressSequence )
@@ -227,8 +206,8 @@ func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, e
227206 data [3 ] = mc .compressSequence
228207 putUint24 (data [4 :7 ], uncompressedLen )
229208
230- if n , err := mc .writeWithTimeout (data ); err != nil {
231- // mc.log("writing compressed packet:", err)
209+ n , err := mc .writeWithTimeout (data )
210+ if err != nil {
232211 return n , err
233212 }
234213
0 commit comments