Skip to content

Commit 8e79357

Browse files
committed
feat(courierhttp/client): added support for custom host aliases and set default tls config
1 parent 2ced01d commit 8e79357

File tree

3 files changed

+246
-7
lines changed

3 files changed

+246
-7
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
package client
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"iter"
8+
"maps"
9+
"net"
10+
"strings"
11+
12+
randv2 "math/rand/v2"
13+
)
14+
15+
var ErrInvalidHostAlias = errors.New("invalid host alias")
16+
17+
func ParseHostAlias(x string) (*HostAlias, error) {
18+
if x == "" {
19+
return nil, fmt.Errorf("empty host alias: %w", ErrInvalidHostAlias)
20+
}
21+
22+
ha := &HostAlias{}
23+
24+
if strings.IndexByte(x, '[') == 0 {
25+
// ipv6
26+
end := strings.IndexByte(x, ']')
27+
if end == -1 {
28+
return nil, fmt.Errorf("ipv6 addr should warpped []: %w", ErrInvalidHostAlias)
29+
}
30+
31+
ha.IP = net.ParseIP(x[1:end])
32+
33+
x = x[end:]
34+
35+
end = strings.IndexByte(x, ':')
36+
if end == -1 {
37+
return nil, fmt.Errorf(" ip should end with ':': %w", ErrInvalidHostAlias)
38+
}
39+
x = x[end+1:]
40+
} else {
41+
end := strings.IndexByte(x, ':')
42+
if end == -1 {
43+
return nil, fmt.Errorf(" ip should end with ':': %w", ErrInvalidHostAlias)
44+
}
45+
ha.IP = net.ParseIP(x[0:end])
46+
x = x[end+1:]
47+
}
48+
49+
if x == "" {
50+
return nil, fmt.Errorf("invalid host alias")
51+
}
52+
53+
ha.Hostnames = strings.Split(x, ",")
54+
55+
return ha, nil
56+
}
57+
58+
type HostAlias struct {
59+
IP net.IP
60+
Hostnames []string
61+
}
62+
63+
func (x HostAlias) IsZero() bool {
64+
return len(x.Hostnames) == 0 || len(x.IP) == 0
65+
}
66+
67+
func (x *HostAlias) UnmarshalText(raw []byte) error {
68+
ha, err := ParseHostAlias(string(raw))
69+
if err != nil {
70+
return err
71+
}
72+
*x = *ha
73+
return nil
74+
}
75+
76+
func (x HostAlias) MarshalText() ([]byte, error) {
77+
return []byte(x.String()), nil
78+
}
79+
80+
func (x HostAlias) String() string {
81+
s := strings.Builder{}
82+
83+
if ip := x.IP.To4(); ip != nil {
84+
s.WriteString(ip.String())
85+
} else {
86+
s.WriteString("[")
87+
s.WriteString(x.IP.String())
88+
s.WriteString("]")
89+
}
90+
91+
s.WriteString(":")
92+
93+
for i, hostname := range x.Hostnames {
94+
if i > 0 {
95+
s.WriteString(",")
96+
}
97+
s.WriteString(hostname)
98+
}
99+
100+
return s.String()
101+
}
102+
103+
type Hosts map[string]map[string]struct{}
104+
105+
func (hosts Hosts) WrapDialContext(dialContext func(ctx context.Context, network string, address string) (net.Conn, error)) func(ctx context.Context, network string, addr string) (net.Conn, error) {
106+
return func(ctx context.Context, network string, addr string) (net.Conn, error) {
107+
if len(hosts) == 0 {
108+
return dialContext(ctx, network, addr)
109+
}
110+
111+
host, port, err := net.SplitHostPort(addr)
112+
if err != nil {
113+
host = addr
114+
port = "80"
115+
}
116+
117+
if ips, ok := hosts[host]; ok && len(ips) > 0 {
118+
resolved := net.JoinHostPort(hosts.selectIP(maps.Keys(ips), len(ips)), port)
119+
return dialContext(ctx, network, resolved)
120+
}
121+
122+
return dialContext(ctx, network, addr)
123+
}
124+
}
125+
126+
func (hosts Hosts) selectIP(ips iter.Seq[string], n int) (ip string) {
127+
i := 0
128+
idx := 0
129+
if n > 1 {
130+
idx = randv2.IntN(n) - 1
131+
}
132+
133+
for x := range ips {
134+
if i == idx {
135+
ip = x
136+
break
137+
}
138+
i++
139+
}
140+
141+
return
142+
}
143+
144+
func (hosts Hosts) AddHostAlias(alias HostAlias) {
145+
if alias.IsZero() {
146+
return
147+
}
148+
149+
for _, hostname := range alias.Hostnames {
150+
if hosts[hostname] == nil {
151+
hosts[hostname] = make(map[string]struct{})
152+
}
153+
hosts[hostname][alias.IP.String()] = struct{}{}
154+
}
155+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package client
2+
3+
import (
4+
"net"
5+
"testing"
6+
7+
"github.com/octohelm/x/testing/bdd"
8+
)
9+
10+
func TestHostAliases(t *testing.T) {
11+
bdd.FromT(t).Given("host alias for ipv4", func(b bdd.T) {
12+
ha := &HostAlias{
13+
IP: net.ParseIP("127.0.0.1"),
14+
Hostnames: []string{
15+
"localhost",
16+
"localhost1",
17+
},
18+
}
19+
20+
txt := bdd.Must(ha.MarshalText())
21+
22+
b.Then("match results",
23+
bdd.Equal("127.0.0.1:localhost,localhost1", string(txt)),
24+
)
25+
26+
b.When("unmarshal", func(b bdd.T) {
27+
ha1 := &HostAlias{}
28+
29+
b.Then("success",
30+
bdd.NoError(ha1.UnmarshalText(txt)),
31+
bdd.Equal(ha, ha1),
32+
)
33+
})
34+
})
35+
36+
bdd.FromT(t).Given("host alias for ipv6", func(b bdd.T) {
37+
ha := &HostAlias{
38+
IP: net.ParseIP("::1"),
39+
Hostnames: []string{
40+
"localhost",
41+
"localhost1",
42+
},
43+
}
44+
45+
txt := bdd.Must(ha.MarshalText())
46+
47+
b.Then("match results",
48+
bdd.Equal("[::1]:localhost,localhost1", string(txt)),
49+
)
50+
51+
b.When("unmarshal", func(b bdd.T) {
52+
ha1 := &HostAlias{}
53+
54+
b.Then("success",
55+
bdd.NoError(ha1.UnmarshalText(txt)),
56+
bdd.Equal(ha, ha1),
57+
)
58+
})
59+
})
60+
}

pkg/courierhttp/client/http_default.go

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@ package client
22

33
import (
44
"context"
5+
"crypto/tls"
56
"net"
67
"net/http"
78
"time"
89
)
910

10-
var reasonableRoundTripper http.RoundTripper = &http.Transport{
11+
var reasonableRoundTripper = &http.Transport{
1112
Proxy: http.ProxyFromEnvironment,
1213

13-
DialContext: (&net.Dialer{
14+
DialContext: defaultHosts.WrapDialContext((&net.Dialer{
1415
Timeout: 30 * time.Second,
1516
KeepAlive: 30 * time.Second,
16-
}).DialContext,
17+
}).DialContext),
1718

1819
MaxIdleConns: 100,
1920
MaxIdleConnsPerHost: 10,
@@ -27,8 +28,25 @@ var reasonableRoundTripper http.RoundTripper = &http.Transport{
2728
ForceAttemptHTTP2: true,
2829
}
2930

31+
var defaultTlsConfig = &tls.Config{}
32+
33+
var defaultHosts = Hosts{}
34+
35+
func AddHostAlias(hostAliases ...HostAlias) {
36+
for _, hostAlias := range hostAliases {
37+
defaultHosts.AddHostAlias(hostAlias)
38+
}
39+
}
40+
41+
func SetDefaultTLSClientConfig(tlsConfig *tls.Config) {
42+
if tlsConfig != nil {
43+
defaultTlsConfig = tlsConfig.Clone()
44+
reasonableRoundTripper.TLSClientConfig = tlsConfig.Clone()
45+
}
46+
}
47+
3048
func GetReasonableClientContext(ctx context.Context, httpTransports ...HttpTransport) *http.Client {
31-
t := reasonableRoundTripper
49+
t := http.RoundTripper(reasonableRoundTripper)
3250

3351
tc, ok := RoundTripperCreatorFromContext(ctx)
3452
if ok {
@@ -39,20 +57,26 @@ func GetReasonableClientContext(ctx context.Context, httpTransports ...HttpTrans
3957
}
4058

4159
func newRoundTripperWithoutKeepAlive() http.RoundTripper {
42-
return &http.Transport{
60+
t := &http.Transport{
4361
Proxy: http.ProxyFromEnvironment,
4462

45-
DialContext: (&net.Dialer{
63+
DialContext: defaultHosts.WrapDialContext((&net.Dialer{
4664
Timeout: 30 * time.Second,
4765
KeepAlive: 0,
48-
}).DialContext,
66+
}).DialContext),
4967

5068
DisableKeepAlives: true,
5169

5270
TLSHandshakeTimeout: 10 * time.Second,
5371
ExpectContinueTimeout: 1 * time.Second,
5472
ResponseHeaderTimeout: 60 * time.Second,
5573
}
74+
75+
if defaultTlsConfig != nil {
76+
t.TLSClientConfig = defaultTlsConfig.Clone()
77+
}
78+
79+
return t
5680
}
5781

5882
func GetShortConnClientContext(ctx context.Context, httpTransports ...HttpTransport) *http.Client {

0 commit comments

Comments
 (0)