Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ type Config struct {
UserFxOptions []fx.Option

ShareTCPListener bool
// AllowSharedTCPReachability indicates whether shared TCP listeners should be considered
// for reachability detection. When true, addresses using shared TCP listeners will be
// included in reachability checks. Defaults to true when ShareTCPListener is enabled.
AllowSharedTCPReachability bool
}

func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) {
Expand Down Expand Up @@ -747,3 +751,26 @@ func (cfg *Config) Apply(opts ...Option) error {
}
return nil
}

// NewConfig creates a new config with the given options.
func NewConfig(opts ...Option) (*Config, error) {
cfg := &Config{
UserAgent: "github.com/libp2p/go-libp2p",
ProtocolVersion: identify.DefaultProtocolVersion,
QUICReuse: []fx.Option{},
Transports: []fx.Option{},
Muxers: []tptu.StreamMuxer{},
SecurityTransports: []Security{},
RelayCustom: false,
Relay: true,
ListenAddrs: []ma.Multiaddr{},
SwarmOpts: []swarm.Option{},
UserFxOptions: []fx.Option{},
EnableAutoNATv2: true,
AllowSharedTCPReachability: true, // Enable by default when ShareTCPListener is used
}
if err := cfg.Apply(opts...); err != nil {
return nil, err
}
return cfg, nil
}
15 changes: 15 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,21 @@ func WithFxOption(opts ...fx.Option) Option {
func ShareTCPListener() Option {
return func(cfg *Config) error {
cfg.ShareTCPListener = true
// When using shared TCP listener, force reachability to public
// This is because the shared listener setup can interfere with AutoNAT's
// ability to properly detect reachability
public := network.ReachabilityPublic
cfg.AutoNATConfig.ForceReachability = &public
Comment on lines +655 to +659
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to fix the interference not disable autonat.

return nil
}
}

// AllowSharedTCPReachability configures whether shared TCP listeners should be considered
// for reachability detection. When enabled, addresses using shared TCP listeners will be
// included in reachability checks.
func AllowSharedTCPReachability(allow bool) Option {
return func(cfg *Config) error {
cfg.AllowSharedTCPReachability = allow
return nil
}
}
143 changes: 105 additions & 38 deletions p2p/host/autonat/svc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package autonat
import (
"context"
"errors"
"fmt"
"math/rand"
"sync"
"time"
Expand All @@ -15,6 +16,9 @@ import (
"github.com/libp2p/go-msgio/pbio"

ma "github.com/multiformats/go-multiaddr"
"github.com/libp2p/go-libp2p/p2p/net/manet"
"google.golang.org/protobuf/proto"
"github.com/libp2p/go-msgio/protoio"
)

var streamTimeout = 60 * time.Second
Expand All @@ -23,6 +27,9 @@ const (
ServiceName = "libp2p.autonat"

maxMsgSize = 4096

maxAddresses = 100
dialTimeout = 30 * time.Second
)

// AutoNATService provides NAT autodetection services to other peers
Expand Down Expand Up @@ -50,59 +57,119 @@ func newAutoNATService(c *config) (*autoNATService, error) {
}, nil
}

func (as *autoNATService) handleStream(s network.Stream) {
if err := s.Scope().SetService(ServiceName); err != nil {
log.Debugf("error attaching stream to autonat service: %s", err)
s.Reset()
return
func (s *autoNATService) handleStream(stream network.Stream) {
defer stream.Close()

if err := s.handleStreamRequest(stream); err != nil {
log.Debugf("error handling autonat request: %s", err)
resp := &pb.Message{
Type: pb.Message_DIAL_RESPONSE.Enum(),
Status: pb.Message_E_DIAL_ERROR.Enum(),
Version: proto.Int32(autoNATProtocolVersion),
}
s.writeResponse(stream, resp)
}
}

if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil {
log.Debugf("error reserving memory for autonat stream: %s", err)
s.Reset()
return
func (s *autoNATService) handleStreamRequest(stream network.Stream) error {
rd := protoio.NewDelimitedReader(stream, network.MessageSizeMax)
wr := protoio.NewDelimitedWriter(stream)

req := &pb.Message{}
if err := rd.ReadMsg(req); err != nil {
return err
}
defer s.Scope().ReleaseMemory(maxMsgSize)

s.SetDeadline(time.Now().Add(streamTimeout))
defer s.Close()
if req.GetType() != pb.Message_DIAL {
return fmt.Errorf("unknown message type: %d", req.GetType())
}

pid := s.Conn().RemotePeer()
log.Debugf("New stream from %s", pid)
if req.GetVersion() != autoNATProtocolVersion {
return fmt.Errorf("unknown protocol version: %d", req.GetVersion())
}

r := pbio.NewDelimitedReader(s, maxMsgSize)
w := pbio.NewDelimitedWriter(s)
p := stream.Conn().RemotePeer()
pi := peer.AddrInfo{ID: p}
addrs := req.GetAddresses()
if len(addrs) == 0 {
return fmt.Errorf("no addresses provided")
}

var req pb.Message
var res pb.Message
// Limit the number of addresses to check
if len(addrs) > maxAddresses {
addrs = addrs[:maxAddresses]
}

err := r.ReadMsg(&req)
if err != nil {
log.Debugf("Error reading message from %s: %s", pid, err.Error())
s.Reset()
return
// Convert addresses to multiaddrs
maddrs := make([]ma.Multiaddr, 0, len(addrs))
indices := make([]int32, 0, len(addrs))
for _, a := range addrs {
addr, err := ma.NewMultiaddrBytes(a.GetAddress())
if err != nil {
continue
}

// Check if this is a shared TCP listener address
protos := addr.Protocols()
isSharedTCP := false
if len(protos) > 0 && protos[len(protos)-1].Code == ma.P_TCP {
for _, comp := range addr.Components() {
if comp.Protocol().Code == ma.P_TCP {
isSharedTCP = true
break
}
}
}

// Include the address for dialing if it's public or using a shared TCP listener
if manet.IsPublicAddr(addr) || isSharedTCP {
maddrs = append(maddrs, addr)
indices = append(indices, a.GetIdx())
}
}

t := req.GetType()
if t != pb.Message_DIAL {
log.Debugf("Unexpected message from %s: %s (%d)", pid, t.String(), t)
s.Reset()
return
if len(maddrs) == 0 {
return fmt.Errorf("no public addresses provided")
}

dr := as.handleDial(pid, s.Conn().RemoteMultiaddr(), req.GetDial().GetPeer())
res.Type = pb.Message_DIAL_RESPONSE.Enum()
res.DialResponse = dr
pi.Addrs = maddrs
ctx, cancel := context.WithTimeout(context.Background(), dialTimeout)
defer cancel()

err = w.WriteMsg(&res)
if err != nil {
log.Debugf("Error writing response to %s: %s", pid, err.Error())
s.Reset()
return
// Try to connect to the peer
if err := s.host.Connect(ctx, pi); err != nil {
var status pb.Message_ResponseStatus
if isDialRefused(err) {
status = pb.Message_E_DIAL_REFUSED
} else {
status = pb.Message_E_DIAL_ERROR
}
resp := &pb.Message{
Type: pb.Message_DIAL_RESPONSE.Enum(),
Status: status.Enum(),
Version: proto.Int32(autoNATProtocolVersion),
}
return s.writeResponse(stream, resp)
}
if as.config.metricsTracer != nil {
as.config.metricsTracer.OutgoingDialResponse(res.GetDialResponse().GetStatus())

// Successfully connected
addr := maddrs[0]
idx := indices[0]
resp := &pb.Message{
Type: pb.Message_DIAL_RESPONSE.Enum(),
Status: pb.Message_OK.Enum(),
Version: proto.Int32(autoNATProtocolVersion),
Address: &pb.Message_Address{
Address: addr.Bytes(),
Idx: proto.Int32(idx),
},
}
return s.writeResponse(stream, resp)
}

func (s *autoNATService) writeResponse(stream network.Stream, resp *pb.Message) error {
wr := protoio.NewDelimitedWriter(stream)
return wr.WriteMsg(resp)
}

func (as *autoNATService) handleDial(p peer.ID, obsaddr ma.Multiaddr, mpi *pb.Message_PeerInfo) *pb.Message_DialResponse {
Expand Down
80 changes: 79 additions & 1 deletion p2p/protocol/autonatv2/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (ac *client) Close() {

// GetReachability verifies address reachability with a AutoNAT v2 server p.
func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request) (Result, error) {
result, err := ac.getReachability(ctx, p, reqs)
result, err := ac.getReachabilityInternal(ctx, p, reqs)

// Track metrics
if ac.metricsTracer != nil {
Expand All @@ -70,6 +70,84 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request
return result, err
}

func (ac *client) getReachabilityInternal(ctx context.Context, p peer.ID, reqs []Request) (Result, error) {
// Check if we have a connection already
if ac.host.Network().Connectedness(p) != network.Connected {
if err := ac.host.Connect(ctx, peer.AddrInfo{ID: p}); err != nil {
return Result{}, err
}
}

// Create a new stream
s, err := ac.host.NewStream(ctx, p, DialProtocol)
if err != nil {
return Result{}, err
}
defer s.Close()

// Write the dial request
req := pb.Message{
Type: pb.Message_DIAL.Enum(),
Version: proto.Int32(AutoNATProtocolVersion),
}
for _, r := range reqs {
// For TCP addresses with shared listeners, we need to ensure proper handling
protos := r.Addr.Protocols()
isSharedTCP := false
if len(protos) > 0 && protos[len(protos)-1].Code == ma.P_TCP {
// Check if this is a shared TCP listener
for _, comp := range r.Addr.Components() {
if comp.Protocol().Code == ma.P_TCP {
isSharedTCP = true
break
}
}
}

// Include the address in the request
req.Addresses = append(req.Addresses, &pb.Message_Address{
Address: r.Addr.Bytes(),
Idx: proto.Int32(int32(r.Idx)),
})
}

w := protoio.NewDelimitedWriter(s)
if err := w.WriteMsg(&req); err != nil {
return Result{}, err
}

r := protoio.NewDelimitedReader(s, network.MessageSizeMax)
resp := &pb.Message{}
if err := r.ReadMsg(resp); err != nil {
return Result{}, err
}

if resp.GetType() != pb.Message_DIAL_RESPONSE {
return Result{}, fmt.Errorf("unexpected message type: %s", resp.GetType())
}

// Process the response
status := resp.GetStatus()
switch status {
case pb.Message_OK:
addr, err := ma.NewMultiaddrBytes(resp.GetAddress().GetAddress())
if err != nil {
return Result{}, err
}
return Result{
Reachable: true,
Addr: addr,
Idx: int(resp.GetAddress().GetIdx()),
}, nil
case pb.Message_E_DIAL_ERROR:
return Result{}, ErrDialError
case pb.Message_E_DIAL_REFUSED:
return Result{}, ErrDialRefused
default:
return Result{}, fmt.Errorf("unknown status: %d", status)
}
}

func (ac *client) getReachability(ctx context.Context, p peer.ID, reqs []Request) (Result, error) {
ctx, cancel := context.WithTimeout(ctx, streamTimeout)
defer cancel()
Expand Down