Skip to content

Commit dc02949

Browse files
committed
New DSN parser
+ Set right default addr for net=unix Go 1.2RC1 BenchmarkParseDSN_new 200000 10545 ns/op 4039 B/op 42 allocs/op BenchmarkParseDSN_old 10000 233313 ns/op 7588 B/op 91 allocs/op Go 1.1 BenchmarkParseDSN_new 200000 7940 ns/op 4204 B/op 42 allocs/op BenchmarkParseDSN_old 10000 264115 ns/op 8083 B/op 91 allocs/op
1 parent 0792be4 commit dc02949

File tree

2 files changed

+185
-107
lines changed

2 files changed

+185
-107
lines changed

utils.go

Lines changed: 141 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,24 @@ import (
1313
"crypto/tls"
1414
"database/sql/driver"
1515
"encoding/binary"
16+
"errors"
1617
"fmt"
1718
"io"
1819
"log"
1920
"os"
20-
"regexp"
2121
"strings"
2222
"time"
2323
)
2424

2525
var (
2626
errLog *log.Logger // Error Logger
27-
dsnPattern *regexp.Regexp // Data Source Name Parser
2827
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
28+
29+
errInvalidDSN = errors.New("Invalid DSN")
2930
)
3031

3132
func init() {
3233
errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
33-
34-
dsnPattern = regexp.MustCompile(
35-
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
36-
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
37-
`\/(?P<dbname>.*?)` + // /dbname
38-
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
39-
4034
tlsConfigRegister = make(map[string]*tls.Config)
4135
}
4236

@@ -79,96 +73,69 @@ func DeregisterTLSConfig(key string) {
7973

8074
func parseDSN(dsn string) (cfg *config, err error) {
8175
cfg = new(config)
82-
cfg.params = make(map[string]string)
83-
84-
matches := dsnPattern.FindStringSubmatch(dsn)
85-
names := dsnPattern.SubexpNames()
86-
87-
for i, match := range matches {
88-
switch names[i] {
89-
case "user":
90-
cfg.user = match
91-
case "passwd":
92-
cfg.passwd = match
93-
case "net":
94-
cfg.net = match
95-
case "addr":
96-
cfg.addr = match
97-
case "dbname":
98-
cfg.dbname = match
99-
case "params":
100-
for _, v := range strings.Split(match, "&") {
101-
param := strings.SplitN(v, "=", 2)
102-
if len(param) != 2 {
103-
continue
104-
}
105-
106-
// cfg params
107-
switch value := param[1]; param[0] {
10876

109-
// Disable INFILE whitelist / enable all files
110-
case "allowAllFiles":
111-
var isBool bool
112-
cfg.allowAllFiles, isBool = readBool(value)
113-
if !isBool {
114-
err = fmt.Errorf("Invalid Bool value: %s", value)
115-
return
116-
}
77+
// TODO: use strings.IndexByte when we can depend on Go 1.2
78+
79+
// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
80+
// Find the last '/'
81+
for i := len(dsn) - 1; i >= 0; i-- {
82+
if dsn[i] == '/' {
83+
var j int
84+
85+
// left part is empty if i <= 0
86+
if i > 0 {
87+
// [username[:password]@][protocol[(address)]]
88+
// Find the last '@' in dsn[:i]
89+
for j = i; j >= 0; j-- {
90+
if dsn[j] == '@' {
91+
// username[:password]
92+
// Find the first ':' in dsn[:j]
93+
var k int
94+
for k = 0; k < j; k++ {
95+
if dsn[k] == ':' {
96+
cfg.passwd = dsn[k+1 : j]
97+
break
98+
}
99+
}
100+
cfg.user = dsn[:k]
101+
102+
// [protocol[(address)]]
103+
// Find the first '(' in dsn[j+1:i]
104+
for k = j + 1; k < i; k++ {
105+
if dsn[k] == '(' {
106+
// dsn[i-1] must be == ')' if an adress is specified
107+
if dsn[i-1] != ')' {
108+
return nil, errInvalidDSN
109+
}
110+
cfg.addr = dsn[k+1 : i-1]
111+
break
112+
}
113+
}
114+
cfg.net = dsn[j+1 : k]
117115

118-
// Switch "rowsAffected" mode
119-
case "clientFoundRows":
120-
var isBool bool
121-
cfg.clientFoundRows, isBool = readBool(value)
122-
if !isBool {
123-
err = fmt.Errorf("Invalid Bool value: %s", value)
124-
return
125-
}
126-
127-
// Use old authentication mode (pre MySQL 4.1)
128-
case "allowOldPasswords":
129-
var isBool bool
130-
cfg.allowOldPasswords, isBool = readBool(value)
131-
if !isBool {
132-
err = fmt.Errorf("Invalid Bool value: %s", value)
133-
return
116+
break
134117
}
118+
}
135119

136-
// Time Location
137-
case "loc":
138-
cfg.loc, err = time.LoadLocation(value)
139-
if err != nil {
140-
return
141-
}
120+
// non-empty left part must contain an '@'
121+
if j < 0 {
122+
return nil, errInvalidDSN
123+
}
124+
}
142125

143-
// Dial Timeout
144-
case "timeout":
145-
cfg.timeout, err = time.ParseDuration(value)
146-
if err != nil {
126+
// dbname[?param1=value1&...&paramN=valueN]
127+
// Find the first '?' in dsn[i+1:]
128+
for j = i + 1; j < len(dsn); j++ {
129+
if dsn[j] == '?' {
130+
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
147131
return
148132
}
149-
150-
// TLS-Encryption
151-
case "tls":
152-
boolValue, isBool := readBool(value)
153-
if isBool {
154-
if boolValue {
155-
cfg.tls = &tls.Config{}
156-
}
157-
} else {
158-
if strings.ToLower(value) == "skip-verify" {
159-
cfg.tls = &tls.Config{InsecureSkipVerify: true}
160-
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
161-
cfg.tls = tlsConfig
162-
} else {
163-
err = fmt.Errorf("Invalid value / unknown config name: %s", value)
164-
return
165-
}
166-
}
167-
168-
default:
169-
cfg.params[param[0]] = value
133+
break
170134
}
171135
}
136+
cfg.dbname = dsn[i+1 : j]
137+
138+
break
172139
}
173140
}
174141

@@ -179,7 +146,15 @@ func parseDSN(dsn string) (cfg *config, err error) {
179146

180147
// Set default adress if empty
181148
if cfg.addr == "" {
182-
cfg.addr = "127.0.0.1:3306"
149+
switch cfg.net {
150+
case "tcp":
151+
cfg.addr = "127.0.0.1:3306"
152+
case "unix":
153+
cfg.addr = "/tmp/mysql.sock"
154+
default:
155+
return nil, errors.New("Default addr for network '" + cfg.net + "' unknown")
156+
}
157+
183158
}
184159

185160
// Set default location if not set
@@ -190,6 +165,81 @@ func parseDSN(dsn string) (cfg *config, err error) {
190165
return
191166
}
192167

168+
func parseDSNParams(cfg *config, params string) (err error) {
169+
cfg.params = make(map[string]string)
170+
171+
for _, v := range strings.Split(params, "&") {
172+
param := strings.SplitN(v, "=", 2)
173+
if len(param) != 2 {
174+
continue
175+
}
176+
177+
// cfg params
178+
switch value := param[1]; param[0] {
179+
180+
// Disable INFILE whitelist / enable all files
181+
case "allowAllFiles":
182+
var isBool bool
183+
cfg.allowAllFiles, isBool = readBool(value)
184+
if !isBool {
185+
return fmt.Errorf("Invalid Bool value: %s", value)
186+
}
187+
188+
// Switch "rowsAffected" mode
189+
case "clientFoundRows":
190+
var isBool bool
191+
cfg.clientFoundRows, isBool = readBool(value)
192+
if !isBool {
193+
return fmt.Errorf("Invalid Bool value: %s", value)
194+
}
195+
196+
// Use old authentication mode (pre MySQL 4.1)
197+
case "allowOldPasswords":
198+
var isBool bool
199+
cfg.allowOldPasswords, isBool = readBool(value)
200+
if !isBool {
201+
return fmt.Errorf("Invalid Bool value: %s", value)
202+
}
203+
204+
// Time Location
205+
case "loc":
206+
cfg.loc, err = time.LoadLocation(value)
207+
if err != nil {
208+
return
209+
}
210+
211+
// Dial Timeout
212+
case "timeout":
213+
cfg.timeout, err = time.ParseDuration(value)
214+
if err != nil {
215+
return
216+
}
217+
218+
// TLS-Encryption
219+
case "tls":
220+
boolValue, isBool := readBool(value)
221+
if isBool {
222+
if boolValue {
223+
cfg.tls = &tls.Config{}
224+
}
225+
} else {
226+
if strings.ToLower(value) == "skip-verify" {
227+
cfg.tls = &tls.Config{InsecureSkipVerify: true}
228+
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
229+
cfg.tls = tlsConfig
230+
} else {
231+
return fmt.Errorf("Invalid value / unknown config name: %s", value)
232+
}
233+
}
234+
235+
default:
236+
cfg.params[param[0]] = value
237+
}
238+
}
239+
240+
return
241+
}
242+
193243
// Returns the bool value of the input.
194244
// The 2nd return value indicates if the input was a valid bool value
195245
func readBool(input string) (value bool, valid bool) {

utils_test.go

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,26 @@ import (
1414
"time"
1515
)
1616

17-
func TestDSNParser(t *testing.T) {
18-
var testDSNs = []struct {
19-
in string
20-
out string
21-
loc *time.Location
22-
}{
23-
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
24-
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
25-
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
26-
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
27-
{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC},
28-
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local},
29-
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
30-
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
31-
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
32-
}
17+
var testDSNs = []struct {
18+
in string
19+
out string
20+
loc *time.Location
21+
}{
22+
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
23+
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
24+
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
25+
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
26+
{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC},
27+
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local},
28+
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
29+
{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
30+
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
31+
{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
32+
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
33+
{"@unix/", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
34+
}
3335

36+
func TestDSNParser(t *testing.T) {
3437
var cfg *config
3538
var err error
3639
var res string
@@ -51,6 +54,31 @@ func TestDSNParser(t *testing.T) {
5154
}
5255
}
5356

57+
func TestDSNParserInvalid(t *testing.T) {
58+
var invalidDSNs = []string{
59+
"asdf/dbname",
60+
//"/dbname?arg=/some/unescaped/path",
61+
}
62+
63+
for i, tst := range invalidDSNs {
64+
if _, err := parseDSN(tst); err == nil {
65+
t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst)
66+
}
67+
}
68+
}
69+
70+
func BenchmarkParseDSN(b *testing.B) {
71+
b.ReportAllocs()
72+
73+
for i := 0; i < b.N; i++ {
74+
for _, tst := range testDSNs {
75+
if _, err := parseDSN(tst.in); err != nil {
76+
b.Error(err.Error())
77+
}
78+
}
79+
}
80+
}
81+
5482
func TestScanNullTime(t *testing.T) {
5583
var scanTests = []struct {
5684
in interface{}

0 commit comments

Comments
 (0)