Skip to content

Commit f9db021

Browse files
Merge pull request #16 from libp2p/shutdown
implement a clean shutdown of the probe method
2 parents a3c4995 + f91e31c commit f9db021

File tree

2 files changed

+60
-33
lines changed

2 files changed

+60
-33
lines changed

server.go

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package zeroconf
22

33
import (
4-
"errors"
54
"fmt"
65
"log"
76
"math/rand"
@@ -75,8 +74,7 @@ func Register(instance, service, domain string, port int, text []string, ifaces
7574
}
7675

7776
s.service = entry
78-
go s.mainloop()
79-
go s.probe()
77+
s.start()
8078

8179
return s, nil
8280
}
@@ -132,8 +130,7 @@ func RegisterProxy(instance, service, domain string, port int, host string, ips
132130
}
133131

134132
s.service = entry
135-
go s.mainloop()
136-
go s.probe()
133+
s.start()
137134

138135
return s, nil
139136
}
@@ -151,7 +148,7 @@ type Server struct {
151148

152149
shouldShutdown chan struct{}
153150
shutdownLock sync.Mutex
154-
shutdownEnd sync.WaitGroup
151+
refCount sync.WaitGroup
155152
isShutdown bool
156153
ttl uint32
157154
}
@@ -182,19 +179,17 @@ func newServer(ifaces []net.Interface) (*Server, error) {
182179
return s, nil
183180
}
184181

