Skip to content

Commit 4a7e228

Browse files
committed
dnsx/mdns: propogate link changes to existing mdns
1 parent d1dbfe7 commit 4a7e228

File tree

4 files changed

+57
-22
lines changed

4 files changed

+57
-22
lines changed

intra/dns.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ func newDNSCryptTransport(ctx context.Context, px ipn.ProxyProvider, bdg Bridge)
113113
return
114114
}
115115

116-
func newMDNSTransport(ctx context.Context, protos string, px ipn.ProxyProvider) (d dnsx.Transport) {
116+
func newMDNSTransport(ctx context.Context, protos string, px ipn.ProxyProvider) (d dnsx.MDNSTransport) {
117117
return dns53.NewMDNSTransport(ctx, protos, px)
118118
}
119119

intra/dns53/mdns.go

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ type dnssd struct {
4747
dialer ipn.Proxy
4848
id string // ID of this transport
4949
ipport string // IP:Port queries are sent to (v4)
50-
use4 bool // Use IPv4
51-
use6 bool // Use IPv6
50+
use4 atomic.Bool // Use IPv4
51+
use6 atomic.Bool // Use IPv6
5252
status *core.Volatile[int] // Status of this transport
5353
est core.P2QuantileEstimator
5454
}
5555

56-
var _ dnsx.Transport = (*dnssd)(nil)
56+
var _ dnsx.MDNSTransport = (*dnssd)(nil)
5757

5858
// NewMDNSTransport returns a DNS transport that sends all DNS queries to mDNS endpoint.
5959
func NewMDNSTransport(pctx context.Context, protos string, pxr ipn.ProxyProvider) *dnssd {
@@ -72,12 +72,12 @@ func NewMDNSTransport(pctx context.Context, protos string, pxr ipn.ProxyProvider
7272
done: done,
7373
id: dnsx.Local,
7474
dialer: exit,
75-
use4: use4(protos),
76-
use6: use6(protos),
7775
ipport: xdns.MDNSAddr4.String(), // ip6: ff02::fb:5353
7876
status: core.NewVolatile(dnsx.Start),
7977
est: core.NewP50Estimator(ctx),
8078
}
79+
t.use4.Store(use4(protos))
80+
t.use6.Store(use6(protos))
8181
log.I("mdns: setup: %s", protos)
8282
return t
8383
}
@@ -100,6 +100,14 @@ func use6(l3 string) bool {
100100
}
101101
}
102102

