Skip to content

Commit e1f0ad7

Browse files
authored
net/udprelay: implement Server.SetStaticAddrPorts (tailscale#17909)
Only used in tests for now. Updates tailscale/corp#31489 Signed-off-by: Jordan Whited <[email protected]>
1 parent a96ef43 commit e1f0ad7

File tree

4 files changed

+64
-72
lines changed

4 files changed

+64
-72
lines changed

feature/relayserver/relayserver.go

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,10 @@ package relayserver
88
import (
99
"encoding/json"
1010
"fmt"
11-
"log"
1211
"net/http"
13-
"net/netip"
14-
"strings"
1512
"sync"
1613

1714
"tailscale.com/disco"
18-
"tailscale.com/envknob"
1915
"tailscale.com/feature"
2016
"tailscale.com/ipn"
2117
"tailscale.com/ipn/ipnext"
@@ -71,8 +67,8 @@ func servePeerRelayDebugSessions(h *localapi.Handler, w http.ResponseWriter, r *
7167
// imported.
7268
func newExtension(logf logger.Logf, sb ipnext.SafeBackend) (ipnext.Extension, error) {
7369
e := &extension{
74-
newServerFn: func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) {
75-
return udprelay.NewServer(logf, port, overrideAddrs)
70+
newServerFn: func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) {
71+
return udprelay.NewServer(logf, port, onlyStaticAddrPorts)
7672
},
7773
logf: logger.WithPrefix(logf, featureName+": "),
7874
}
@@ -94,7 +90,7 @@ type relayServer interface {
9490
// extension is an [ipnext.Extension] managing the relay server on platforms
9591
// that import this package.
9692
type extension struct {
97-
newServerFn func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) // swappable for tests
93+
newServerFn func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) // swappable for tests
9894
logf logger.Logf
9995
ec *eventbus.Client
10096
respPub *eventbus.Publisher[magicsock.UDPRelayAllocResp]
@@ -170,7 +166,7 @@ func (e *extension) onAllocReq(req magicsock.UDPRelayAllocReq) {
170166
}
171167

172168
func (e *extension) tryStartRelayServerLocked() {
173-
rs, err := e.newServerFn(e.logf, *e.port, overrideAddrs())
169+
rs, err := e.newServerFn(e.logf, *e.port, false)
174170
if err != nil {
175171
e.logf("error initializing server: %v", err)
176172
return
@@ -217,26 +213,6 @@ func (e *extension) profileStateChanged(_ ipn.LoginProfileView, prefs ipn.PrefsV
217213
e.handleRelayServerLifetimeLocked()
218214
}
219215

220-
// overrideAddrs returns TS_DEBUG_RELAY_SERVER_ADDRS as []netip.Addr, if set. It
221-
// can be between 0 and 3 comma-separated Addrs. TS_DEBUG_RELAY_SERVER_ADDRS is
222-
// not a stable interface, and is subject to change.
223-
var overrideAddrs = sync.OnceValue(func() (ret []netip.Addr) {
224-
all := envknob.String("TS_DEBUG_RELAY_SERVER_ADDRS")
225-
const max = 3
226-
remain := all
227-
for remain != "" && len(ret) < max {
228-
var s string
229-
s, remain, _ = strings.Cut(remain, ",")
230-
addr, err := netip.ParseAddr(s)
231-
if err != nil {
232-
log.Printf("ignoring invalid Addr %q in TS_DEBUG_RELAY_SERVER_ADDRS %q: %v", s, all, err)
233-
continue
234-
}
235-
ret = append(ret, addr)
236-
}
237-
return
238-
})
239-
240216
func (e *extension) stopRelayServerLocked() {
241217
if e.rs != nil {
242218
e.rs.Close()

feature/relayserver/relayserver_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ package relayserver
55

66
import (
77
"errors"
8-
"net/netip"
98
"reflect"
109
"testing"
1110

@@ -157,7 +156,7 @@ func Test_extension_profileStateChanged(t *testing.T) {
157156
t.Fatal(err)
158157
}
159158
e := ipne.(*extension)
160-
e.newServerFn = func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) {
159+
e.newServerFn = func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) {
161160
return &mockRelayServer{}, nil
162161
}
163162
e.port = tt.fields.port
@@ -289,7 +288,7 @@ func Test_extension_handleRelayServerLifetimeLocked(t *testing.T) {
289288
t.Fatal(err)
290289
}
291290
e := ipne.(*extension)
292-
e.newServerFn = func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) {
291+
e.newServerFn = func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) {
293292
return &mockRelayServer{}, nil
294293
}
295294
e.shutdown = tt.shutdown

net/udprelay/server.go

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import (
3636
"tailscale.com/types/key"
3737
"tailscale.com/types/logger"
3838
"tailscale.com/types/nettype"
39+
"tailscale.com/types/views"
3940
"tailscale.com/util/eventbus"
4041
"tailscale.com/util/set"
4142
)
@@ -72,15 +73,16 @@ type Server struct {
7273
closeCh chan struct{}
7374
netChecker *netcheck.Client
7475

75-
mu sync.Mutex // guards the following fields
76-
derpMap *tailcfg.DERPMap
77-
addrDiscoveryOnce bool // addrDiscovery completed once (successfully or unsuccessfully)
78-
addrPorts []netip.AddrPort // the ip:port pairs returned as candidate endpoints
79-
closed bool
80-
lamportID uint64
81-
nextVNI uint32
82-
byVNI map[uint32]*serverEndpoint
83-
byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint
76+
mu sync.Mutex // guards the following fields
77+
derpMap *tailcfg.DERPMap
78+
onlyStaticAddrPorts bool // no dynamic addr port discovery when set
79+
staticAddrPorts views.Slice[netip.AddrPort] // static ip:port pairs set with [Server.SetStaticAddrPorts]
80+
dynamicAddrPorts []netip.AddrPort // dynamically discovered ip:port pairs
81+
closed bool
82+
lamportID uint64
83+
nextVNI uint32
84+
byVNI map[uint32]*serverEndpoint
85+
byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint
8486
}
8587

8688
const (
@@ -278,15 +280,17 @@ func (e *serverEndpoint) isBound() bool {
278280

279281
// NewServer constructs a [Server] listening on port. If port is zero, then
280282
// port selection is left up to the host networking stack. If
281-
// len(overrideAddrs) > 0 these will be used in place of dynamic discovery,
282-
// which is useful to override in tests.
283-
func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, err error) {
283+
// onlyStaticAddrPorts is true, then dynamic addr:port discovery will be
284+
// disabled, and only addr:port's set via [Server.SetStaticAddrPorts] will be
285+
// used.
286+
func NewServer(logf logger.Logf, port int, onlyStaticAddrPorts bool) (s *Server, err error) {
284287
s = &Server{
285288
logf: logf,
286289
disco: key.NewDisco(),
287290
bindLifetime: defaultBindLifetime,
288291
steadyStateLifetime: defaultSteadyStateLifetime,
289292
closeCh: make(chan struct{}),
293+
onlyStaticAddrPorts: onlyStaticAddrPorts,
290294
byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint),
291295
nextVNI: minVNI,
292296
byVNI: make(map[uint32]*serverEndpoint),
@@ -321,19 +325,7 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve
321325
return nil, err
322326
}
323327

324-
if len(overrideAddrs) > 0 {
325-
addrPorts := make(set.Set[netip.AddrPort], len(overrideAddrs))
326-
for _, addr := range overrideAddrs {
327-
if addr.IsValid() {
328-
if addr.Is4() {
329-
addrPorts.Add(netip.AddrPortFrom(addr, s.uc4Port))
330-
} else if s.uc6 != nil {
331-
addrPorts.Add(netip.AddrPortFrom(addr, s.uc6Port))
332-
}
333-
}
334-
}
335-
s.addrPorts = addrPorts.Slice()
336-
} else {
328+
if !s.onlyStaticAddrPorts {
337329
s.wg.Add(1)
338330
go s.addrDiscoveryLoop()
339331
}
@@ -429,8 +421,7 @@ func (s *Server) addrDiscoveryLoop() {
429421
s.logf("error discovering IP:port candidates: %v", err)
430422
}
431423
s.mu.Lock()
432-
s.addrPorts = addrPorts
433-
s.addrDiscoveryOnce = true
424+
s.dynamicAddrPorts = addrPorts
434425
s.mu.Unlock()
435426
case <-s.closeCh:
436427
return
@@ -747,6 +738,15 @@ func (s *Server) getNextVNILocked() (uint32, error) {
747738
return 0, errors.New("VNI pool exhausted")
748739
}
749740

741+
// getAllAddrPortsCopyLocked returns a copy of the combined
742+
// [Server.staticAddrPorts] and [Server.dynamicAddrPorts] slices.
743+
func (s *Server) getAllAddrPortsCopyLocked() []netip.AddrPort {
744+
addrPorts := make([]netip.AddrPort, 0, len(s.dynamicAddrPorts)+s.staticAddrPorts.Len())
745+
addrPorts = append(addrPorts, s.staticAddrPorts.AsSlice()...)
746+
addrPorts = append(addrPorts, slices.Clone(s.dynamicAddrPorts)...)
747+
return addrPorts
748+
}
749+
750750
// AllocateEndpoint allocates an [endpoint.ServerEndpoint] for the provided pair
751751
// of [key.DiscoPublic]'s. If an allocation already exists for discoA and discoB
752752
// it is returned without modification/reallocation. AllocateEndpoint returns
@@ -760,11 +760,8 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
760760
return endpoint.ServerEndpoint{}, ErrServerClosed
761761
}
762762

763-
if len(s.addrPorts) == 0 {
764-
if !s.addrDiscoveryOnce {
765-
return endpoint.ServerEndpoint{}, ErrServerNotReady{RetryAfter: endpoint.ServerRetryAfter}
766-
}
767-
return endpoint.ServerEndpoint{}, errors.New("server addrPorts are not yet known")
763+
if s.staticAddrPorts.Len() == 0 && len(s.dynamicAddrPorts) == 0 {
764+
return endpoint.ServerEndpoint{}, ErrServerNotReady{RetryAfter: endpoint.ServerRetryAfter}
768765
}
769766

770767
if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 {
@@ -787,7 +784,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
787784
// consider storing them (maybe interning) in the [*serverEndpoint]
788785
// at allocation time.
789786
ClientDisco: pair.Get(),
790-
AddrPorts: slices.Clone(s.addrPorts),
787+
AddrPorts: s.getAllAddrPortsCopyLocked(),
791788
VNI: e.vni,
792789
LamportID: e.lamportID,
793790
BindLifetime: tstime.GoDuration{Duration: s.bindLifetime},
@@ -817,7 +814,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
817814
return endpoint.ServerEndpoint{
818815
ServerDisco: s.discoPublic,
819816
ClientDisco: pair.Get(),
820-
AddrPorts: slices.Clone(s.addrPorts),
817+
AddrPorts: s.getAllAddrPortsCopyLocked(),
821818
VNI: e.vni,
822819
LamportID: e.lamportID,
823820
BindLifetime: tstime.GoDuration{Duration: s.bindLifetime},
@@ -880,3 +877,13 @@ func (s *Server) getDERPMap() *tailcfg.DERPMap {
880877
defer s.mu.Unlock()
881878
return s.derpMap
882879
}
880+
881+
// SetStaticAddrPorts sets addr:port pairs the [Server] will advertise
882+
// as candidates it is potentially reachable over, in combination with
883+
// dynamically discovered pairs. This replaces any previously-provided static
884+
// values.
885+
func (s *Server) SetStaticAddrPorts(addrPorts views.Slice[netip.AddrPort]) {
886+
s.mu.Lock()
887+
defer s.mu.Unlock()
888+
s.staticAddrPorts = addrPorts
889+
}

net/udprelay/server_test.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"tailscale.com/disco"
1818
"tailscale.com/net/packet"
1919
"tailscale.com/types/key"
20+
"tailscale.com/types/views"
2021
)
2122

2223
type testClient struct {
@@ -185,31 +186,40 @@ func TestServer(t *testing.T) {
185186

186187
cases := []struct {
187188
name string
188-
overrideAddrs []netip.Addr
189+
staticAddrs []netip.Addr
189190
forceClientsMixedAF bool
190191
}{
191192
{
192-
name: "over ipv4",
193-
overrideAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
193+
name: "over ipv4",
194+
staticAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
194195
},
195196
{
196-
name: "over ipv6",
197-
overrideAddrs: []netip.Addr{netip.MustParseAddr("::1")},
197+
name: "over ipv6",
198+
staticAddrs: []netip.Addr{netip.MustParseAddr("::1")},
198199
},
199200
{
200201
name: "mixed address families",
201-
overrideAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("::1")},
202+
staticAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("::1")},
202203
forceClientsMixedAF: true,
203204
},
204205
}
205206

206207
for _, tt := range cases {
207208
t.Run(tt.name, func(t *testing.T) {
208-
server, err := NewServer(t.Logf, 0, tt.overrideAddrs)
209+
server, err := NewServer(t.Logf, 0, true)
209210
if err != nil {
210211
t.Fatal(err)
211212
}
212213
defer server.Close()
214+
addrPorts := make([]netip.AddrPort, 0, len(tt.staticAddrs))
215+
for _, addr := range tt.staticAddrs {
216+
if addr.Is4() {
217+
addrPorts = append(addrPorts, netip.AddrPortFrom(addr, server.uc4Port))
218+
} else if server.uc6Port != 0 {
219+
addrPorts = append(addrPorts, netip.AddrPortFrom(addr, server.uc6Port))
220+
}
221+
}
222+
server.SetStaticAddrPorts(views.SliceOf(addrPorts))
213223

214224
endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
215225
if err != nil {

0 commit comments

Comments
 (0)