@@ -19,12 +19,12 @@ import (
19
19
20
20
var (
21
21
fileRegister map [string ]bool
22
- readerRegister map [string ]io.Reader
22
+ readerRegister map [string ]func () io.Reader
23
23
)
24
24
25
25
func init () {
26
26
fileRegister = make (map [string ]bool )
27
- readerRegister = make (map [string ]io.Reader )
27
+ readerRegister = make (map [string ]func () io.Reader )
28
28
}
29
29
30
30
// RegisterLocalFile adds the given file to the file whitelist,
@@ -38,8 +38,8 @@ func RegisterLocalFile(filepath string) {
38
38
// RegisterReader registers a io.Reader so that it can be used by
39
39
// "LOAD DATA LOCAL INFILE Reader::<name>".
40
40
// The use of io.Reader in this context is NOT safe for concurrency!
41
- func RegisterReader (name string , rdr io.Reader ) {
42
- readerRegister [name ] = rdr
41
+ func RegisterReaderHandler (name string , cb func () io.Reader ) {
42
+ readerRegister [name ] = cb
43
43
}
44
44
45
45
func (mc * mysqlConn ) handleInFileRequest (name string ) (err error ) {
@@ -48,8 +48,10 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
48
48
49
49
if strings .HasPrefix (name , "Reader::" ) { // io.Reader
50
50
name = name [8 :]
51
- var inMap bool
52
- rdr , inMap = readerRegister [name ]
51
+ cb , inMap := readerRegister [name ]
52
+ if cb != nil {
53
+ rdr = cb ()
54
+ }
53
55
if rdr == nil {
54
56
if ! inMap {
55
57
err = fmt .Errorf ("Reader '%s' is not registered" , name )
0 commit comments