Skip to content

Commit 58ac805

Browse files
committed
Merge pull request #130 from go-sql-driver/dsn_parser
New DSN parser
2 parents 0792be4 + 6d1a06d commit 58ac805

File tree

5 files changed

+207
-111
lines changed

5 files changed

+207
-111
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Changes:
1313
- Refactored the driver tests
1414
- Added more benchmarks and moved all to a separate file
1515
- Other small refactoring
16+
- DSN parameter values must now be url.QueryEscape'ed. This allows text values to contain special characters, such as '&'.
1617

1718
New Features:
1819

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ A DSN in its fullest form:
7878
username:password@protocol(address)/dbname?param=value
7979
```
8080

81-
Except of the databasename, all values are optional. So the minimal DSN is:
81+
Except for the databasename, all values are optional. So the minimal DSN is:
8282
```
8383
/dbname
8484
```
@@ -110,7 +110,7 @@ Possible Parameters are:
110110
* `allowOldPasswords`: `allowAllFiles=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords).
111111
* `charset`: Sets the charset used for client-server interaction ("SET NAMES `value`"). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`).
112112
* `clientFoundRows`: `clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed.
113-
* `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
113+
* `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details. Please keep in mind, that param values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `US%2FPacific`.
114114
* `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
115115
* `strict`: Enable strict mode. MySQL warnings are treated as errors.
116116
* `timeout`: **Driver** side connection timeout. The value must be a string of decimal numbers, each with optional fraction and a unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. To set a server side timeout, use the parameter [`wait_timeout`](http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html#sysvar_wait_timeout).
@@ -122,6 +122,8 @@ All other parameters are interpreted as system variables:
122122
* `tx_isolation`: *"SET [tx_isolation](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation)=`value`"*
123123
* `param`: *"SET `param`=`value`"*
124124

125+
***The values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed!***
126+
125127
#### Examples
126128
```
127129
user@unix(/path/to/socket)/dbname
@@ -132,7 +134,7 @@ user:password@tcp(localhost:5555)/dbname?autocommit=true
132134
```
133135

134136
```
135-
user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?tls=skip-verify&charset=utf8mb4,utf8
137+
user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?tls=skip-verify&charset=utf8mb4,utf8&sys_var=withSlash%2FandAt%40
136138
```
137139

138140
```

driver_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"io"
1616
"io/ioutil"
1717
"net"
18+
"net/url"
1819
"os"
1920
"strings"
2021
"testing"
@@ -206,7 +207,7 @@ func TestTimezoneConversion(t *testing.T) {
206207
}
207208

208209
for _, tz := range zones {
209-
runTests(t, dsn+"&parseTime=true&loc="+tz, tzTest)
210+
runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest)
210211
}
211212
}
212213

utils.go

Lines changed: 151 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,26 @@ import (
1313
"crypto/tls"
1414
"database/sql/driver"
1515
"encoding/binary"
16+
"errors"
1617
"fmt"
1718
"io"
1819
"log"
20+
"net/url"
1921
"os"
20-
"regexp"
2122
"strings"
2223
"time"
2324
)
2425

2526
var (
2627
errLog *log.Logger // Error Logger
27-
dsnPattern *regexp.Regexp // Data Source Name Parser
2828
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
29+
30+
errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?")
31+
errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)")
2932
)
3033