185-
// Start listeners and waits for the shutdown signal from exit channel
186-
func (s *Server) mainloop() {
182+
func (s *Server) start() {
187183
if s.ipv4conn != nil {
184+
s.refCount.Add(1)
188185
go s.recv4(s.ipv4conn)
189186
}
190187
if s.ipv6conn != nil {
188+
s.refCount.Add(1)
191189
go s.recv6(s.ipv6conn)
192190
}
193-
}
194-
195-
// Shutdown closes all udp connections and unregisters the service
196-
func (s *Server) Shutdown() {
197-
s.shutdown()
191+
s.refCount.Add(1)
192+
go s.probe()
198193
}
199194

200195
// SetText updates and announces the TXT records
@@ -208,15 +203,17 @@ func (s *Server) TTL(ttl uint32) {
208203
s.ttl = ttl
209204
}
210205

211-
// Shutdown server will close currently open connections & channel
212-
func (s *Server) shutdown() error {
206+
// Shutdown closes all udp connections and unregisters the service
207+
func (s *Server) Shutdown() {
213208
s.shutdownLock.Lock()
214209
defer s.shutdownLock.Unlock()
215210
if s.isShutdown {
216-
return errors.New("server is already shutdown")
211+
return
217212
}
218213

219-
err := s.unregister()
214+
if err := s.unregister(); err != nil {
215+
log.Printf("failed to unregister: %s", err)
216+
}
220217

221218
close(s.shouldShutdown)
222219

@@ -228,20 +225,17 @@ func (s *Server) shutdown() error {
228225
}
229226

230227
// Wait for connection and routines to be closed
231-
s.shutdownEnd.Wait()
228+
s.refCount.Wait()
232229
s.isShutdown = true
233-
234-
return err
235230
}
236231

237-
// recv is a long running routine to receive packets from an interface
232+
// recv4 is a long running routine to receive packets from an interface
238233
func (s *Server) recv4(c *ipv4.PacketConn) {
234+
defer s.refCount.Done()
239235
if c == nil {
240236
return
241237
}
242238
buf := make([]byte, 65536)
243-
s.shutdownEnd.Add(1)
244-
defer s.shutdownEnd.Done()
245239
for {
246240
select {
247241
case <-s.shouldShutdown:
@@ -260,14 +254,13 @@ func (s *Server) recv4(c *ipv4.PacketConn) {
260254
}
261255
}
262256

263-
// recv is a long running routine to receive packets from an interface
257+
// recv6 is a long running routine to receive packets from an interface
264258
func (s *Server) recv6(c *ipv6.PacketConn) {
259+
defer s.refCount.Done()
265260
if c == nil {
266261
return
267262
}
268263
buf := make([]byte, 65536)
269-
s.shutdownEnd.Add(1)
270-
defer s.shutdownEnd.Done()
271264
for {
272265
select {
273266
case <-s.shouldShutdown:
@@ -528,6 +521,8 @@ func (s *Server) serviceTypeName(resp *dns.Msg, ttl uint32) {
528521
// Perform probing & announcement
529522
//TODO: implement a proper probing & conflict resolution
530523
func (s *Server) probe() {
524+
defer s.refCount.Done()
525+
531526
q := new(dns.Msg)
532527
q.SetQuestion(s.service.ServiceInstanceName(), dns.TypePTR)
533528
q.RecursionDesired = false
@@ -555,16 +550,25 @@ func (s *Server) probe() {
555550
}
556551
q.Ns = []dns.RR{srv, txt}
557552

558-
randomizer := rand.New(rand.NewSource(time.Now().UnixNano()))
559-
560553
// Wait for a random duration uniformly distributed between 0 and 250 ms
561554
// before sending the first probe packet.
562-
time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond)
555+
timer := time.NewTimer(time.Duration(rand.Intn(250)) * time.Millisecond)
556+
defer timer.Stop()
557+
select {
558+
case <-timer.C:
559+
case <-s.shouldShutdown:
560+
return
561+
}
563562
for i := 0; i < 3; i++ {
564563
if err := s.multicastResponse(q, 0); err != nil {
565564
log.Println("[ERR] zeroconf: failed to send probe:", err.Error())
566565
}
567-
time.Sleep(250 * time.Millisecond)
566+
timer.Reset(250 * time.Millisecond)
567+
select {
568+
case <-timer.C:
569+
case <-s.shouldShutdown:
570+
return
571+
}
568572
}
569573

570574
// From RFC6762
@@ -573,7 +577,7 @@ func (s *Server) probe() {
573577
// packet loss, a responder MAY send up to eight unsolicited responses,
574578
// provided that the interval between unsolicited responses increases by
575579
// at least a factor of two with every response sent.
576-
timeout := 1 * time.Second
580+
timeout := time.Second
577581
for i := 0; i < multicastRepetitions; i++ {
578582
for _, intf := range s.ifaces {
579583
resp := new(dns.Msg)
@@ -587,7 +591,12 @@ func (s *Server) probe() {
587591
log.Println("[ERR] zeroconf: failed to send announcement:", err.Error())
588592
}
589593
}
590-
time.Sleep(timeout)
594+
timer.Reset(timeout)
595+
select {
596+
case <-timer.C:
597+
case <-s.shouldShutdown:
598+
return
599+
}
591600
timeout *= 2
592601
}
593602
}
@@ -719,7 +728,7 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro
719728
}
720729
}
721730

722-
// multicastResponse us used to send a multicast response packet
731+
// multicastResponse is used to send a multicast response packet
723732
func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error {
724733
buf, err := msg.Pack()
725734
if err != nil {

service_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,24 @@ func startMDNS(t *testing.T, port int, name, service, domain string) {
2525
log.Printf("Published service: %s, type: %s, domain: %s", name, service, domain)
2626
}
2727

28+
func TestQuickShutdown(t *testing.T) {
29+
server, err := Register(mdnsName, mdnsService, mdnsDomain, mdnsPort, []string{"txtv=0", "lo=1", "la=2"}, nil)
30+
if err != nil {
31+
t.Fatal(err)
32+
}
33+
34+
done := make(chan struct{})
35+
go func() {
36+
defer close(done)
37+
server.Shutdown()
38+
}()
39+
select {
40+
case <-done:
41+
case <-time.After(500 * time.Millisecond):
42+
t.Fatal("shutdown took longer than 500ms")
43+
}
44+
}
45+
2846
func TestBasic(t *testing.T) {
2947
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
3048
defer cancel()

0 commit comments

Comments
 (0)