Skip to content

Commit 5fcce02

Browse files
committed
add tests for present behavior
1 parent fd76100 commit 5fcce02

File tree

3 files changed

+251
-78
lines changed

3 files changed

+251
-78
lines changed

p2p/host/basic/address_service.go

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99

1010
"github.com/libp2p/go-libp2p/core/event"
1111
"github.com/libp2p/go-libp2p/core/network"
12-
"github.com/libp2p/go-libp2p/core/record"
1312
"github.com/libp2p/go-libp2p/core/transport"
1413
"github.com/libp2p/go-libp2p/p2p/host/basic/internal/backoff"
1514
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
@@ -19,8 +18,6 @@ import (
1918
manet "github.com/multiformats/go-multiaddr/net"
2019
)
2120

22-
type peerRecordFunc func([]ma.Multiaddr) (*record.Envelope, error)
23-
2421
type observedAddrsService interface {
2522
OwnObservedAddrs() []ma.Multiaddr
2623
ObservedAddrsFor(local ma.Multiaddr) []ma.Multiaddr
@@ -34,12 +31,13 @@ type addressService struct {
3431
addrsChangeChan chan struct{}
3532
addrsUpdated chan struct{}
3633
autoRelayAddrsSub event.Subscription
37-
autoRelayAddrs func() []ma.Multiaddr
38-
reachability func() network.Reachability
39-
ifaceAddrs *interfaceAddrsCache
40-
wg sync.WaitGroup
41-
ctx context.Context
42-
ctxCancel context.CancelFunc
34+
// There are wrapped in to functions for mocking
35+
autoRelayAddrs func() []ma.Multiaddr
36+
reachability func() network.Reachability
37+
ifaceAddrs *interfaceAddrsCache
38+
wg sync.WaitGroup
39+
ctx context.Context
40+
ctxCancel context.CancelFunc
4341
}
4442

4543
func NewAddressService(h *BasicHost, natmgr func(network.Network) NATManager,
@@ -177,19 +175,22 @@ func (a *addressService) AllAddrs() []ma.Multiaddr {
177175

178176
finalAddrs := make([]ma.Multiaddr, 0, 8)
179177
finalAddrs = a.appendInterfaceAddrs(finalAddrs, listenAddrs)
180-
181-
// use nat mappings if we have them
182178
finalAddrs = a.appendNATAddrs(finalAddrs, listenAddrs)
183179
finalAddrs = ma.Unique(finalAddrs)
184180

185-
// Remove /p2p-circuit addresses from the list.
186-
// The p2p-circuit transport listener reports its address as just /p2p-circuit
187-
// This is useless for dialing. Users need to manage their circuit addresses themselves,
181+
// Remove "/p2p-circuit" addresses from the list.
182+
// The p2p-circuit listener reports its address as just /p2p-circuit. This is
183+
// useless for dialing. Users need to manage their circuit addresses themselves,
188184
// or use AutoRelay.
189185
finalAddrs = slices.DeleteFunc(finalAddrs, func(a ma.Multiaddr) bool {
190186
return a.Equal(p2pCircuitAddr)
191187
})
192188

189+
// Remove any unspecified address from the list
190+
finalAddrs = slices.DeleteFunc(finalAddrs, func(a ma.Multiaddr) bool {
191+
return manet.IsIPUnspecified(a)
192+
})
193+
193194
// Add certhashes for /webrtc-direct, /webtransport, etc addresses discovered
194195
// using identify.
195196
finalAddrs = a.addCertHashes(finalAddrs)
@@ -208,19 +209,23 @@ func (a *addressService) appendInterfaceAddrs(result []ma.Multiaddr, listenAddrs
208209
return result
209210
}
210211

212+
// appendNATAddrs appends the NAT-ed addrs for the listenAddrs. For unspecified listen addrs it appends the
213+
// public address for all the interfaces.
214+
// This automatically infers addresses from other transport addresses. For example, it'll infer a webtransport
215+
// address from a quic observed address.
216+
//
217+
// TODO: Merge the natmgr and identify.ObservedAddrManager in to one NatMapper module.
211218
func (a *addressService) appendNATAddrs(result []ma.Multiaddr, listenAddrs []ma.Multiaddr) []ma.Multiaddr {
212219
ifaceAddrs := a.ifaceAddrs.All()
213-
// use nat mappings if we have them
214-
if a.natmgr != nil && a.natmgr.HasDiscoveredNAT() {
215-
// we have a NAT device
216-
for _, listen := range listenAddrs {
217-
extMaddr := a.natmgr.GetMapping(listen)
218-
result = appendNATAddrsForListenAddrs(result, listen, extMaddr, a.observedAddrsService.ObservedAddrsFor, ifaceAddrs)
219-
}
220-
} else {
220+
if a.natmgr == nil || !a.natmgr.HasDiscoveredNAT() {
221221
if a.observedAddrsService != nil {
222222
result = append(result, a.observedAddrsService.OwnObservedAddrs()...)
223223
}
224+
return result
225+
}
226+
for _, listen := range listenAddrs {
227+
extMaddr := a.natmgr.GetMapping(listen)
228+
result = appendNATAddrsForListenAddrs(result, listen, extMaddr, a.observedAddrsService.ObservedAddrsFor, ifaceAddrs)
224229
}
225230
return result
226231
}
@@ -241,11 +246,6 @@ func (a *addressService) addCertHashes(addrs []ma.Multiaddr) []ma.Multiaddr {
241246
return addrs
242247
}
243248

244-
// Copy addrs slice since we'll be modifying it.
245-
addrsOld := addrs
246-
addrs = make([]ma.Multiaddr, len(addrsOld))
247-
copy(addrs, addrsOld)
248-
249249
for i, addr := range addrs {
250250
wtOK, wtN := libp2pwebtransport.IsWebtransportMultiaddr(addr)
251251
webrtcOK, webrtcN := libp2pwebrtc.IsWebRTCDirectMultiaddr(addr)
@@ -411,6 +411,8 @@ func (i *interfaceAddrsCache) updateUnlocked() {
411411
}
412412
}
413413

414+
// getAllPossibleLocalAddrs gives all the possible address returned for `conn.LocalAddr` correspoinding
415+
// to the `listenAddr`
414416
func getAllPossibleLocalAddrs(listenAddr ma.Multiaddr, ifaceAddrs []ma.Multiaddr) []ma.Multiaddr {
415417
// If the nat mapping fails, use the observed addrs
416418
resolved, err := manet.ResolveUnspecifiedAddress(listenAddr, ifaceAddrs)

p2p/host/basic/address_service_test.go

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ package basichost
22

33
import (
44
"testing"
5+
"time"
56

7+
"github.com/libp2p/go-libp2p/core/network"
8+
swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing"
69
ma "github.com/multiformats/go-multiaddr"
710
manet "github.com/multiformats/go-multiaddr/net"
811
"github.com/stretchr/testify/require"
@@ -96,3 +99,174 @@ func TestAppendNATAddrs(t *testing.T) {
9699
})
97100
}
98101
}
102+
103+
type mockNatManager struct {
104+
GetMappingFunc func(addr ma.Multiaddr) ma.Multiaddr
105+
HasDiscoveredNATFunc func() bool
106+
}
107+
108+
func (m *mockNatManager) Close() error {
109+
return nil
110+
}
111+
112+
func (m *mockNatManager) GetMapping(addr ma.Multiaddr) ma.Multiaddr {
113+
return m.GetMappingFunc(addr)
114+
}
115+
116+
func (m *mockNatManager) HasDiscoveredNAT() bool {
117+
return m.HasDiscoveredNATFunc()
118+
}
119+
120+
var _ NATManager = &mockNatManager{}
121+
122+
type mockObservedAddrs struct {
123+
OwnObservedAddrsFunc func() []ma.Multiaddr
124+
ObservedAddrsForFunc func(ma.Multiaddr) []ma.Multiaddr
125+
}
126+
127+
func (m *mockObservedAddrs) OwnObservedAddrs() []ma.Multiaddr {
128+
return m.OwnObservedAddrsFunc()
129+
}
130+
131+
func (m *mockObservedAddrs) ObservedAddrsFor(local ma.Multiaddr) []ma.Multiaddr {
132+
return m.ObservedAddrsForFunc(local)
133+
}
134+
135+
func TestAddressService(t *testing.T) {
136+
getAddrService := func() *addressService {
137+
h, err := NewHost(swarmt.GenSwarm(t), &HostOpts{DisableIdentifyAddressDiscovery: true})
138+
require.NoError(t, err)
139+
t.Cleanup(func() { h.Close() })
140+
141+
as := h.addressService
142+
return as
143+
}
144+
145+
t.Run("NAT Address", func(t *testing.T) {
146+
as := getAddrService()
147+
as.natmgr = &mockNatManager{
148+
HasDiscoveredNATFunc: func() bool { return true },
149+
GetMappingFunc: func(addr ma.Multiaddr) ma.Multiaddr {
150+
if _, err := addr.ValueForProtocol(ma.P_UDP); err == nil {
151+
return ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1")
152+
}
153+
return nil
154+
},
155+
}
156+
require.Contains(t, as.Addrs(), ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1"))
157+
})
158+
159+
t.Run("NAT And Observed Address", func(t *testing.T) {
160+
as := getAddrService()
161+
as.natmgr = &mockNatManager{
162+
HasDiscoveredNATFunc: func() bool { return true },
163+
GetMappingFunc: func(addr ma.Multiaddr) ma.Multiaddr {
164+
if _, err := addr.ValueForProtocol(ma.P_UDP); err == nil {
165+
return ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1")
166+
}
167+
return nil
168+
},
169+
}
170+
as.observedAddrsService = &mockObservedAddrs{
171+
ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr {
172+
if _, err := addr.ValueForProtocol(ma.P_TCP); err == nil {
173+
return []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/tcp/1")}
174+
}
175+
return nil
176+
},
177+
}
178+
require.Contains(t, as.Addrs(), ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1"))
179+
require.Contains(t, as.Addrs(), ma.StringCast("/ip4/2.2.2.2/tcp/1"))
180+
})
181+
t.Run("Only Observed Address", func(t *testing.T) {
182+
as := getAddrService()
183+
as.natmgr = nil
184+
as.observedAddrsService = &mockObservedAddrs{
185+
ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr {
186+
if _, err := addr.ValueForProtocol(ma.P_TCP); err == nil {
187+
return []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/tcp/1")}
188+
}
189+
return nil
190+
},
191+
OwnObservedAddrsFunc: func() []ma.Multiaddr {
192+
return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")}
193+
},
194+
}
195+
require.NotContains(t, as.Addrs(), ma.StringCast("/ip4/2.2.2.2/tcp/1"))
196+
require.Contains(t, as.Addrs(), ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1"))
197+
})
198+
t.Run("Public Addrs Removed When Private", func(t *testing.T) {
199+
as := getAddrService()
200+
as.natmgr = nil
201+
as.observedAddrsService = &mockObservedAddrs{
202+
OwnObservedAddrsFunc: func() []ma.Multiaddr {
203+
return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")}
204+
},
205+
}
206+
as.reachability = func() network.Reachability {
207+
return network.ReachabilityPrivate
208+
}
209+
relayAddr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/p2p/QmdXGaeGiVA745XorV1jr11RHxB9z4fqykm6xCUPX1aTJo/p2p-circuit")
210+
as.autoRelayAddrs = func() []ma.Multiaddr {
211+
return []ma.Multiaddr{relayAddr}
212+
}
213+
require.NotContains(t, as.Addrs(), ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1"))
214+
require.Contains(t, as.Addrs(), relayAddr)
215+
require.Contains(t, as.AllAddrs(), ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1"))
216+
})
217+
218+
t.Run("AddressFactory gets relay addresses", func(t *testing.T) {
219+
as := getAddrService()
220+
as.natmgr = nil
221+
as.observedAddrsService = &mockObservedAddrs{
222+
OwnObservedAddrsFunc: func() []ma.Multiaddr {
223+
return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")}
224+
},
225+
}
226+
as.reachability = func() network.Reachability {
227+
return network.ReachabilityPrivate
228+
}
229+
relayAddr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/p2p/QmdXGaeGiVA745XorV1jr11RHxB9z4fqykm6xCUPX1aTJo/p2p-circuit")
230+
as.autoRelayAddrs = func() []ma.Multiaddr {
231+
return []ma.Multiaddr{relayAddr}
232+
}
233+
as.addrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr {
234+
for _, a := range addrs {
235+
if a.Equal(relayAddr) {
236+
return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")}
237+
}
238+
}
239+
return nil
240+
}
241+
require.Contains(t, as.Addrs(), ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1"))
242+
require.NotContains(t, as.Addrs(), relayAddr)
243+
})
244+
245+
t.Run("updates addresses on signaling", func(t *testing.T) {
246+
as := getAddrService()
247+
as.natmgr = nil
248+
updateChan := make(chan struct{})
249+
a1 := ma.StringCast("/ip4/1.1.1.1/udp/1/quic-v1")
250+
a2 := ma.StringCast("/ip4/1.1.1.1/tcp/1")
251+
as.addrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr {
252+
select {
253+
case <-updateChan:
254+
return []ma.Multiaddr{a2}
255+
default:
256+
return []ma.Multiaddr{a1}
257+
}
258+
}
259+
as.Start()
260+
require.Contains(t, as.Addrs(), a1)
261+
require.NotContains(t, as.Addrs(), a2)
262+
close(updateChan)
263+
as.SignalAddressChange()
264+
select {
265+
case <-as.AddrsUpdated():
266+
require.Contains(t, as.Addrs(), a2)
267+
require.NotContains(t, as.Addrs(), a1)
268+
case <-time.After(2 * time.Second):
269+
t.Fatal("expected addrs to be updated")
270+
}
271+
})
272+
}

0 commit comments

Comments
 (0)