Skip to content

Commit bcb0751

Browse files
Merge pull request #415 from depot/feat/racing-dialer
feat: use racing, retrying dialer when attempting to connect
2 parents 0d767e6 + fd2bda4 commit bcb0751

File tree

3 files changed

+216
-2
lines changed

3 files changed

+216
-2
lines changed

pkg/api/dialer.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package api
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net"
7+
"os"
8+
"strings"
9+
"testing"
10+
"time"
11+
)
12+
13+
// DialContextFunc is a function that dials a context, network, and address.
14+
type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error)
15+
16+
// RetryDialUntilSuccess will retry every `retryTimeout` until it succeeds.
17+
func RetryDialUntilSuccess(retryTimeout time.Duration) DialContextFunc {
18+
return func(ctx context.Context, network, address string) (net.Conn, error) {
19+
for {
20+
dialer := &net.Dialer{
21+
Timeout: retryTimeout,
22+
KeepAlive: 30 * time.Second, // Similar to the default HTTP dialer.
23+
}
24+
c, err := dialer.DialContext(ctx, network, address)
25+
if err != nil {
26+
if errors.Is(err, context.DeadlineExceeded) {
27+
continue
28+
}
29+
if errors.Is(err, os.ErrDeadlineExceeded) {
30+
continue
31+
}
32+
// Testing hook.
33+
if testing.Testing() && strings.Contains(err.Error(), "connection refused") {
34+
continue
35+
}
36+
}
37+
return c, err
38+
}
39+
}
40+
}
41+
42+
// DialNoTimeout will block with no timeout or until the context is canceled.
43+
func DialNoTimeout() DialContextFunc {
44+
dialer := &net.Dialer{}
45+
return dialer.DialContext
46+
}
47+
48+
// DefaultHTTPDialer has the same options as the default HTTP dialer.
49+
func DefaultHTTPDialer() DialContextFunc {
50+
dialer := &net.Dialer{
51+
Timeout: 30 * time.Second,
52+
KeepAlive: 30 * time.Second,
53+
}
54+
return dialer.DialContext
55+
}
56+
57+
// RacingDialer is a custom dialer that attempts to connect to a given address.
58+
//
59+
// It uses two different dialers.
60+
// The dialer connects first is returned, and the other is canceled.
61+
//
62+
// The first has a short timeout (200 ms) and continues to retry until it succeeds.
63+
// The second dialer has no timeout and will block until it either succeeds or fails.
64+
//
65+
// We are doing this because we see connection timeouts perhaps caused by some competing network routes.
66+
// Our workaround is to use a short timeout dialer that will retry until it succeeds.
67+
func RacingDialer(dialers ...DialContextFunc) DialContextFunc {
68+
if len(dialers) == 0 {
69+
return DialNoTimeout()
70+
}
71+
72+
return func(ctx context.Context, network, address string) (net.Conn, error) {
73+
ctx, cancel := context.WithCancel(ctx)
74+
defer cancel()
75+
76+
type dialResult struct {
77+
conn net.Conn
78+
err error
79+
}
80+
resultCh := make(chan dialResult, len(dialers))
81+
for _, dialer := range dialers {
82+
go func(d DialContextFunc) {
83+
c, err := d(ctx, network, address)
84+
resultCh <- dialResult{conn: c, err: err}
85+
}(dialer)
86+
}
87+
88+
var connError error
89+
for range len(dialers) {
90+
res := <-resultCh
91+
if res.err == nil {
92+
cancel()
93+
return res.conn, nil
94+
} else {
95+
connError = res.err
96+
}
97+
}
98+
return nil, connError
99+
}
100+
}

pkg/api/dialer_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package api
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net"
7+
"sync"
8+
"testing"
9+
"time"
10+
)
11+
12+
func TestRetryDialUntilSuccess(t *testing.T) {
13+
// "reserve" a port.
14+
ln, err := net.Listen("tcp", ":0")
15+
if err != nil {
16+
t.Fatalf("failed to listen on random port: %v", err)
17+
}
18+
port := ln.Addr().(*net.TCPAddr).Port
19+
_ = ln.Close()
20+
21+
ctx, cancel := context.WithCancel(t.Context())
22+
defer cancel()
23+
24+
var wg sync.WaitGroup
25+
wg.Add(1)
26+
go func() {
27+
defer wg.Done()
28+
// Wait a a bit so that the dialer will retry a few times.
29+
time.Sleep(50 * time.Millisecond)
30+
31+
// Start a server to listen on the reserved port.
32+
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
33+
if err != nil {
34+
cancel()
35+
return
36+
}
37+
t.Log("listener", listener.Addr())
38+
defer listener.Close()
39+
conn, err := listener.Accept()
40+
if err != nil {
41+
cancel()
42+
return
43+
}
44+
defer conn.Close()
45+
}()
46+
47+
dialer := RetryDialUntilSuccess(10 * time.Millisecond)
48+
conn, err := dialer(ctx, "tcp", fmt.Sprintf("localhost:%d", port))
49+
if err != nil {
50+
t.Fatalf("failed to dial: %v", err)
51+
}
52+
conn.Close()
53+
wg.Wait()
54+
}
55+
56+
func TestRacingDialer(t *testing.T) {
57+
// "reserve" a port.
58+
ln, err := net.Listen("tcp", ":0")
59+
if err != nil {
60+
t.Fatalf("failed to listen on random port: %v", err)
61+
}
62+
port := ln.Addr().(*net.TCPAddr).Port
63+
_ = ln.Close()
64+
65+
ctx, cancel := context.WithCancel(t.Context())
66+
defer cancel()
67+
68+
var wg sync.WaitGroup
69+
wg.Add(1)
70+
go func() {
71+
defer wg.Done()
72+
// Wait a a bit so that the dialer will retry a few times.
73+
time.Sleep(50 * time.Millisecond)
74+
75+
// Start a server to listen on the reserved port.
76+
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
77+
if err != nil {
78+
cancel()
79+
return
80+
}
81+
t.Log("listener", listener.Addr())
82+
defer listener.Close()
83+
conn, err := listener.Accept()
84+
if err != nil {
85+
cancel()
86+
return
87+
}
88+
defer conn.Close()
89+
}()
90+
91+
dialer := RacingDialer(DialNoTimeout(), RetryDialUntilSuccess(10*time.Millisecond))
92+
conn, err := dialer(ctx, "tcp", fmt.Sprintf("localhost:%d", port))
93+
if err != nil {
94+
t.Fatalf("failed to dial: %v", err)
95+
}
96+
conn.Close()
97+
wg.Wait()
98+
}

pkg/api/rpc.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"net/http"
77
"os"
88
"strings"
9+
"time"
910

1011
"buf.build/gen/go/depot/api/connectrpc/go/depot/core/v1/corev1connect"
1112
"connectrpc.com/connect"
@@ -84,8 +85,23 @@ func getHTTPClient(baseURL string) *http.Client {
8485
},
8586
}
8687
}
87-
// Use default client for HTTPS connections
88-
return http.DefaultClient
88+
89+
t, ok := http.DefaultTransport.(*http.Transport)
90+
if !ok {
91+
return http.DefaultClient
92+
}
93+
94+
transport := t.Clone()
95+
transport.DialContext = RacingDialer(
96+
RetryDialUntilSuccess(500*time.Millisecond),
97+
DefaultHTTPDialer(),
98+
)
99+
100+
racingClient := &http.Client{
101+
Transport: transport,
102+
}
103+
104+
return racingClient
89105
}
90106

91107
func getBaseURL() string {

0 commit comments

Comments
 (0)