103+
func (t *dnssd) RefreshProto(protos string) {
104+
n4 := use4(protos)
105+
n6 := use6(protos)
106+
o4 := t.use4.Swap(n4)
107+
o6 := t.use6.Swap(n6)
108+
log.I("mdns: proto change: %s; 4(%s => %s) 6(%s => %s)", protos, o4, n4, o6, n6)
109+
}
110+
103111
func (t *dnssd) oneshotQuery(msg *dns.Msg) (*dns.Msg, *dnsx.QueryError) {
104112
if qerr := dnsx.WillErr(t); qerr != nil {
105113
return nil, qerr
@@ -291,15 +299,17 @@ func (c *client) String() string {
291299

292300
// newClient creates a new mdns unicast and multicast client
293301
func (t *dnssd) newClient(oneshot bool) (*client, error) {
294-
if !t.use4 && !t.use6 {
302+
use4 := t.use4.Load()
303+
use6 := t.use6.Load()
304+
if !use4 && !use6 {
295305
return nil, errNoProtos
296306
}
297307

298308
var uconn4, uconn6 net.PacketConn // bind to higher port for unicast
299309
var mconn4, mconn6 net.PacketConn // bind to port 5353 for multicast
300310
var err error
301311

302-
if t.use4 {
312+
if use4 {
303313
uconn4, err = t.dialer.Announce("udp4", "0.0.0.0:0")
304314
if err != nil {
305315
log.E("mdns: new-client: unicast4 bind fail: %v", err)
@@ -314,7 +324,7 @@ func (t *dnssd) newClient(oneshot bool) (*client, error) {
314324
}
315325
}
316326

317-
if t.use6 {
327+
if use6 {
318328
uconn6, err = t.dialer.Announce("udp6", "[::]:0")
319329
if err != nil {
320330
log.E("mdns: new-client: unicast6 bind fail: %v", err)
@@ -328,16 +338,16 @@ func (t *dnssd) newClient(oneshot bool) (*client, error) {
328338
}
329339
}
330340

331-
has4 := t.use4 && uconn4 != nil && (oneshot || mconn4 != nil)
332-
has6 := t.use6 && uconn6 != nil && (oneshot || mconn6 != nil)
341+
has4 := use4 && uconn4 != nil && (oneshot || mconn4 != nil)
342+
has6 := use6 && uconn6 != nil && (oneshot || mconn6 != nil)
333343
if !has4 && !has6 {
334344
log.E("mdns: new-client: oneshot? %t with no4? %t / no6? %t", oneshot, has4, has6)
335345
return nil, errBindFail
336346
}
337347

338348
c := &client{
339-
use4: t.use4,
340-
use6: t.use6,
349+
use4: use4,
350+
use6: use6,
341351
multicast4: mconn4, // nil if oneshot
342352
multicast6: mconn6, // nil if oneshot
343353
unicast4: uconn4,

intra/dnsx/transport.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,16 @@ var (
107107
errBlockFreeTransport = errors.New("dns: block free transport")
108108
errNoRdns = errors.New("dns: no rdns")
109109
errTransportNotMult = errors.New("dns: not a multi-transport")
110+
errTransportNotMDNS = errors.New("dns: not an mdns transport")
110111
errMissingQueryName = errors.New("dns: no query name")
111112
errResolverClosed = errors.New("dns: closed for business")
112113
)
113114

115+
type MDNSTransport interface {
116+
Transport
117+
RefreshProto(protos string)
118+
}
119+
114120
// Transport represents a DNS query transport. This interface is exported by gobind,
115121
// so it has to be very simple.
116122
type Transport interface {
@@ -150,6 +156,8 @@ type TransportProviderInternal interface {
150156

151157
// Gateway implements a DNS ALG transport
152158
Gateway() Gateway
159+
// MDNS returns the mdns transport, if available; error otherwise.
160+
MDNS() (MDNSTransport, error)
153161
}
154162

155163
type Resolver interface {
@@ -260,6 +268,18 @@ func (r *resolver) Gateway() Gateway {
260268
return r.gateway
261269
}
262270

271+
func (r *resolver) MDNS() (MDNSTransport, error) {
272+
r.RLock()
273+
defer r.RUnlock()
274+
if t, ok := r.transports[Local]; ok {
275+
if mdnst, ok := t.(MDNSTransport); ok {
276+
return mdnst, nil
277+
}
278+
return nil, errTransportNotMDNS
279+
}
280+
return nil, errNoSuchTransport
281+
}
282+
263283
func (r *resolver) Translate(b bool) {
264284
r.gateway.translate(b)
265285
}

intra/tunnel.go

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -305,15 +305,18 @@ func (t *rtunnel) SetLinkAndRoutes2(fd, tunmtu, linkmtu, engine int) error {
305305

306306
err := tunnel.SetLinkAndRoutes(fd, tunmtu, engine) // route is always dual-stack
307307

308+
if l3diff {
309+
if mdns, err := t.resolver.MDNS(); err == nil {
310+
mdns.RefreshProto(l3)
311+
}
312+
}
313+
308314
// TODO: skip refresh on err?
309315
core.Gx("i.setLinkAndRoutesRefresh", func() {
310316
if l3diff || mtudiff {
311-
// dialers.IPProtos must always preced calls to other refreshes
317+
// dialers.IPProtos must always precede calls to other refreshes
312318
// as it carries the global state for dialers and ipn/multihost
313-
go t.proxies.RefreshProto(l3, linkmtu, false /*force*/)
314-
}
315-
if l3diff {
316-
t.resolver.Add(newMDNSTransport(t.ctx, l3, t.proxies))
319+
t.proxies.RefreshProto(l3, linkmtu, false /*force*/)
317320
}
318321
})
319322

@@ -365,12 +368,14 @@ func (t *rtunnel) Restart(fd, linkmtu, tunmtu, engine int) error {
365368

366369
log.I("tun: <<< restart >>>; for: %d (linkmtu: %d / tunmtu: %d), netstack ok; rev err? %v", fd, linkmtu, tunmtu, rerr)
367370

371+
if l3diff {
372+
if mdns, err := t.resolver.MDNS(); err == nil {
373+
mdns.RefreshProto(l3)
374+
}
375+
}
368376
core.Gx("i.RestartRefresh", func() {
369377
// Refresh proxies to update to the new reverser
370-
go t.proxies.RefreshProto(l3, linkmtu, true /*force; reverser changed*/) // also updates reverser
371-
if l3diff {
372-
t.resolver.Add(newMDNSTransport(t.ctx, l3, t.proxies))
373-
}
378+
t.proxies.RefreshProto(l3, linkmtu, true /*force; reverser changed*/) // also updates reverser
374379
})
375380

376381
return err

0 commit comments

Comments
 (0)