Skip to content

Commit a3895fe

Browse files
neildgopherbot
authored andcommitted
database/sql: avoid closing Rows while scan is in progress
A database/sql/driver.Rows can return database-owned data from Rows.Next. The driver.Rows documentation doesn't explicitly document the lifetime guarantees for this data, but a reasonable expectation is that the caller of Next should only access it until the next call to Rows.Close or Rows.Next. Avoid violating that constraint when a query is cancelled while a call to database/sql.Rows.Scan (note the difference between the two different Rows types!) is in progress. We previously took care to avoid closing a driver.Rows while the user has access to driver-owned memory via a RawData, but we could still close a driver.Rows while a Scan call was in the process of reading previously-returned driver-owned data. Update the fake DB used in database/sql tests to invalidate returned data to help catch other places we might be incorrectly retaining it. Fixes #74831. Change-Id: Ice45b5fad51b679c38e3e1d21ef39156b56d6037 Reviewed-on: https://go-internal-review.googlesource.com/c/go/+/2540 Reviewed-by: Roland Shoemaker <[email protected]> Reviewed-by: Neal Patel <[email protected]> Reviewed-on: https://go-review.googlesource.com/c/go/+/693735 Auto-Submit: Dmitri Shuralyov <[email protected]> Reviewed-by: Roland Shoemaker <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Dmitri Shuralyov <[email protected]>
1 parent 608e9fa commit a3895fe

File tree

4 files changed

+90
-49
lines changed

4 files changed

+90
-49
lines changed

src/database/sql/convert.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,6 @@ func convertAssignRows(dest, src any, rows *Rows) error {
335335
if rows == nil {
336336
return errors.New("invalid context to convert cursor rows, missing parent *Rows")
337337
}
338-
rows.closemu.Lock()
339338
*d = Rows{
340339
dc: rows.dc,
341340
releaseConn: func(error) {},
@@ -351,7 +350,6 @@ func convertAssignRows(dest, src any, rows *Rows) error {
351350
parentCancel()
352351
}
353352
}
354-
rows.closemu.Unlock()
355353
return nil
356354
}
357355
}

src/database/sql/fakedb_test.go

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package sql
66

