Skip to content

Commit 5d25a76

Browse files
committed
Merge pull request #174 from go-sql-driver/infile
infile: refactoring
2 parents c418c1b + 34f105c commit 5d25a76

File tree

1 file changed

+54
-40
lines changed

1 file changed

+54
-40
lines changed

infile.go

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
package mysql
1010

1111
import (
12-
"database/sql/driver"
1312
"fmt"
1413
"io"
1514
"os"
@@ -21,11 +20,6 @@ var (
2120
readerRegister map[string]func() io.Reader
2221
)
2322

24-
func init() {
25-
fileRegister = make(map[string]bool)
26-
readerRegister = make(map[string]func() io.Reader)
27-
}
28-
2923
// RegisterLocalFile adds the given file to the file whitelist,
3024
// so that it can be used by "LOAD DATA LOCAL INFILE <filepath>".
3125
// Alternatively you can allow the use of all local files with
@@ -38,6 +32,11 @@ func init() {
3832
// ...
3933
//
4034
func RegisterLocalFile(filePath string) {
35+
// lazy map init
36+
if fileRegister == nil {
37+
fileRegister = make(map[string]bool)
38+
}
39+
4140
fileRegister[strings.Trim(filePath, `"`)] = true
4241
}
4342

@@ -62,6 +61,11 @@ func DeregisterLocalFile(filePath string) {
6261
// ...
6362
//
6463
func RegisterReaderHandler(name string, handler func() io.Reader) {
64+
// lazy map init
65+
if readerRegister == nil {
66+
readerRegister = make(map[string]func() io.Reader)
67+
}
68+
6569
readerRegister[name] = handler
6670
}
6771

@@ -71,71 +75,81 @@ func DeregisterReaderHandler(name string) {
7175
delete(readerRegister, name)
7276
}
7377

78+
func deferredClose(err *error, closer io.Closer) {
79+
closeErr := closer.Close()
80+
if *err == nil {
81+
*err = closeErr
82+
}
83+
}
84+
7485
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
7586
var rdr io.Reader
76-
data := make([]byte, 4+mc.maxWriteSize)
87+
var data []byte
7788

7889
if strings.HasPrefix(name, "Reader::") { // io.Reader
7990
name = name[8:]
80-
handler, inMap := readerRegister[name]
81-
if handler != nil {
91+
if handler, inMap := readerRegister[name]; inMap {
8292
rdr = handler()
83-
}
84-
if rdr == nil {
85-
if !inMap {
86-
err = fmt.Errorf("Reader '%s' is not registered", name)
93+
if rdr != nil {
94+
data = make([]byte, 4+mc.maxWriteSize)
95+
96+
if cl, ok := rdr.(io.Closer); ok {
97+
defer deferredClose(&err, cl)
98+
}
8799
} else {
88100
err = fmt.Errorf("Reader '%s' is <nil>", name)
89101
}
102+
} else {
103+
err = fmt.Errorf("Reader '%s' is not registered", name)
90104
}
91105
} else { // File
92106
name = strings.Trim(name, `"`)
93107
if mc.cfg.allowAllFiles || fileRegister[name] {
94-
rdr, err = os.Open(name)
108+
var file *os.File
109+
var fi os.FileInfo
110+
111+
if file, err = os.Open(name); err == nil {
112+
defer deferredClose(&err, file)
113+
114+
// get file size
115+
if fi, err = file.Stat(); err == nil {
116+
rdr = file
117+
if fileSize := int(fi.Size()); fileSize <= mc.maxWriteSize {
118+
data = make([]byte, 4+fileSize)
119+
} else if fileSize <= mc.maxPacketAllowed {
120+
data = make([]byte, 4+mc.maxWriteSize)
121+
} else {
122+
err = fmt.Errorf("Local File '%s' too large: Size: %d, Max: %d", name, fileSize, mc.maxPacketAllowed)
123+
}
124+
}
125+
}
95126
} else {
96127
err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name)
97128
}
98129
}
99130

100-
if rdc, ok := rdr.(io.ReadCloser); ok {
101-
defer func() {
102-
if err == nil {
103-
err = rdc.Close()
104-
} else {
105-
rdc.Close()
106-
}
107-
}()
108-
}
109-
110131
// send content packets
111-
var ioErr error
112132
if err == nil {
113133
var n int
114-
for err == nil && ioErr == nil {
134+
for err == nil {
115135
n, err = rdr.Read(data[4:])
116136
if n > 0 {
117-
ioErr = mc.writePacket(data[:4+n])
137+
if ioErr := mc.writePacket(data[:4+n]); ioErr != nil {
138+
return ioErr
139+
}
118140
}
119141
}
120142
if err == io.EOF {
121143
err = nil
122144
}
123-
if ioErr != nil {
124-
errLog.Print(ioErr.Error())
125-
return driver.ErrBadConn
126-
}
127145
}
128146

129147
// send empty packet (termination)
130-
ioErr = mc.writePacket([]byte{
131-
0x00,
132-
0x00,
133-
0x00,
134-
mc.sequence,
135-
})
136-
if ioErr != nil {
137-
errLog.Print(ioErr.Error())
138-
return driver.ErrBadConn
148+
if data == nil {
149+
data = make([]byte, 4)
150+
}
151+
if ioErr := mc.writePacket(data[:4]); ioErr != nil {
152+
return ioErr
139153
}
140154

141155
// read OK packet

0 commit comments

Comments
 (0)