Skip to content

Commit 59ce2f9

Browse files
committed
Merge pull request #42 from go-sql-driver/readerHandler
register a io.Reader handle func instead
2 parents a2cbf81 + 5398634 commit 59ce2f9

File tree

3 files changed

+34
-23
lines changed

3 files changed

+34
-23
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,9 @@ For this feature you need direct access to the package. Therefore you must chang
144144
import "github.com/go-sql-driver/mysql"
145145
```
146146

147-
Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (reccommended) or the whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` (might be insecure).
147+
Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (reccommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` (might be insecure).
148148

149-
`io.Reader`s must be registered with `mysql.RegisterReader(name, reader)`. They are available with the filepath `Reader::<name>` then.
149+
To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::<name>` then.
150150

151151
See also the [godoc of Go-MySQL-Driver](http://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation")
152152

driver_test.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package mysql
33
import (
44
"database/sql"
55
"fmt"
6+
"io"
67
"io/ioutil"
78
"net"
89
"os"
@@ -657,11 +658,13 @@ func TestLoadData(t *testing.T) {
657658
mustExec(t, db, "TRUNCATE TABLE test")
658659

659660
// Reader
660-
file, err = os.Open(file.Name())
661-
if err != nil {
662-
t.Fatal(err)
663-
}
664-
RegisterReader("test", file)
661+
RegisterReaderHandler("test", func() io.Reader {
662+
file, err = os.Open(file.Name())
663+
if err != nil {
664+
t.Fatal(err)
665+
}
666+
return file
667+
})
665668
mustExec(t, db, "LOAD DATA LOCAL INFILE 'Reader::test' INTO TABLE test")
666669
verifyLoadDataResult(t, db)
667670
// negative test
@@ -671,7 +674,6 @@ func TestLoadData(t *testing.T) {
671674
} else if err.Error() != "Reader 'doesnotexist' is not registered" {
672675
t.Fatal(err.Error())
673676
}
674-
file.Close()
675677

676678
mustExec(t, db, "DROP TABLE IF EXISTS test")
677679
}

infile.go

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ import (
1919

2020
var (
2121
fileRegister map[string]bool
22-
readerRegister map[string]io.Reader
22+
readerRegister map[string]func() io.Reader
2323
)
2424

2525
func init() {
2626
fileRegister = make(map[string]bool)
27-
readerRegister = make(map[string]io.Reader)
27+
readerRegister = make(map[string]func() io.Reader)
2828
}
2929

3030
// RegisterLocalFile adds the given file to the file whitelist,
@@ -35,11 +35,13 @@ func RegisterLocalFile(filepath string) {
3535
fileRegister[filepath] = true
3636
}
3737

38-
// RegisterReader registers a io.Reader so that it can be used by
39-
// "LOAD DATA LOCAL INFILE Reader::<name>".
40-
// The use of io.Reader in this context is NOT safe for concurrency!
41-
func RegisterReader(name string, rdr io.Reader) {
42-
readerRegister[name] = rdr
38+
// RegisterReaderHandler registers a handler function which is used
39+
// to receive a io.Reader.
40+
// The Reader can be used by "LOAD DATA LOCAL INFILE Reader::<name>".
41+
// If the handler returns a io.ReadCloser Close() is called when the
42+
// request is finished.
43+
func RegisterReaderHandler(name string, handler func() io.Reader) {
44+
readerRegister[name] = handler
4345
}
4446

4547
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
@@ -48,28 +50,35 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
4850

4951
if strings.HasPrefix(name, "Reader::") { // io.Reader
5052
name = name[8:]
51-
var inMap bool
52-
rdr, inMap = readerRegister[name]
53+
handler, inMap := readerRegister[name]
54+
if handler != nil {
55+
rdr = handler()
56+
}
5357
if rdr == nil {
5458
if !inMap {
5559
err = fmt.Errorf("Reader '%s' is not registered", name)
5660
} else {
5761
err = fmt.Errorf("Reader '%s' is <nil>", name)
5862
}
5963
}
60-
6164
} else { // File
6265
if fileRegister[name] || mc.cfg.params[`allowAllFiles`] == `true` {
63-
var file *os.File
64-
file, err = os.Open(name)
65-
defer file.Close()
66-
67-
rdr = file
66+
rdr, err = os.Open(name)
6867
} else {
6968
err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name)
7069
}
7170
}
7271

72+
if rdc, ok := rdr.(io.ReadCloser); ok {
73+
defer func() {
74+
if err == nil {
75+
err = rdc.Close()
76+
} else {
77+
rdc.Close()
78+
}
79+
}()
80+
}
81+
7382
// send content packets
7483
var ioErr error
7584
if err == nil {

0 commit comments

Comments
 (0)