Skip to content

Commit 8e1f894

Browse files
committed
Implement MariaDB metadata skipping.
Refactor handshake packet handling to support extended capabilities Updated the readHandshakePacket and writeHandshakeResponsePacket functions to include server capabilities and extended capabilities. Adjusted related tests and connection logic to accommodate these changes, ensuring compatibility with MariaDB and improved handling of client capabilities.
1 parent 98f445c commit 8e1f894

File tree

8 files changed

+171
-101
lines changed

8 files changed

+171
-101
lines changed

auth_test.go

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) {
8989
if err != nil {
9090
t.Fatal(err)
9191
}
92-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
92+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
9393
if err != nil {
9494
t.Fatal(err)
9595
}
@@ -134,7 +134,7 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) {
134134
if err != nil {
135135
t.Fatal(err)
136136
}
137-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
137+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
138138
if err != nil {
139139
t.Fatal(err)
140140
}
@@ -176,7 +176,7 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) {
176176
if err != nil {
177177
t.Fatal(err)
178178
}
179-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
179+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
180180
if err != nil {
181181
t.Fatal(err)
182182
}
@@ -232,7 +232,7 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) {
232232
if err != nil {
233233
t.Fatal(err)
234234
}
235-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
235+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
236236
if err != nil {
237237
t.Fatal(err)
238238
}
@@ -284,7 +284,7 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) {
284284
if err != nil {
285285
t.Fatal(err)
286286
}
287-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
287+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
288288
if err != nil {
289289
t.Fatal(err)
290290
}
@@ -357,7 +357,7 @@ func TestAuthFastCleartextPassword(t *testing.T) {
357357
if err != nil {
358358
t.Fatal(err)
359359
}
360-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
360+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
361361
if err != nil {
362362
t.Fatal(err)
363363
}
@@ -400,7 +400,7 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) {
400400
if err != nil {
401401
t.Fatal(err)
402402
}
403-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
403+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
404404
if err != nil {
405405
t.Fatal(err)
406406
}
@@ -459,7 +459,7 @@ func TestAuthFastNativePassword(t *testing.T) {
459459
if err != nil {
460460
t.Fatal(err)
461461
}
462-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
462+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
463463
if err != nil {
464464
t.Fatal(err)
465465
}
@@ -502,7 +502,7 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) {
502502
if err != nil {
503503
t.Fatal(err)
504504
}
505-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
505+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
506506
if err != nil {
507507
t.Fatal(err)
508508
}
@@ -544,7 +544,7 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) {
544544
if err != nil {
545545
t.Fatal(err)
546546
}
547-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
547+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
548548
if err != nil {
549549
t.Fatal(err)
550550
}
@@ -592,7 +592,7 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) {
592592
if err != nil {
593593
t.Fatal(err)
594594
}
595-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
595+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
596596
if err != nil {
597597
t.Fatal(err)
598598
}
@@ -641,7 +641,7 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) {
641641
if err != nil {
642642
t.Fatal(err)
643643
}
644-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
644+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
645645
if err != nil {
646646
t.Fatal(err)
647647
}
@@ -678,7 +678,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
678678
// unset TLS config to prevent the actual establishment of a TLS wrapper
679679
mc.cfg.TLS = nil
680680

681-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
681+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
682682
if err != nil {
683683
t.Fatal(err)
684684
}
@@ -1343,7 +1343,7 @@ func TestEd25519Auth(t *testing.T) {
13431343
if err != nil {
13441344
t.Fatal(err)
13451345
}
1346-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
1346+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
13471347
if err != nil {
13481348
t.Fatal(err)
13491349
}

connection.go

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,22 @@ import (
2424
)
2525

2626
type mysqlConn struct {
27-
buf buffer
28-
netConn net.Conn
29-
rawConn net.Conn // underlying connection when netConn is TLS connection.
30-
result mysqlResult // managed by clearResult() and handleOkPacket().
31-
compIO *compIO
32-
cfg *Config
33-
connector *connector
34-
maxAllowedPacket int
35-
maxWriteSize int
36-
flags clientFlag
37-
status statusFlag
38-
sequence uint8
39-
compressSequence uint8
40-
parseTime bool
41-
compress bool
27+
buf buffer
28+
netConn net.Conn
29+
rawConn net.Conn // underlying connection when netConn is TLS connection.
30+
result mysqlResult // managed by clearResult() and handleOkPacket().
31+
compIO *compIO
32+
cfg *Config
33+
connector *connector
34+
maxAllowedPacket int
35+
maxWriteSize int
36+
clientCapabilities capabilityFlag
37+
clientExtCapabilities extendedCapabilityFlag
38+
status statusFlag
39+
sequence uint8
40+
compressSequence uint8
41+
parseTime bool
42+
compress bool
4243

4344
// for context support (Go 1.8+)
4445
watching bool
@@ -229,7 +230,15 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
229230
}
230231

231232
if columnCount > 0 {
232-
err = mc.readUntilEOF()
233+
if mc.clientExtCapabilities&clientCacheMetadata != 0 {
234+
stmt.columns, err = mc.readColumns(int(columnCount))
235+
if err != nil {
236+
return nil, err
237+
}
238+
} else {
239+
// skip column definition packets and intermediate EOF packet
240+
err = mc.readUntilEOF()
241+
}
233242
}
234243
}
235244

@@ -370,7 +379,7 @@ func (mc *mysqlConn) exec(query string) error {
370379
}
371380

372381
// Read Result
373-
resLen, err := handleOk.readResultSetHeaderPacket()
382+
resLen, _, err := handleOk.readResultSetHeaderPacket()
374383
if err != nil {
375384
return err
376385
}
@@ -419,7 +428,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
419428

420429
// Read Result
421430
var resLen int
422-
resLen, err = handleOk.readResultSetHeaderPacket()
431+
resLen, _, err = handleOk.readResultSetHeaderPacket()
423432
if err != nil {
424433
return nil, err
425434
}
@@ -453,7 +462,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
453462
}
454463

455464
// Read Result
456-
resLen, err := handleOk.readResultSetHeaderPacket()
465+
resLen, _, err := handleOk.readResultSetHeaderPacket()
457466
if err == nil {
458467
rows := new(textRows)
459468
rows.mc = mc

connector.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
131131
mc.buf = newBuffer()
132132

133133
// Reading Handshake Initialization Packet
134-
authData, plugin, err := mc.readHandshakePacket()
134+
authData, serverCapabilities, serverExtendedCapabilities, plugin, err := mc.readHandshakePacket()
135135
if err != nil {
136136
mc.cleanup()
137137
return nil, err
@@ -153,7 +153,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
153153
return nil, err
154154
}
155155
}
156-
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
156+
if err = mc.writeHandshakeResponsePacket(authResp, serverCapabilities, serverExtendedCapabilities, plugin); err != nil {
157157
mc.cleanup()
158158
return nil, err
159159
}
@@ -167,7 +167,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
167167
return nil, err
168168
}
169169

170-
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
170+
if mc.cfg.compress && mc.clientCapabilities&clientCompress > 0 {
171171
mc.compress = true
172172
mc.compIO = newCompIO(mc)
173173
}

const.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ const (
4343
)
4444

4545
// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
46-
type clientFlag uint32
46+
type capabilityFlag uint32
4747

4848
const (
49-
clientLongPassword clientFlag = 1 << iota
49+
clientMySQL capabilityFlag = 1 << iota
5050
clientFoundRows
5151
clientLongFlag
5252
clientConnectWithDB
@@ -73,6 +73,20 @@ const (
7373
clientDeprecateEOF
7474
)
7575

76+
// https://mariadb.com/kb/en/connection/#capabilities
77+
type extendedCapabilityFlag uint32
78+
79+
const (
80+
progressIndicator extendedCapabilityFlag = 1 << iota
81+
clientComMulti
82+
clientStmtBulkOperations
83+
clientExtendedMetadata
84+
clientCacheMetadata
85+
clientUnitBulkResult
86+
)
87+
88+
// https://mariadb.com/kb/en/connection/#capabilities
89+
7690
const (
7791
comQuit byte = iota + 1
7892
comInitDB

0 commit comments

Comments
 (0)