11package config
22
33import (
4+ "net"
45 "regexp"
6+ "strconv"
57 "strings"
68 "unicode"
79)
@@ -85,9 +87,75 @@ func (d *DSN) MaskPassword() (s string) {
8587 return s
8688}
8789
88- // Parse parses a data source name string.
90+ // Host the database server host.
91+ func (d * DSN ) Host () string {
92+ if d .Driver == SQLite3 {
93+ return ""
94+ }
95+
96+ host , _ := d .splitHostPort ()
97+ return host
98+ }
99+
100+ // Port the database server port.
101+ func (d * DSN ) Port () int {
102+ switch d .Driver {
103+ case SQLite3 :
104+ return 0
105+ }
106+
107+ defaultPort := 0
108+
109+ switch d .Driver {
110+ case MySQL , MariaDB :
111+ defaultPort = 3306
112+ case Postgres :
113+ defaultPort = 5432
114+ }
115+
116+ if d .Server == "" {
117+ return 0
118+ }
119+
120+ _ , portValue := d .splitHostPort ()
121+
122+ if portValue == "" {
123+ return defaultPort
124+ }
125+
126+ port , err := strconv .Atoi (portValue )
127+ if err != nil || port < 1 || port > 65535 {
128+ return defaultPort
129+ }
130+
131+ return port
132+ }
133+
134+ // splitHostPort splits the DSN server field into host and port components.
135+ func (d * DSN ) splitHostPort () (host , port string ) {
136+ server := strings .TrimSpace (d .Server )
137+
138+ if server == "" {
139+ return "" , ""
140+ }
141+
142+ var err error
143+
144+ host , port , err = net .SplitHostPort (server )
145+
146+ if err != nil {
147+ return server , ""
148+ }
149+
150+ return host , port
151+ }
152+
153+ // parse parses a data source name string.
89154func (d * DSN ) parse () {
90- // Assume a regular DSN, and if parsing fails, treat it as a PostgreSQL-style DSN.
155+ if d .parsePostgres () {
156+ return
157+ }
158+
91159 if matches := dsnPattern .FindStringSubmatch (d .DSN ); len (matches ) > 0 {
92160 names := dsnPattern .SubexpNames ()
93161
@@ -114,14 +182,18 @@ func (d *DSN) parse() {
114182 d .Server = d .Net
115183 d .Net = ""
116184 }
117- } else {
118- // Parse PostgreSQL-style DSN
119- d .parsePostgres ()
120185 }
186+
187+ d .detectDriver ()
121188}
122189
123- // parsePostgres extracts connection settings from PostgreSQL key/value style DSNs.
190+ // parsePostgres extracts connection settings from PostgreSQL key/value style DSNs and
191+ // returns true on success.
124192func (d * DSN ) parsePostgres () bool {
193+ if ! strings .Contains (d .DSN , "password=" ) || ! strings .Contains (d .DSN , "user=" ) {
194+ return false
195+ }
196+
125197 fields , ok := d .splitKeyValue (d .DSN )
126198
127199 if ! ok {
@@ -276,6 +348,54 @@ func (d *DSN) splitKeyValue(input string) ([]string, bool) {
276348 return tokens , true
277349}
278350
351+ // detectDriver infers the driver name from DSN contents when it is not explicitly specified.
352+ func (d * DSN ) detectDriver () {
353+ driver := strings .ToLower (d .Driver )
354+
355+ switch driver {
356+ case "postgres" , "postgresql" :
357+ d .Driver = Postgres
358+ return
359+ case "mysql" , "mariadb" :
360+ d .Driver = MySQL
361+ return
362+ case "sqlite" , "sqlite3" , "file" :
363+ d .Driver = SQLite3
364+ return
365+ }
366+
367+ if driver != "" {
368+ d .Driver = driver
369+ return
370+ }
371+
372+ lower := strings .ToLower (d .DSN )
373+
374+ if strings .Contains (lower , "postgres://" ) || strings .Contains (lower , "postgresql://" ) {
375+ d .Driver = Postgres
376+ return
377+ }
378+
379+ if d .Net == "tcp" || d .Net == "unix" || strings .Contains (lower , "@tcp(" ) || strings .Contains (lower , "@unix(" ) {
380+ d .Driver = MySQL
381+ return
382+ }
383+
384+ if strings .HasPrefix (lower , "file:" ) || strings .HasSuffix (lower , ".db" ) || strings .HasSuffix (strings .ToLower (d .Name ), ".db" ) {
385+ d .Driver = SQLite3
386+ return
387+ }
388+
389+ if strings .Contains (lower , " host=" ) && strings .Contains (lower , " dbname=" ) {
390+ d .Driver = Postgres
391+ return
392+ }
393+
394+ if d .Server != "" && (strings .Contains (d .Server , ":" ) || d .Net != "" ) && d .Driver == "" {
395+ d .Driver = MySQL
396+ }
397+ }
398+
279399// MaskDatabaseDSN hides the password portion of a DSN while leaving the rest untouched for logging/reporting.
280400func MaskDatabaseDSN (dsn string ) string {
281401 if dsn == "" {
0 commit comments