Skip to content

Commit 9584246

Browse files
authored
feat: Implement Custom TCP Dialers (#3166)
1 parent fde0e3a commit 9584246

File tree

3 files changed

+158
-0
lines changed

3 files changed

+158
-0
lines changed

libp2p_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,3 +784,31 @@ func TestSharedTCPAddr(t *testing.T) {
784784
)
785785
require.ErrorContains(t, err, "cannot use shared TCP listener with PSK")
786786
}
787+
788+
func TestCustomTCPDialer(t *testing.T) {
789+
expectedErr := errors.New("custom dialer called, but not implemented")
790+
customDialer := func(raddr ma.Multiaddr) (tcp.ContextDialer, error) {
791+
// Normally a user would implement this by returning a custom dialer
792+
// Here, we just test that this is called.
793+
return nil, expectedErr
794+
}
795+
796+
h, err := New(
797+
Transport(tcp.NewTCPTransport, tcp.WithDialerForAddr(customDialer)),
798+
)
799+
require.NoError(t, err)
800+
defer h.Close()
801+
802+
var randID peer.ID
803+
priv, _, err := crypto.GenerateKeyPair(crypto.Ed25519, 256)
804+
require.NoError(t, err)
805+
randID, err = peer.IDFromPrivateKey(priv)
806+
require.NoError(t, err)
807+
808+
err = h.Connect(context.Background(), peer.AddrInfo{
809+
ID: randID,
810+
// This won't actually be dialed since we return an error above
811+
Addrs: []ma.Multiaddr{ma.StringCast("/ip4/1.2.3.4/tcp/4")},
812+
})
813+
require.ErrorContains(t, err, expectedErr.Error())
814+
}

p2p/transport/tcp/tcp.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package tcp
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"net"
78
"os"
89
"runtime"
@@ -117,12 +118,35 @@ func WithMetrics() Option {
117118
}
118119
}
119120

