@@ -13,30 +13,24 @@ import (
13
13
"crypto/tls"
14
14
"database/sql/driver"
15
15
"encoding/binary"
16
+ "errors"
16
17
"fmt"
17
18
"io"
18
19
"log"
19
20
"os"
20
- "regexp"
21
21
"strings"
22
22
"time"
23
23
)
24
24
25
25
var (
26
26
errLog * log.Logger // Error Logger
27
- dsnPattern * regexp.Regexp // Data Source Name Parser
28
27
tlsConfigRegister map [string ]* tls.Config // Register for custom tls.Configs
28
+
29
+ errInvalidDSN = errors .New ("Invalid DSN" )
29
30
)
30
31
31
32
func init () {
32
33
errLog = log .New (os .Stderr , "[MySQL] " , log .Ldate | log .Ltime | log .Lshortfile )
33
-
34
- dsnPattern = regexp .MustCompile (
35
- `^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
36
- `(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
37
- `\/(?P<dbname>.*?)` + // /dbname
38
- `(?:\?(?P<params>[^\?]*))?$` ) // [?param1=value1¶mN=valueN]
39
-
40
34
tlsConfigRegister = make (map [string ]* tls.Config )
41
35
}
42
36
@@ -79,96 +73,69 @@ func DeregisterTLSConfig(key string) {
79
73
80
74
func parseDSN (dsn string ) (cfg * config , err error ) {
81
75
cfg = new (config )
82
- cfg .params = make (map [string ]string )
83
-
84
- matches := dsnPattern .FindStringSubmatch (dsn )
85
- names := dsnPattern .SubexpNames ()
86
-
87
- for i , match := range matches {
88
- switch names [i ] {
89
- case "user" :
90
- cfg .user = match
91
- case "passwd" :
92
- cfg .passwd = match
93
- case "net" :
94
- cfg .net = match
95
- case "addr" :
96
- cfg .addr = match
97
- case "dbname" :
98
- cfg .dbname = match
99
- case "params" :
100
- for _ , v := range strings .Split (match , "&" ) {
101
- param := strings .SplitN (v , "=" , 2 )
102
- if len (param ) != 2 {
103
- continue
104
- }
105
-
106
- // cfg params
107
- switch value := param [1 ]; param [0 ] {
108
76
109
- // Disable INFILE whitelist / enable all files
110
- case "allowAllFiles" :
111
- var isBool bool
112
- cfg .allowAllFiles , isBool = readBool (value )
113
- if ! isBool {
114
- err = fmt .Errorf ("Invalid Bool value: %s" , value )
115
- return
116
- }
77
+ // TODO: use strings.IndexByte when we can depend on Go 1.2
78
+
79
+ // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
80
+ // Find the last '/'
81
+ for i := len (dsn ) - 1 ; i >= 0 ; i -- {
82
+ if dsn [i ] == '/' {
83
+ var j int
84
+
85
+ // left part is empty if i <= 0
86
+ if i > 0 {
87
+ // [username[:password]@][protocol[(address)]]
88
+ // Find the last '@' in dsn[:i]
89
+ for j = i ; j >= 0 ; j -- {
90
+ if dsn [j ] == '@' {
91
+ // username[:password]
92
+ // Find the first ':' in dsn[:j]
93
+ var k int
94
+ for k = 0 ; k < j ; k ++ {
95
+ if dsn [k ] == ':' {
96
+ cfg .passwd = dsn [k + 1 : j ]
97
+ break
98
+ }
99
+ }
100
+ cfg .user = dsn [:k ]
101
+
102
+ // [protocol[(address)]]
103
+ // Find the first '(' in dsn[j+1:i]
104
+ for k = j + 1 ; k < i ; k ++ {
105
+ if dsn [k ] == '(' {
106
+ // dsn[i-1] must be == ')' if an adress is specified
107
+ if dsn [i - 1 ] != ')' {
108
+ return nil , errInvalidDSN
109
+ }
110
+ cfg .addr = dsn [k + 1 : i - 1 ]
111
+ break
112
+ }
113
+ }
114
+ cfg .net = dsn [j + 1 : k ]
117
115
118
- // Switch "rowsAffected" mode
119
- case "clientFoundRows" :
120
- var isBool bool
121
- cfg .clientFoundRows , isBool = readBool (value )
122
- if ! isBool {
123
- err = fmt .Errorf ("Invalid Bool value: %s" , value )
124
- return
125
- }
126
-
127
- // Use old authentication mode (pre MySQL 4.1)
128
- case "allowOldPasswords" :
129
- var isBool bool
130
- cfg .allowOldPasswords , isBool = readBool (value )
131
- if ! isBool {
132
- err = fmt .Errorf ("Invalid Bool value: %s" , value )
133
- return
116
+ break
134
117
}
118
+ }
135
119
136
- // Time Location
137
- case "loc" :
138
- cfg .loc , err = time .LoadLocation (value )
139
- if err != nil {
140
- return
141
- }
120
+ // non-empty left part must contain an '@'
121
+ if j < 0 {
122
+ return nil , errInvalidDSN
123
+ }
124
+ }
142
125
143
- // Dial Timeout
144
- case "timeout" :
145
- cfg .timeout , err = time .ParseDuration (value )
146
- if err != nil {
126
+ // dbname[?param1=value1&...¶mN=valueN]
127
+ // Find the first '?' in dsn[i+1:]
128
+ for j = i + 1 ; j < len (dsn ); j ++ {
129
+ if dsn [j ] == '?' {
130
+ if err = parseDSNParams (cfg , dsn [j + 1 :]); err != nil {
147
131
return
148
132
}
149
-
150
- // TLS-Encryption
151
- case "tls" :
152
- boolValue , isBool := readBool (value )
153
- if isBool {
154
- if boolValue {
155
- cfg .tls = & tls.Config {}
156
- }
157
- } else {
158
- if strings .ToLower (value ) == "skip-verify" {
159
- cfg .tls = & tls.Config {InsecureSkipVerify : true }
160
- } else if tlsConfig , ok := tlsConfigRegister [value ]; ok {
161
- cfg .tls = tlsConfig
162
- } else {
163
- err = fmt .Errorf ("Invalid value / unknown config name: %s" , value )
164
- return
165
- }
166
- }
167
-
168
- default :
169
- cfg .params [param [0 ]] = value
133
+ break
170
134
}
171
135
}
136
+ cfg .dbname = dsn [i + 1 : j ]
137
+
138
+ break
172
139
}
173
140
}
174
141
@@ -179,7 +146,15 @@ func parseDSN(dsn string) (cfg *config, err error) {
179
146
180
147
// Set default adress if empty
181
148
if cfg .addr == "" {
182
- cfg .addr = "127.0.0.1:3306"
149
+ switch cfg .net {
150
+ case "tcp" :
151
+ cfg .addr = "127.0.0.1:3306"
152
+ case "unix" :
153
+ cfg .addr = "/tmp/mysql.sock"
154
+ default :
155
+ return nil , errors .New ("Default addr for network '" + cfg .net + "' unknown" )
156
+ }
157
+
183
158
}
184
159
185
160
// Set default location if not set
@@ -190,6 +165,81 @@ func parseDSN(dsn string) (cfg *config, err error) {
190
165
return
191
166
}
192
167
168
+ func parseDSNParams (cfg * config , params string ) (err error ) {
169
+ cfg .params = make (map [string ]string )
170
+
171
+ for _ , v := range strings .Split (params , "&" ) {
172
+ param := strings .SplitN (v , "=" , 2 )
173
+ if len (param ) != 2 {
174
+ continue
175
+ }
176
+
177
+ // cfg params
178
+ switch value := param [1 ]; param [0 ] {
179
+
180
+ // Disable INFILE whitelist / enable all files
181
+ case "allowAllFiles" :
182
+ var isBool bool
183
+ cfg .allowAllFiles , isBool = readBool (value )
184
+ if ! isBool {
185
+ return fmt .Errorf ("Invalid Bool value: %s" , value )
186
+ }
187
+
188
+ // Switch "rowsAffected" mode
189
+ case "clientFoundRows" :
190
+ var isBool bool
191
+ cfg .clientFoundRows , isBool = readBool (value )
192
+ if ! isBool {
193
+ return fmt .Errorf ("Invalid Bool value: %s" , value )
194
+ }
195
+
196
+ // Use old authentication mode (pre MySQL 4.1)
197
+ case "allowOldPasswords" :
198
+ var isBool bool
199
+ cfg .allowOldPasswords , isBool = readBool (value )
200
+ if ! isBool {
201
+ return fmt .Errorf ("Invalid Bool value: %s" , value )
202
+ }
203
+
204
+ // Time Location
205
+ case "loc" :
206
+ cfg .loc , err = time .LoadLocation (value )
207
+ if err != nil {
208
+ return
209
+ }
210
+
211
+ // Dial Timeout
212
+ case "timeout" :
213
+ cfg .timeout , err = time .ParseDuration (value )
214
+ if err != nil {
215
+ return
216
+ }
217
+
218
+ // TLS-Encryption
219
+ case "tls" :
220
+ boolValue , isBool := readBool (value )
221
+ if isBool {
222
+ if boolValue {
223
+ cfg .tls = & tls.Config {}
224
+ }
225
+ } else {
226
+ if strings .ToLower (value ) == "skip-verify" {
227
+ cfg .tls = & tls.Config {InsecureSkipVerify : true }
228
+ } else if tlsConfig , ok := tlsConfigRegister [value ]; ok {
229
+ cfg .tls = tlsConfig
230
+ } else {
231
+ return fmt .Errorf ("Invalid value / unknown config name: %s" , value )
232
+ }
233
+ }
234
+
235
+ default :
236
+ cfg .params [param [0 ]] = value
237
+ }
238
+ }
239
+
240
+ return
241
+ }
242
+
193
243
// Returns the bool value of the input.
194
244
// The 2nd return value indicates if the input was a valid bool value
195
245
func readBool (input string ) (value bool , valid bool ) {
0 commit comments