44 "context"
55 "database/sql"
66 "database/sql/driver"
7+ "errors"
78 "net"
9+ "strings"
810 "time"
911
1012 "github.com/lib/pq"
@@ -20,14 +22,82 @@ func (d proxyDriver) Open(name string) (driver.Conn, error) {
2022}
2123
2224func (d proxyDriver ) Dial (network , address string ) (net.Conn , error ) {
23- dialer := proxy .FromEnvironment ()
24- return dialer .Dial (network , address )
25+ return d .DialTimeout (network , address , 0 )
2526}
2627
2728func (d proxyDriver ) DialTimeout (network , address string , timeout time.Duration ) (net.Conn , error ) {
28- ctx , cancel := context .WithTimeout (context .TODO (), timeout )
29- defer cancel ()
30- return proxy .Dial (ctx , network , address )
29+ var ctx context.Context
30+ var cancel context.CancelFunc
31+ if timeout > 0 {
32+ ctx , cancel = context .WithTimeout (context .Background (), timeout )
33+ defer cancel ()
34+ } else {
35+ ctx = context .Background ()
36+ }
37+
38+ // Only handle TCP networks for multi-host splitting
39+ if ! strings .HasPrefix (network , "tcp" ) {
40+ return proxy .Dial (ctx , network , address )
41+ }
42+
43+ hosts , port , err := parseAddress (address )
44+ if err != nil {
45+ // If parsing fails, fall back to trying the original address
46+ return proxy .Dial (ctx , network , address )
47+ }
48+
49+ var lastErr error
50+ for _ , host := range hosts {
51+ addr := net .JoinHostPort (host , port )
52+ conn , err := proxy .Dial (ctx , network , addr )
53+ if err == nil {
54+ return conn , nil
55+ }
56+ lastErr = err
57+
58+ // Check if context expired
59+ select {
60+ case <- ctx .Done ():
61+ return nil , ctx .Err ()
62+ default :
63+ }
64+ }
65+ if lastErr != nil {
66+ return nil , lastErr
67+ }
68+ return nil , errors .New ("no hosts available" )
69+ }
70+
71+ func parseAddress (address string ) ([]string , string , error ) {
72+ host , port , err := net .SplitHostPort (address )
73+ if err == nil {
74+ if strings .Contains (host , "," ) {
75+ return strings .Split (host , "," ), port , nil
76+ }
77+ return []string {host }, port , nil
78+ }
79+
80+ // Fallback for when net.SplitHostPort fails (e.g. mixed bracketed and unbracketed hosts)
81+ lastColon := strings .LastIndex (address , ":" )
82+ if lastColon == - 1 {
83+ return nil , "" , err
84+ }
85+
86+ port = address [lastColon + 1 :]
87+ hostPart := address [:lastColon ]
88+
89+ if strings .Contains (hostPart , "," ) {
90+ hosts := strings .Split (hostPart , "," )
91+ // Clean up brackets if present so net.JoinHostPort doesn't double them
92+ for i , h := range hosts {
93+ if len (h ) > 2 && h [0 ] == '[' && h [len (h )- 1 ] == ']' {
94+ hosts [i ] = h [1 : len (h )- 1 ]
95+ }
96+ }
97+ return hosts , port , nil
98+ }
99+
100+ return nil , "" , err
31101}
32102
33103func init () {
0 commit comments