Skip to content

Commit c5ce227

Browse files
fix port-less host parsing
1 parent 0beb854 commit c5ce227

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

sql/mysql_db/mysql_db.go

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
"encoding/json"
2222
"fmt"
2323
"net"
24-
"net/url"
2524
"sort"
2625
"strings"
2726
"sync"
@@ -768,11 +767,15 @@ func (db *MySQLDb) AuthMethod(user, addr string) (string, error) {
768767
if addr == "@" || addr == "" {
769768
host = "localhost"
770769
} else {
771-
addrUrl, err := url.Parse(addr)
770+
splitHost, _, err := net.SplitHostPort(addr)
772771
if err != nil {
773-
return "", err
772+
if err.(*net.AddrError).Err == "missing port in address" {
773+
host = addr
774+
} else {
775+
return "", err
776+
}
777+
host = splitHost
774778
}
775-
host = addrUrl.Hostname()
776779
}
777780

778781
rd := db.Reader()
@@ -796,14 +799,18 @@ func (db *MySQLDb) Salt() ([]byte, error) {
796799
// ValidateHash implements the interface mysql.AuthServer. This is called when the method used is "mysql_native_password".
797800
func (db *MySQLDb) ValidateHash(salt []byte, user string, authResponse []byte, addr net.Addr) (mysql.Getter, error) {
798801
var host string
802+
var err error
799803
if addr.Network() == "unix" {
800804
host = "localhost"
801805
} else {
802-
addrUrl, err := url.Parse(addr.String())
806+
host, _, err = net.SplitHostPort(addr.String())
803807
if err != nil {
804-
return nil, err
808+
if err.(*net.AddrError).Err == "missing port in address" {
809+
host = addr.String()
810+
} else {
811+
return nil, err
812+
}
805813
}
806-
host = addrUrl.Hostname()
807814
}
808815

809816
rd := db.Reader()
@@ -832,14 +839,18 @@ func (db *MySQLDb) ValidateHash(salt []byte, user string, authResponse []byte, a
832839
// Negotiate implements the interface mysql.AuthServer. This is called when the method used is not "mysql_native_password".
833840
func (db *MySQLDb) Negotiate(c *mysql.Conn, user string, addr net.Addr) (mysql.Getter, error) {
834841
var host string
842+
var err error
835843
if addr.Network() == "unix" {
836844
host = "localhost"
837845
} else {
838-
addrUrl, err := url.Parse(addr.String())
846+
host, _, err = net.SplitHostPort(addr.String())
839847
if err != nil {
840-
return nil, err
848+
if err.(*net.AddrError).Err == "missing port in address" {
849+
host = addr.String()
850+
} else {
851+
return nil, err
852+
}
841853
}
842-
host = addrUrl.Hostname()
843854
}
844855

845856
rd := db.Reader()

0 commit comments

Comments
 (0)