3134
func init() {
3235
errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
33-
34-
dsnPattern = regexp.MustCompile(
35-
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
36-
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
37-
`\/(?P<dbname>.*?)` + // /dbname
38-
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
39-
4036
tlsConfigRegister = make(map[string]*tls.Config)
4137
}
4238

@@ -77,98 +73,69 @@ func DeregisterTLSConfig(key string) {
7773
delete(tlsConfigRegister, key)
7874
}
7975

76+
// parseDSN parses the DSN string to a config
8077
func parseDSN(dsn string) (cfg *config, err error) {
8178
cfg = new(config)
82-
cfg.params = make(map[string]string)
83-
84-
matches := dsnPattern.FindStringSubmatch(dsn)
85-
names := dsnPattern.SubexpNames()
86-
87-
for i, match := range matches {
88-
switch names[i] {
89-
case "user":
90-
cfg.user = match
91-
case "passwd":
92-
cfg.passwd = match
93-
case "net":
94-
cfg.net = match
95-
case "addr":
96-
cfg.addr = match
97-
case "dbname":
98-
cfg.dbname = match
99-
case "params":
100-
for _, v := range strings.Split(match, "&") {
101-
param := strings.SplitN(v, "=", 2)
102-
if len(param) != 2 {
103-
continue
104-
}
105-
106-
// cfg params
107-
switch value := param[1]; param[0] {
108-
109-
// Disable INFILE whitelist / enable all files
110-
case "allowAllFiles":
111-
var isBool bool
112-
cfg.allowAllFiles, isBool = readBool(value)
113-
if !isBool {
114-
err = fmt.Errorf("Invalid Bool value: %s", value)
115-
return
116-
}
11779

118-
// Switch "rowsAffected" mode
119-
case "clientFoundRows":
120-
var isBool bool
121-
cfg.clientFoundRows, isBool = readBool(value)
122-
if !isBool {
123-
err = fmt.Errorf("Invalid Bool value: %s", value)
124-
return
125-
}
80+
// TODO: use strings.IndexByte when we can depend on Go 1.2
81+
82+
// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
83+
// Find the last '/' (since the password or the net addr might contain a '/')
84+
for i := len(dsn) - 1; i >= 0; i-- {
85+
if dsn[i] == '/' {
86+
var j, k int
87+
88+
// left part is empty if i <= 0
89+
if i > 0 {
90+
// [username[:password]@][protocol[(address)]]
91+
// Find the last '@' in dsn[:i]
92+
for j = i; j >= 0; j-- {
93+
if dsn[j] == '@' {
94+
// username[:password]
95+
// Find the first ':' in dsn[:j]
96+
for k = 0; k < j; k++ {
97+
if dsn[k] == ':' {
98+
cfg.passwd = dsn[k+1 : j]
99+
break
100+
}
101+
}
102+
cfg.user = dsn[:k]
126103

127-
// Use old authentication mode (pre MySQL 4.1)
128-
case "allowOldPasswords":
129-
var isBool bool
130-
cfg.allowOldPasswords, isBool = readBool(value)
131-
if !isBool {
132-
err = fmt.Errorf("Invalid Bool value: %s", value)
133-
return
104+
break
134105
}
106+
}
135107

136-
// Time Location
137-
case "loc":
138-
cfg.loc, err = time.LoadLocation(value)
139-
if err != nil {
140-
return
108+
// [protocol[(address)]]
109+
// Find the first '(' in dsn[j+1:i]
110+
for k = j + 1; k < i; k++ {
111+
if dsn[k] == '(' {
112+
// dsn[i-1] must be == ')' if an adress is specified
113+
if dsn[i-1] != ')' {
114+
if strings.ContainsRune(dsn[k+1:i], ')') {
115+
return nil, errInvalidDSNUnescaped
116+
}
117+
return nil, errInvalidDSNAddr
118+
}
119+
cfg.addr = dsn[k+1 : i-1]
120+
break
141121
}
122+
}
123+
cfg.net = dsn[j+1 : k]
124+
}
142125

143-
// Dial Timeout
144-
case "timeout":
145-
cfg.timeout, err = time.ParseDuration(value)
146-
if err != nil {
126+
// dbname[?param1=value1&...&paramN=valueN]
127+
// Find the first '?' in dsn[i+1:]
128+
for j = i + 1; j < len(dsn); j++ {
129+
if dsn[j] == '?' {
130+
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
147131
return
148132
}
149-
150-
// TLS-Encryption
151-
case "tls":
152-
boolValue, isBool := readBool(value)
153-
if isBool {
154-
if boolValue {
155-
cfg.tls = &tls.Config{}
156-
}
157-
} else {
158-
if strings.ToLower(value) == "skip-verify" {
159-
cfg.tls = &tls.Config{InsecureSkipVerify: true}
160-
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
161-
cfg.tls = tlsConfig
162-
} else {
163-
err = fmt.Errorf("Invalid value / unknown config name: %s", value)
164-
return
165-
}
166-
}
167-
168-
default:
169-
cfg.params[param[0]] = value
133+
break
170134
}
171135
}
136+
cfg.dbname = dsn[i+1 : j]
137+
138+
break
172139
}
173140
}
174141

@@ -179,17 +146,110 @@ func parseDSN(dsn string) (cfg *config, err error) {
179146

180147
// Set default adress if empty
181148
if cfg.addr == "" {
182-
cfg.addr = "127.0.0.1:3306"
149+
switch cfg.net {
150+
case "tcp":
151+
cfg.addr = "127.0.0.1:3306"
152+
case "unix":
153+
cfg.addr = "/tmp/mysql.sock"
154+
default:
155+
return nil, errors.New("Default addr for network '" + cfg.net + "' unknown")
156+
}
157+
183158
}
184159

185-
// Set default location if not set
160+
// Set default location if empty
186161
if cfg.loc == nil {
187162
cfg.loc = time.UTC
188163
}
189164

190165
return
191166
}
192167

168+
// parseDSNParams parses the DSN "query string"
169+
// Values must be url.QueryEscape'ed
170+
func parseDSNParams(cfg *config, params string) (err error) {
171+
for _, v := range strings.Split(params, "&") {
172+
param := strings.SplitN(v, "=", 2)
173+
if len(param) != 2 {
174+
continue
175+
}
176+
177+
// cfg params
178+
switch value := param[1]; param[0] {
179+
180+
// Disable INFILE whitelist / enable all files
181+
case "allowAllFiles":
182+
var isBool bool
183+
cfg.allowAllFiles, isBool = readBool(value)
184+
if !isBool {
185+
return fmt.Errorf("Invalid Bool value: %s", value)
186+
}
187+
188+
// Switch "rowsAffected" mode
189+
case "clientFoundRows":
190+
var isBool bool
191+
cfg.clientFoundRows, isBool = readBool(value)
192+
if !isBool {
193+
return fmt.Errorf("Invalid Bool value: %s", value)
194+
}
195+
196+
// Use old authentication mode (pre MySQL 4.1)
197+
case "allowOldPasswords":
198+
var isBool bool
199+
cfg.allowOldPasswords, isBool = readBool(value)
200+
if !isBool {
201+
return fmt.Errorf("Invalid Bool value: %s", value)
202+
}
203+
204+
// Time Location
205+
case "loc":
206+
if value, err = url.QueryUnescape(value); err != nil {
207+
return
208+
}
209+
cfg.loc, err = time.LoadLocation(value)
210+
if err != nil {
211+
return
212+
}
213+
214+
// Dial Timeout
215+
case "timeout":
216+
cfg.timeout, err = time.ParseDuration(value)
217+
if err != nil {
218+
return
219+
}
220+
221+
// TLS-Encryption
222+
case "tls":
223+
boolValue, isBool := readBool(value)
224+
if isBool {
225+
if boolValue {
226+
cfg.tls = &tls.Config{}
227+
}
228+
} else {
229+
if strings.ToLower(value) == "skip-verify" {
230+
cfg.tls = &tls.Config{InsecureSkipVerify: true}
231+
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
232+
cfg.tls = tlsConfig
233+
} else {
234+
return fmt.Errorf("Invalid value / unknown config name: %s", value)
235+
}
236+
}
237+
238+
default:
239+
// lazy init
240+
if cfg.params == nil {
241+
cfg.params = make(map[string]string)
242+
}
243+
244+
if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil {
245+
return
246+
}
247+
}
248+
}
249+
250+
return
251+
}
252+
193253
// Returns the bool value of the input.
194254
// The 2nd return value indicates if the input was a valid bool value
195255
func readBool(input string) (value bool, valid bool) {

0 commit comments

Comments
 (0)