9
9
package mysql
10
10
11
11
import (
12
- "database/sql/driver"
13
12
"fmt"
14
13
"io"
15
14
"os"
21
20
readerRegister map [string ]func () io.Reader
22
21
)
23
22
24
- func init () {
25
- fileRegister = make (map [string ]bool )
26
- readerRegister = make (map [string ]func () io.Reader )
27
- }
28
-
29
23
// RegisterLocalFile adds the given file to the file whitelist,
30
24
// so that it can be used by "LOAD DATA LOCAL INFILE <filepath>".
31
25
// Alternatively you can allow the use of all local files with
@@ -38,6 +32,11 @@ func init() {
38
32
// ...
39
33
//
40
34
func RegisterLocalFile (filePath string ) {
35
+ // lazy map init
36
+ if fileRegister == nil {
37
+ fileRegister = make (map [string ]bool )
38
+ }
39
+
41
40
fileRegister [strings .Trim (filePath , `"` )] = true
42
41
}
43
42
@@ -62,6 +61,11 @@ func DeregisterLocalFile(filePath string) {
62
61
// ...
63
62
//
64
63
func RegisterReaderHandler (name string , handler func () io.Reader ) {
64
+ // lazy map init
65
+ if readerRegister == nil {
66
+ readerRegister = make (map [string ]func () io.Reader )
67
+ }
68
+
65
69
readerRegister [name ] = handler
66
70
}
67
71
@@ -71,71 +75,81 @@ func DeregisterReaderHandler(name string) {
71
75
delete (readerRegister , name )
72
76
}
73
77
78
+ func deferredClose (err * error , closer io.Closer ) {
79
+ closeErr := closer .Close ()
80
+ if * err == nil {
81
+ * err = closeErr
82
+ }
83
+ }
84
+
74
85
func (mc * mysqlConn ) handleInFileRequest (name string ) (err error ) {
75
86
var rdr io.Reader
76
- data := make ( []byte , 4 + mc . maxWriteSize )
87
+ var data []byte
77
88
78
89
if strings .HasPrefix (name , "Reader::" ) { // io.Reader
79
90
name = name [8 :]
80
- handler , inMap := readerRegister [name ]
81
- if handler != nil {
91
+ if handler , inMap := readerRegister [name ]; inMap {
82
92
rdr = handler ()
83
- }
84
- if rdr == nil {
85
- if ! inMap {
86
- err = fmt .Errorf ("Reader '%s' is not registered" , name )
93
+ if rdr != nil {
94
+ data = make ([]byte , 4 + mc .maxWriteSize )
95
+
96
+ if cl , ok := rdr .(io.Closer ); ok {
97
+ defer deferredClose (& err , cl )
98
+ }
87
99
} else {
88
100
err = fmt .Errorf ("Reader '%s' is <nil>" , name )
89
101
}
102
+ } else {
103
+ err = fmt .Errorf ("Reader '%s' is not registered" , name )
90
104
}
91
105
} else { // File
92
106
name = strings .Trim (name , `"` )
93
107
if mc .cfg .allowAllFiles || fileRegister [name ] {
94
- rdr , err = os .Open (name )
108
+ var file * os.File
109
+ var fi os.FileInfo
110
+
111
+ if file , err = os .Open (name ); err == nil {
112
+ defer deferredClose (& err , file )
113
+
114
+ // get file size
115
+ if fi , err = file .Stat (); err == nil {
116
+ rdr = file
117
+ if fileSize := int (fi .Size ()); fileSize <= mc .maxWriteSize {
118
+ data = make ([]byte , 4 + fileSize )
119
+ } else if fileSize <= mc .maxPacketAllowed {
120
+ data = make ([]byte , 4 + mc .maxWriteSize )
121
+ } else {
122
+ err = fmt .Errorf ("Local File '%s' too large: Size: %d, Max: %d" , name , fileSize , mc .maxPacketAllowed )
123
+ }
124
+ }
125
+ }
95
126
} else {
96
127
err = fmt .Errorf ("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files" , name )
97
128
}
98
129
}
99
130
100
- if rdc , ok := rdr .(io.ReadCloser ); ok {
101
- defer func () {
102
- if err == nil {
103
- err = rdc .Close ()
104
- } else {
105
- rdc .Close ()
106
- }
107
- }()
108
- }
109
-
110
131
// send content packets
111
- var ioErr error
112
132
if err == nil {
113
133
var n int
114
- for err == nil && ioErr == nil {
134
+ for err == nil {
115
135
n , err = rdr .Read (data [4 :])
116
136
if n > 0 {
117
- ioErr = mc .writePacket (data [:4 + n ])
137
+ if ioErr := mc .writePacket (data [:4 + n ]); ioErr != nil {
138
+ return ioErr
139
+ }
118
140
}
119
141
}
120
142
if err == io .EOF {
121
143
err = nil
122
144
}
123
- if ioErr != nil {
124
- errLog .Print (ioErr .Error ())
125
- return driver .ErrBadConn
126
- }
127
145
}
128
146
129
147
// send empty packet (termination)
130
- ioErr = mc .writePacket ([]byte {
131
- 0x00 ,
132
- 0x00 ,
133
- 0x00 ,
134
- mc .sequence ,
135
- })
136
- if ioErr != nil {
137
- errLog .Print (ioErr .Error ())
138
- return driver .ErrBadConn
148
+ if data == nil {
149
+ data = make ([]byte , 4 )
150
+ }
151
+ if ioErr := mc .writePacket (data [:4 ]); ioErr != nil {
152
+ return ioErr
139
153
}
140
154
141
155
// read OK packet
0 commit comments