Skip to content

Commit 43d8609

Browse files
committed
Support multiple hosts in provider configuration
1 parent 450ce85 commit 43d8609

File tree

2 files changed

+158
-5
lines changed

2 files changed

+158
-5
lines changed

postgresql/proxy_driver.go

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ import (
44
"context"
55
"database/sql"
66
"database/sql/driver"
7+
"errors"
78
"net"
9+
"strings"
810
"time"
911

1012
"github.com/lib/pq"
@@ -20,14 +22,82 @@ func (d proxyDriver) Open(name string) (driver.Conn, error) {
2022
}
2123

2224
func (d proxyDriver) Dial(network, address string) (net.Conn, error) {
23-
dialer := proxy.FromEnvironment()
24-
return dialer.Dial(network, address)
25+
return d.DialTimeout(network, address, 0)
2526
}
2627

2728
func (d proxyDriver) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
28-
ctx, cancel := context.WithTimeout(context.TODO(), timeout)
29-
defer cancel()
30-
return proxy.Dial(ctx, network, address)
29+
var ctx context.Context
30+
var cancel context.CancelFunc
31+
if timeout > 0 {
32+
ctx, cancel = context.WithTimeout(context.Background(), timeout)
33+
defer cancel()
34+
} else {
35+
ctx = context.Background()
36+
}
37+
38+
// Only handle TCP networks for multi-host splitting
39+
if !strings.HasPrefix(network, "tcp") {
40+
return proxy.Dial(ctx, network, address)
41+
}
42+
43+
hosts, port, err := parseAddress(address)
44+
if err != nil {
45+
// If parsing fails, fall back to trying the original address
46+
return proxy.Dial(ctx, network, address)
47+
}
48+
49+
var lastErr error
50+
for _, host := range hosts {
51+
addr := net.JoinHostPort(host, port)
52+
conn, err := proxy.Dial(ctx, network, addr)
53+
if err == nil {
54+
return conn, nil
55+
}
56+
lastErr = err
57+
58+
// Check if context expired
59+
select {
60+
case <-ctx.Done():
61+
return nil, ctx.Err()
62+
default:
63+
}
64+
}
65+
if lastErr != nil {
66+
return nil, lastErr
67+
}
68+
return nil, errors.New("no hosts available")
69+
}
70+
71+
func parseAddress(address string) ([]string, string, error) {
72+
host, port, err := net.SplitHostPort(address)
73+
if err == nil {
74+
if strings.Contains(host, ",") {
75+
return strings.Split(host, ","), port, nil
76+
}
77+
return []string{host}, port, nil
78+
}
79+
80+
// Fallback for when net.SplitHostPort fails (e.g. mixed bracketed and unbracketed hosts)
81+
lastColon := strings.LastIndex(address, ":")
82+
if lastColon == -1 {
83+
return nil, "", err
84+
}
85+
86+
port = address[lastColon+1:]
87+
hostPart := address[:lastColon]
88+
89+
if strings.Contains(hostPart, ",") {
90+
hosts := strings.Split(hostPart, ",")
91+
// Clean up brackets if present so net.JoinHostPort doesn't double them
92+
for i, h := range hosts {
93+
if len(h) > 2 && h[0] == '[' && h[len(h)-1] == ']' {
94+
hosts[i] = h[1 : len(h)-1]
95+
}
96+
}
97+
return hosts, port, nil
98+
}
99+
100+
return nil, "", err
31101
}
32102

33103
func init() {

postgresql/proxy_driver_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package postgresql
2+
3+
import (
4+
"net"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestParseAddress(t *testing.T) {
11+
tests := []struct {
12+
input string
13+
expectHosts []string
14+
expectPort string
15+
expectErr bool
16+
}{
17+
{
18+
input: "host1:5432",
19+
expectHosts: []string{"host1"},
20+
expectPort: "5432",
21+
expectErr: false,
22+
},
23+
{
24+
input: "host1,host2:5432",
25+
expectHosts: []string{"host1", "host2"},
26+
expectPort: "5432",
27+
expectErr: false,
28+
},
29+
{
30+
input: "[::1]:5432",
31+
expectHosts: []string{"::1"}, // net.SplitHostPort strips brackets
32+
expectPort: "5432",
33+
expectErr: false,
34+
},
35+
{
36+
input: "[::1],localhost:5432",
37+
expectHosts: []string{"::1", "localhost"}, // manual split strips brackets
38+
expectPort: "5432",
39+
expectErr: false,
40+
},
41+
{
42+
input: "host1,[::1]:5432",
43+
expectHosts: []string{"host1", "::1"},
44+
expectPort: "5432",
45+
expectErr: false,
46+
},
47+
}
48+
49+
for _, tt := range tests {
50+
t.Run(tt.input, func(t *testing.T) {
51+
hosts, port, err := parseAddress(tt.input)
52+
if tt.expectErr {
53+
assert.Error(t, err)
54+
} else {
55+
assert.NoError(t, err)
56+
assert.Equal(t, tt.expectHosts, hosts)
57+
assert.Equal(t, tt.expectPort, port)
58+
}
59+
})
60+
}
61+
}
62+
63+
func TestReconstruction(t *testing.T) {
64+
// Verify that net.JoinHostPort reconstructs correctly from what parseAddress returns
65+
tests := []string{
66+
"host1:5432",
67+
"host1,host2:5432",
68+
"[::1]:5432",
69+
"[::1],localhost:5432",
70+
}
71+
72+
for _, input := range tests {
73+
hosts, port, err := parseAddress(input)
74+
assert.NoError(t, err)
75+
for _, h := range hosts {
76+
addr := net.JoinHostPort(h, port)
77+
// Sanity check on address format
78+
_, _, err := net.SplitHostPort(addr)
79+
assert.NoError(t, err, "JoinHostPort produced invalid address: %s", addr)
80+
}
81+
}
82+
}
83+

0 commit comments

Comments
 (0)