Skip to content

Commit 3cce9a7

Browse files
use url parse to handle urls that may not have ports
1 parent 6c9434e commit 3cce9a7

File tree

1 file changed

+11
-16
lines changed

1 file changed

+11
-16
lines changed

sql/mysql_db/mysql_db.go

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"encoding/json"
2222
"fmt"
2323
"net"
24+
"net/url"
2425
"sort"
2526
"strings"
2627
"sync"
@@ -767,11 +768,11 @@ func (db *MySQLDb) AuthMethod(user, addr string) (string, error) {
767768
if addr == "@" || addr == "" {
768769
host = "localhost"
769770
} else {
770-
splitHost, _, err := net.SplitHostPort(addr)
771+
addrUrl, err := url.Parse(addr)
771772
if err != nil {
772773
return "", err
773774
}
774-
host = splitHost
775+
host = addrUrl.Hostname()
775776
}
776777

777778
rd := db.Reader()
@@ -795,17 +796,14 @@ func (db *MySQLDb) Salt() ([]byte, error) {
795796
// ValidateHash implements the interface mysql.AuthServer. This is called when the method used is "mysql_native_password".
796797
func (db *MySQLDb) ValidateHash(salt []byte, user string, authResponse []byte, addr net.Addr) (mysql.Getter, error) {
797798
var host string
798-
var err error
799-
switch addr.Network() {
800-
case "unix":
799+
if addr.Network() == "unix" {
801800
host = "localhost"
802-
case "tcp", "udp":
803-
host, _, err = net.SplitHostPort(addr.String())
801+
} else {
802+
addrUrl, err := url.Parse(addr.String())
804803
if err != nil {
805804
return nil, err
806805
}
807-
default:
808-
host = addr.String()
806+
host = addrUrl.Hostname()
809807
}
810808

811809
rd := db.Reader()
@@ -834,17 +832,14 @@ func (db *MySQLDb) ValidateHash(salt []byte, user string, authResponse []byte, a
834832
// Negotiate implements the interface mysql.AuthServer. This is called when the method used is not "mysql_native_password".
835833
func (db *MySQLDb) Negotiate(c *mysql.Conn, user string, addr net.Addr) (mysql.Getter, error) {
836834
var host string
837-
var err error
838-
switch addr.Network() {
839-
case "unix":
835+
if addr.Network() == "unix" {
840836
host = "localhost"
841-
case "tcp", "udp":
842-
host, _, err = net.SplitHostPort(addr.String())
837+
} else {
838+
addrUrl, err := url.Parse(addr.String())
843839
if err != nil {
844840
return nil, err
845841
}
846-
default:
847-
host = addr.String()
842+
host = addrUrl.Hostname()
848843
}
849844

850845
rd := db.Reader()

0 commit comments

Comments
 (0)