121+
// WithDialerForAddr sets a custom dialer for the given address.
122+
// If set, it will be the *ONLY* dialer used.
123+
func WithDialerForAddr(d DialerForAddr) Option {
124+
return func(tr *TcpTransport) error {
125+
tr.overrideDialerForAddr = d
126+
return nil
127+
}
128+
}
129+
130+
type ContextDialer interface {
131+
DialContext(ctx context.Context, network, address string) (net.Conn, error)
132+
}
133+
134+
// DialerForAddr is a function that returns a dialer for a given address.
135+
// Implementations must return either a ContextDialer or an error. It is
136+
// invalid to return nil, nil.
137+
type DialerForAddr func(raddr ma.Multiaddr) (ContextDialer, error)
138+
120139
// TcpTransport is the TCP transport.
121140
type TcpTransport struct {
122141
// Connection upgrader for upgrading insecure stream connections to
123142
// secure multiplex connections.
124143
upgrader transport.Upgrader
125144

145+
// optional custom dialer to use for dialing. If set, it will be the *ONLY* dialer
146+
// used. The transport will not attempt to reuse the listen port to
147+
// dial or the shared TCP transport for dialing.
148+
overrideDialerForAddr DialerForAddr
149+
126150
disableReuseport bool // Explicitly disable reuseport.
127151
enableMetrics bool
128152

@@ -170,6 +194,35 @@ func (t *TcpTransport) CanDial(addr ma.Multiaddr) bool {
170194
return dialMatcher.Matches(addr)
171195
}
172196

197+
func (t *TcpTransport) customDial(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) {
198+
// get the net.Dial friendly arguments from the remote addr
199+
rnet, rnaddr, err := manet.DialArgs(raddr)
200+
if err != nil {
201+
return nil, err
202+
}
203+
dialer, err := t.overrideDialerForAddr(raddr)
204+
if err != nil {
205+
return nil, err
206+
}
207+
if dialer == nil {
208+
return nil, fmt.Errorf("dialer for address %s is nil", raddr)
209+
}
210+
211+
// ok, Dial!
212+
var nconn net.Conn
213+
switch rnet {
214+
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6", "unix":
215+
nconn, err = dialer.DialContext(ctx, rnet, rnaddr)
216+
if err != nil {
217+
return nil, err
218+
}
219+
default:
220+
return nil, fmt.Errorf("unrecognized network: %s", rnet)
221+
}
222+
223+
return manet.WrapNetConn(nconn)
224+
}
225+
173226
func (t *TcpTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) {
174227
// Apply the deadline iff applicable
175228
if t.connectTimeout > 0 {
@@ -178,6 +231,10 @@ func (t *TcpTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Co
178231
defer cancel()
179232
}
180233

234+
if t.overrideDialerForAddr != nil {
235+
return t.customDial(ctx, raddr)
236+
}
237+
181238
if t.sharedTcp != nil {
182239
return t.sharedTcp.DialContext(ctx, raddr)
183240
}

p2p/transport/tcp/tcp_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package tcp
33
import (
44
"context"
55
"errors"
6+
"net"
67
"testing"
78

89
"github.com/libp2p/go-libp2p/core/crypto"
@@ -205,3 +206,75 @@ func makeInsecureMuxer(t *testing.T) (peer.ID, []sec.SecureTransport) {
205206
require.NoError(t, err)
206207
return id, []sec.SecureTransport{insecure.NewWithIdentity(insecure.ID, id, priv)}
207208
}
209+
210+
type errDialer struct {
211+
err error
212+
}
213+
214+
func (d errDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
215+
return nil, d.err
216+
}
217+
218+
func TestCustomOverrideTCPDialer(t *testing.T) {
219+
t.Run("success", func(t *testing.T) {
220+
peerA, ia := makeInsecureMuxer(t)
221+
ua, err := tptu.New(ia, muxers, nil, nil, nil)
222+
require.NoError(t, err)
223+
ta, err := NewTCPTransport(ua, nil, nil)
224+
require.NoError(t, err)
225+
ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0"))
226+
require.NoError(t, err)
227+
defer ln.Close()
228+
229+
_, ib := makeInsecureMuxer(t)
230+
ub, err := tptu.New(ib, muxers, nil, nil, nil)
231+
require.NoError(t, err)
232+
called := false
233+
customDialer := func(raddr ma.Multiaddr) (ContextDialer, error) {
234+
called = true
235+
return &net.Dialer{}, nil
236+
}
237+
tb, err := NewTCPTransport(ub, nil, nil, WithDialerForAddr(customDialer))
238+
require.NoError(t, err)
239+
240+
conn, err := tb.Dial(context.Background(), ln.Multiaddr(), peerA)
241+
require.NoError(t, err)
242+
require.NotNil(t, conn)
243+
require.True(t, called, "custom dialer should have been called")
244+
conn.Close()
245+
})
246+
247+
t.Run("errors", func(t *testing.T) {
248+
peerA, ia := makeInsecureMuxer(t)
249+
ua, err := tptu.New(ia, muxers, nil, nil, nil)
250+
require.NoError(t, err)
251+
ta, err := NewTCPTransport(ua, nil, nil)
252+
require.NoError(t, err)
253+
ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0"))
254+
require.NoError(t, err)
255+
defer ln.Close()
256+
257+
for _, test := range []string{"error in factory", "error in custom dialer"} {
258+
t.Run(test, func(t *testing.T) {
259+
_, ib := makeInsecureMuxer(t)
260+
ub, err := tptu.New(ib, muxers, nil, nil, nil)
261+
require.NoError(t, err)
262+
customErr := errors.New("custom dialer error")
263+
customDialer := func(raddr ma.Multiaddr) (ContextDialer, error) {
264+
if test == "error in factory" {
265+
return nil, customErr
266+
} else {
267+
return errDialer{err: customErr}, nil
268+
}
269+
}
270+
tb, err := NewTCPTransport(ub, nil, nil, WithDialerForAddr(customDialer))
271+
require.NoError(t, err)
272+
273+
conn, err := tb.Dial(context.Background(), ln.Multiaddr(), peerA)
274+
require.Error(t, err)
275+
require.ErrorContains(t, err, customErr.Error())
276+
require.Nil(t, conn)
277+
})
278+
}
279+
})
280+
}

0 commit comments

Comments
 (0)