Skip to content

Commit cd41d9f

Browse files
arnehormannjulienschmidt
authored andcommitted
no panic on closed connection reuse
1 parent f553f33 commit cd41d9f

File tree

6 files changed

+73
-3
lines changed

6 files changed

+73
-3
lines changed

connection.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ func (mc *mysqlConn) handleParams() (err error) {
9696
}
9797

9898
func (mc *mysqlConn) Begin() (driver.Tx, error) {
99+
if mc.netConn == nil {
100+
return nil, errInvalidConn
101+
}
102+
99103
err := mc.exec("START TRANSACTION")
100104
if err == nil {
101105
return &mysqlTx{mc}, err
@@ -119,6 +123,10 @@ func (mc *mysqlConn) Close() (err error) {
119123
}
120124

121125
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
126+
if mc.netConn == nil {
127+
return nil, errInvalidConn
128+
}
129+
122130
// Send command
123131
err := mc.writeCommandPacketStr(comStmtPrepare, query)
124132
if err != nil {
@@ -148,6 +156,10 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
148156
}
149157

150158
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
159+
if mc.netConn == nil {
160+
return nil, errInvalidConn
161+
}
162+
151163
if len(args) == 0 { // no args, fastpath
152164
mc.affectedRows = 0
153165
mc.insertId = 0
@@ -191,6 +203,9 @@ func (mc *mysqlConn) exec(query string) (err error) {
191203
}
192204

193205
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
206+
if mc.netConn == nil {
207+
return nil, errInvalidConn
208+
}
194209
if len(args) == 0 { // no args, fastpath
195210
// Send command
196211
err := mc.writeCommandPacketStr(comQuery, query)

driver_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,41 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows)
101101
return rows
102102
}
103103

104+
func TestReuseClosedConnection(t *testing.T) {
105+
// this test does not use sql.database, it uses the driver directly
106+
if !available {
107+
t.Logf("MySQL-Server not running on %s. Skipping TestReuseClosedConnection", netAddr)
108+
return
109+
}
110+
driver := &mysqlDriver{}
111+
conn, err := driver.Open(dsn)
112+
if err != nil {
113+
t.Fatalf("Error connecting: %s", err.Error())
114+
}
115+
stmt, err := conn.Prepare("DO 1")
116+
if err != nil {
117+
t.Fatalf("Error preparing statement: %s", err.Error())
118+
}
119+
_, err = stmt.Exec(nil)
120+
if err != nil {
121+
t.Fatalf("Error executing statement: %s", err.Error())
122+
}
123+
err = conn.Close()
124+
if err != nil {
125+
t.Fatalf("Error closing connection: %s", err.Error())
126+
}
127+
defer func() {
128+
if err := recover(); err != nil {
129+
t.Errorf("Panic after reusing a closed connection: %v", err)
130+
}
131+
}()
132+
_, err = stmt.Exec(nil)
133+
if err != nil && err != errInvalidConn {
134+
t.Errorf("Unexpected error '%s', expected '%s'",
135+
err.Error(), errInvalidConn.Error())
136+
}
137+
}
138+
104139
func TestCharset(t *testing.T) {
105140
mustSetCharset := func(charsetParam, expected string) {
106141
db, err := sql.Open("mysql", strings.Replace(dsn, charset, charsetParam, 1))

errors.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
)
1818

1919
var (
20+
errInvalidConn = errors.New("Invalid Connection")
2021
errMalformPkt = errors.New("Malformed Packet")
2122
errPktSync = errors.New("Commands out of sync. You can't run this command now")
2223
errPktSyncMul = errors.New("Commands out of sync. Did you run multiple statements at once?")

rows.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ package mysql
1111

1212
import (
1313
"database/sql/driver"
14-
"errors"
1514
"io"
1615
)
1716

@@ -44,7 +43,7 @@ func (rows *mysqlRows) Close() (err error) {
4443
// Remove unread packets from stream
4544
if !rows.eof {
4645
if rows.mc == nil || rows.mc.netConn == nil {
47-
return errors.New("Invalid Connection")
46+
return errInvalidConn
4847
}
4948

5049
err = rows.mc.readUntilEOF()
@@ -63,7 +62,7 @@ func (rows *mysqlRows) Next(dest []driver.Value) error {
6362
}
6463

6564
if rows.mc == nil || rows.mc.netConn == nil {
66-
return errors.New("Invalid Connection")
65+
return errInvalidConn
6766
}
6867

6968
// Fetch next row from stream

statement.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ type mysqlStmt struct {
2121
}
2222

2323
func (stmt *mysqlStmt) Close() (err error) {
24+
if stmt.mc == nil || stmt.mc.netConn == nil {
25+
return errInvalidConn
26+
}
27+
2428
err = stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
2529
stmt.mc = nil
2630
return
@@ -31,6 +35,10 @@ func (stmt *mysqlStmt) NumInput() int {
3135
}
3236

3337
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
38+
if stmt.mc.netConn == nil {
39+
return nil, errInvalidConn
40+
}
41+
3442
stmt.mc.affectedRows = 0
3543
stmt.mc.insertId = 0
3644

@@ -66,6 +74,10 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
6674
}
6775

6876
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
77+
if stmt.mc.netConn == nil {
78+
return nil, errInvalidConn
79+
}
80+
6981
// Send command
7082
err := stmt.writeExecutePacket(args)
7183
if err != nil {

transaction.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,20 @@ type mysqlTx struct {
1414
}
1515

1616
func (tx *mysqlTx) Commit() (err error) {
17+
if tx.mc == nil || tx.mc.netConn == nil {
18+
return errInvalidConn
19+
}
20+
1721
err = tx.mc.exec("COMMIT")
1822
tx.mc = nil
1923
return
2024
}
2125

2226
func (tx *mysqlTx) Rollback() (err error) {
27+
if tx.mc == nil || tx.mc.netConn == nil {
28+
return errInvalidConn
29+
}
30+
2331
err = tx.mc.exec("ROLLBACK")
2432
tx.mc = nil
2533
return

0 commit comments

Comments
 (0)