@@ -21,6 +21,7 @@ import (
2121 "encoding/json"
2222 "fmt"
2323 "net"
24+ "net/url"
2425 "sort"
2526 "strings"
2627 "sync"
@@ -767,11 +768,11 @@ func (db *MySQLDb) AuthMethod(user, addr string) (string, error) {
767768 if addr == "@" || addr == "" {
768769 host = "localhost"
769770 } else {
770- splitHost , _ , err := net . SplitHostPort (addr )
771+ addrUrl , err := url . Parse (addr )
771772 if err != nil {
772773 return "" , err
773774 }
774- host = splitHost
775+ host = addrUrl . Hostname ()
775776 }
776777
777778 rd := db .Reader ()
@@ -795,17 +796,14 @@ func (db *MySQLDb) Salt() ([]byte, error) {
795796// ValidateHash implements the interface mysql.AuthServer. This is called when the method used is "mysql_native_password".
796797func (db * MySQLDb ) ValidateHash (salt []byte , user string , authResponse []byte , addr net.Addr ) (mysql.Getter , error ) {
797798 var host string
798- var err error
799- switch addr .Network () {
800- case "unix" :
799+ if addr .Network () == "unix" {
801800 host = "localhost"
802- case "tcp" , "udp" :
803- host , _ , err = net . SplitHostPort (addr .String ())
801+ } else {
802+ addrUrl , err := url . Parse (addr .String ())
804803 if err != nil {
805804 return nil , err
806805 }
807- default :
808- host = addr .String ()
806+ host = addrUrl .Hostname ()
809807 }
810808
811809 rd := db .Reader ()
@@ -834,17 +832,14 @@ func (db *MySQLDb) ValidateHash(salt []byte, user string, authResponse []byte, a
834832// Negotiate implements the interface mysql.AuthServer. This is called when the method used is not "mysql_native_password".
835833func (db * MySQLDb ) Negotiate (c * mysql.Conn , user string , addr net.Addr ) (mysql.Getter , error ) {
836834 var host string
837- var err error
838- switch addr .Network () {
839- case "unix" :
835+ if addr .Network () == "unix" {
840836 host = "localhost"
841- case "tcp" , "udp" :
842- host , _ , err = net . SplitHostPort (addr .String ())
837+ } else {
838+ addrUrl , err := url . Parse (addr .String ())
843839 if err != nil {
844840 return nil , err
845841 }
846- default :
847- host = addr .String ()
842+ host = addrUrl .Hostname ()
848843 }
849844
850845 rd := db .Reader ()
0 commit comments