Skip to content

Commit 7205943

Browse files
committed
register a io.Reader handle func instead
1 parent a2cbf81 commit 7205943

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

driver_test.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package mysql
33
import (
44
"database/sql"
55
"fmt"
6+
"io"
67
"io/ioutil"
78
"net"
89
"os"
@@ -657,11 +658,13 @@ func TestLoadData(t *testing.T) {
657658
mustExec(t, db, "TRUNCATE TABLE test")
658659

659660
// Reader
660-
file, err = os.Open(file.Name())
661-
if err != nil {
662-
t.Fatal(err)
663-
}
664-
RegisterReader("test", file)
661+
RegisterReaderHandler("test", func() io.Reader {
662+
file, err = os.Open(file.Name())
663+
if err != nil {
664+
t.Fatal(err)
665+
}
666+
return file
667+
})
665668
mustExec(t, db, "LOAD DATA LOCAL INFILE 'Reader::test' INTO TABLE test")
666669
verifyLoadDataResult(t, db)
667670
// negative test
@@ -671,7 +674,6 @@ func TestLoadData(t *testing.T) {
671674
} else if err.Error() != "Reader 'doesnotexist' is not registered" {
672675
t.Fatal(err.Error())
673676
}
674-
file.Close()
675677

676678
mustExec(t, db, "DROP TABLE IF EXISTS test")
677679
}

infile.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ import (
1919

2020
var (
2121
fileRegister map[string]bool
22-
readerRegister map[string]io.Reader
22+
readerRegister map[string]func() io.Reader
2323
)
2424

2525
func init() {
2626
fileRegister = make(map[string]bool)
27-
readerRegister = make(map[string]io.Reader)
27+
readerRegister = make(map[string]func() io.Reader)
2828
}
2929

3030
// RegisterLocalFile adds the given file to the file whitelist,
@@ -38,8 +38,8 @@ func RegisterLocalFile(filepath string) {
3838
// RegisterReader registers a io.Reader so that it can be used by
3939
// "LOAD DATA LOCAL INFILE Reader::<name>".
4040
// The use of io.Reader in this context is NOT safe for concurrency!
41-
func RegisterReader(name string, rdr io.Reader) {
42-
readerRegister[name] = rdr
41+
func RegisterReaderHandler(name string, cb func() io.Reader) {
42+
readerRegister[name] = cb
4343
}
4444

4545
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
@@ -48,8 +48,10 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
4848

4949
if strings.HasPrefix(name, "Reader::") { // io.Reader
5050
name = name[8:]
51-
var inMap bool
52-
rdr, inMap = readerRegister[name]
51+
cb, inMap := readerRegister[name]
52+
if cb != nil {
53+
rdr = cb()
54+
}
5355
if rdr == nil {
5456
if !inMap {
5557
err = fmt.Errorf("Reader '%s' is not registered", name)

0 commit comments

Comments
 (0)