77
import (
8+
"bytes"
89
"context"
910
"database/sql/driver"
1011
"errors"
@@ -15,7 +16,6 @@ import (
1516
"strconv"
1617
"strings"
1718
"sync"
18-
"sync/atomic"
1919
"testing"
2020
"time"
2121
)
@@ -91,8 +91,6 @@ func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
9191
type fakeDB struct {
9292
name string
9393

94-
useRawBytes atomic.Bool
95-
9694
mu sync.Mutex
9795
tables map[string]*table
9896
badConn bool
@@ -684,8 +682,6 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm
684682
switch cmd {
685683
case "WIPE":
686684
// Nothing
687-
case "USE_RAWBYTES":
688-
c.db.useRawBytes.Store(true)
689685
case "SELECT":
690686
stmt, err = c.prepareSelect(stmt, parts)
691687
case "CREATE":
@@ -789,9 +785,6 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
789785
case "WIPE":
790786
db.wipe()
791787
return driver.ResultNoRows, nil
792-
case "USE_RAWBYTES":
793-
s.c.db.useRawBytes.Store(true)
794-
return driver.ResultNoRows, nil
795788
case "CREATE":
796789
if err := db.createTable(s.table, s.colName, s.colType); err != nil {
797790
return nil, err
@@ -1076,10 +1069,9 @@ type rowsCursor struct {
10761069
errPos int
10771070
err error
10781071

1079-
// a clone of slices to give out to clients, indexed by the
1080-
// original slice's first byte address. we clone them
1081-
// just so we're able to corrupt them on close.
1082-
bytesClone map[*byte][]byte
1072+
// Data returned to clients.
1073+
// We clone and stash it here so it can be invalidated by Close and Next.
1074+
driverOwnedMemory [][]byte
10831075

10841076
// Every operation writes to line to enable the race detector
10851077
// check for data races.
@@ -1096,9 +1088,19 @@ func (rc *rowsCursor) touchMem() {
10961088
rc.line++
10971089
}
10981090

1091+
func (rc *rowsCursor) invalidateDriverOwnedMemory() {
1092+
for _, buf := range rc.driverOwnedMemory {
1093+
for i := range buf {
1094+
buf[i] = 'x'
1095+
}
1096+
}
1097+
rc.driverOwnedMemory = nil
1098+
}
1099+
10991100
func (rc *rowsCursor) Close() error {
11001101
rc.touchMem()
11011102
rc.parentMem.touchMem()
1103+
rc.invalidateDriverOwnedMemory()
11021104
rc.closed = true
11031105
return rc.closeErr
11041106
}
@@ -1129,27 +1131,22 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
11291131
if rc.posRow >= len(rc.rows[rc.posSet]) {
11301132
return io.EOF // per interface spec
11311133
}
1134+
// Corrupt any previously returned bytes.
1135+
rc.invalidateDriverOwnedMemory()
11321136
for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
11331137
// TODO(bradfitz): convert to subset types? naah, I
11341138
// think the subset types should only be input to
11351139
// driver, but the sql package should be able to handle
11361140
// a wider range of types coming out of drivers. all
11371141
// for ease of drivers, and to prevent drivers from
11381142
// messing up conversions or doing them differently.
1139-
dest[i] = v
1140-
1141-
if bs, ok := v.([]byte); ok && !rc.db.useRawBytes.Load() {
1142-
if rc.bytesClone == nil {
1143-
rc.bytesClone = make(map[*byte][]byte)
1144-
}
1145-
clone, ok := rc.bytesClone[&bs[0]]
1146-
if !ok {
1147-
clone = make([]byte, len(bs))
1148-
copy(clone, bs)
1149-
rc.bytesClone[&bs[0]] = clone
1150-
}
1151-
dest[i] = clone
1143+
if bs, ok := v.([]byte); ok {
1144+
// Clone []bytes and stash for later invalidation.
1145+
bs = bytes.Clone(bs)
1146+
rc.driverOwnedMemory = append(rc.driverOwnedMemory, bs)
1147+
v = bs
11521148
}
1149+
dest[i] = v
11531150
}
11541151
return nil
11551152
}

src/database/sql/sql.go

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3368,38 +3368,36 @@ func (rs *Rows) Scan(dest ...any) error {
33683368
// without calling Next.
33693369
return fmt.Errorf("sql: Scan called without calling Next (closemuScanHold)")
33703370
}
3371+
33713372
rs.closemu.RLock()
3373+
rs.raw = rs.raw[:0]
3374+
err := rs.scanLocked(dest...)
3375+
if err == nil && scanArgsContainRawBytes(dest) {
3376+
rs.closemuScanHold = true
3377+
} else {
3378+
rs.closemu.RUnlock()
3379+
}
3380+
return err
3381+
}
33723382

3383+
func (rs *Rows) scanLocked(dest ...any) error {
33733384
if rs.lasterr != nil && rs.lasterr != io.EOF {
3374-
rs.closemu.RUnlock()
33753385
return rs.lasterr
33763386
}
33773387
if rs.closed {
3378-
err := rs.lasterrOrErrLocked(errRowsClosed)
3379-
rs.closemu.RUnlock()
3380-
return err
3381-
}
3382-
3383-
if scanArgsContainRawBytes(dest) {
3384-
rs.closemuScanHold = true
3385-
rs.raw = rs.raw[:0]
3386-
} else {
3387-
rs.closemu.RUnlock()
3388+
return rs.lasterrOrErrLocked(errRowsClosed)
33883389
}
33893390

33903391
if rs.lastcols == nil {
3391-
rs.closemuRUnlockIfHeldByScan()
33923392
return errors.New("sql: Scan called without calling Next")
33933393
}
33943394
if len(dest) != len(rs.lastcols) {
3395-
rs.closemuRUnlockIfHeldByScan()
33963395
return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
33973396
}
33983397

33993398
for i, sv := range rs.lastcols {
34003399
err := convertAssignRows(dest[i], sv, rs)
34013400
if err != nil {
3402-
rs.closemuRUnlockIfHeldByScan()
34033401
return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
34043402
}
34053403
}

src/database/sql/sql_test.go

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package sql
66

77
import (
8+
"bytes"
89
"context"
910
"database/sql/driver"
1011
"errors"
@@ -4434,10 +4435,6 @@ func testContextCancelDuringRawBytesScan(t *testing.T, mode string) {
44344435
db := newTestDB(t, "people")
44354436
defer closeDB(t, db)
44364437

4437-
if _, err := db.Exec("USE_RAWBYTES"); err != nil {
4438-
t.Fatal(err)
4439-
}
4440-
44414438
// cancel used to call close asynchronously.
44424439
// This test checks that it waits so as not to interfere with RawBytes.
44434440
ctx, cancel := context.WithCancel(context.Background())
@@ -4529,6 +4526,61 @@ func TestContextCancelBetweenNextAndErr(t *testing.T) {
45294526
}
45304527
}
45314528

4529+
type testScanner struct {
4530+
scanf func(src any) error
4531+
}
4532+
4533+
func (ts testScanner) Scan(src any) error { return ts.scanf(src) }
4534+
4535+
func TestContextCancelDuringScan(t *testing.T) {
4536+
db := newTestDB(t, "people")
4537+
defer closeDB(t, db)
4538+
4539+
ctx, cancel := context.WithCancel(context.Background())
4540+
defer cancel()
4541+
4542+
scanStart := make(chan any)
4543+
scanEnd := make(chan error)
4544+
scanner := &testScanner{
4545+
scanf: func(src any) error {
4546+
scanStart <- src
4547+
return <-scanEnd
4548+
},
4549+
}
4550+
4551+
// Start a query, and pause it mid-scan.
4552+
want := []byte("Alice")
4553+
r, err := db.QueryContext(ctx, "SELECT|people|name|name=?", string(want))
4554+
if err != nil {
4555+
t.Fatal(err)
4556+
}
4557+
if !r.Next() {
4558+
t.Fatalf("r.Next() = false, want true")
4559+
}
4560+
go func() {
4561+
r.Scan(scanner)
4562+
}()
4563+
got := <-scanStart
4564+
defer close(scanEnd)
4565+
gotBytes, ok := got.([]byte)
4566+
if !ok {
4567+
t.Fatalf("r.Scan returned %T, want []byte", got)
4568+
}
4569+
if !bytes.Equal(gotBytes, want) {
4570+
t.Fatalf("before cancel: r.Scan returned %q, want %q", gotBytes, want)
4571+
}
4572+
4573+
// Cancel the query.
4574+
// Sleep to give it a chance to finish canceling.
4575+
cancel()
4576+
time.Sleep(10 * time.Millisecond)
4577+
4578+
// Cancelling the query should not have changed the result.
4579+
if !bytes.Equal(gotBytes, want) {
4580+
t.Fatalf("after cancel: r.Scan result is now %q, want %q", gotBytes, want)
4581+
}
4582+
}
4583+
45324584
func TestNilErrorAfterClose(t *testing.T) {
45334585
db := newTestDB(t, "people")
45344586
defer closeDB(t, db)
@@ -4562,10 +4614,6 @@ func TestRawBytesReuse(t *testing.T) {
45624614
db := newTestDB(t, "people")
45634615
defer closeDB(t, db)
45644616

4565-
if _, err := db.Exec("USE_RAWBYTES"); err != nil {
4566-
t.Fatal(err)
4567-
}
4568-
45694617
var raw RawBytes
45704618

45714619
// The RawBytes in this query aliases driver-owned memory.

0 commit comments

Comments
 (0)