Skip to content

Commit 3c1e4f1

Browse files
committed
Explicitly close connection on ErrBadConn
1 parent 9786892 commit 3c1e4f1

File tree

5 files changed

+30
-7
lines changed

5 files changed

+30
-7
lines changed

connection.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,16 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
109109
}
110110

111111
func (mc *mysqlConn) Close() (err error) {
112-
mc.writeCommandPacket(comQuit)
112+
// Makes Close idempotent
113+
if mc.netConn != nil {
114+
mc.writeCommandPacket(comQuit)
115+
mc.netConn.Close()
116+
mc.netConn = nil
117+
}
118+
113119
mc.cfg = nil
114120
mc.buf = nil
115-
mc.netConn.Close()
116-
mc.netConn = nil
121+
117122
return
118123
}
119124

packets.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
2929
data, err = mc.buf.readNext(4)
3030
if err != nil {
3131
errLog.Print(err.Error())
32+
mc.Close()
3233
return nil, driver.ErrBadConn
3334
}
3435

@@ -37,6 +38,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
3738

3839
if pktLen < 1 {
3940
errLog.Print(errMalformPkt.Error())
41+
mc.Close()
4042
return nil, driver.ErrBadConn
4143
}
4244

@@ -51,8 +53,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
5153
mc.sequence++
5254

5355
// Read packet body [pktLen bytes]
54-
data, err = mc.buf.readNext(pktLen)
55-
if err == nil {
56+
if data, err = mc.buf.readNext(pktLen); err == nil {
5657
if pktLen < maxPacketSize {
5758
return data, nil
5859
}
@@ -66,6 +67,9 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
6667
return append(buf, data...), nil
6768
}
6869
}
70+
71+
// err case
72+
mc.Close()
6973
errLog.Print(err.Error())
7074
return nil, driver.ErrBadConn
7175
}

rows.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,15 @@ func (rows *mysqlRows) Columns() (columns []string) {
3838
func (rows *mysqlRows) Close() (err error) {
3939
// Remove unread packets from stream
4040
if !rows.eof {
41-
if rows.mc == nil {
41+
if rows.mc == nil || rows.mc.netConn == nil {
4242
return errInvalidConn
4343
}
4444

4545
err = rows.mc.readUntilEOF()
46+
47+
// explicitly set because readUntilEOF might return early in case of an
48+
// error
49+
rows.eof = true
4650
}
4751

4852
rows.mc = nil
@@ -55,7 +59,7 @@ func (rows *mysqlRows) Next(dest []driver.Value) (err error) {
5559
return io.EOF
5660
}
5761

58-
if rows.mc == nil {
62+
if rows.mc == nil || rows.mc.netConn == nil {
5963
return errInvalidConn
6064
}
6165

statement.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ type mysqlStmt struct {
2121
}
2222

2323
func (stmt *mysqlStmt) Close() (err error) {
24+
if stmt.mc == nil || stmt.mc.netConn == nil {
25+
return errInvalidConn
26+
}
27+
2428
err = stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
2529
stmt.mc = nil
2630
return

transaction.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,18 @@ type mysqlTx struct {
1414
}
1515

1616
func (tx *mysqlTx) Commit() (err error) {
17+
if tx.mc == nil {
18+
return errInvalidConn
19+
}
1720
err = tx.mc.exec("COMMIT")
1821
tx.mc = nil
1922
return
2023
}
2124

2225
func (tx *mysqlTx) Rollback() (err error) {
26+
if tx.mc == nil {
27+
return errInvalidConn
28+
}
2329
err = tx.mc.exec("ROLLBACK")
2430
tx.mc = nil
2531
return

0 commit comments

Comments
 (0)