@@ -89,9 +89,10 @@ type Server struct {
89
89
sem * semaphore.Weighted
90
90
91
91
// goroutine management fields
92
- done chan struct {}
93
- checkNow chan struct {}
94
- closewg sync.WaitGroup
92
+ done chan struct {}
93
+ checkNow chan struct {}
94
+ disconnecting chan struct {}
95
+ closewg sync.WaitGroup
95
96
96
97
// description related fields
97
98
desc atomic.Value // holds a description.Server
@@ -139,8 +140,9 @@ func NewServer(addr address.Address, opts ...ServerOption) (*Server, error) {
139
140
140
141
sem : semaphore .NewWeighted (int64 (maxConns )),
141
142
142
- done : make (chan struct {}),
143
- checkNow : make (chan struct {}, 1 ),
143
+ done : make (chan struct {}),
144
+ checkNow : make (chan struct {}, 1 ),
145
+ disconnecting : make (chan struct {}),
144
146
145
147
subscribers : make (map [uint64 ]chan description.Server ),
146
148
}
@@ -193,7 +195,14 @@ func (s *Server) Disconnect(ctx context.Context) error {
193
195
194
196
// For every call to Connect there must be at least 1 goroutine that is
195
197
// waiting on the done channel.
196
- s .done <- struct {}{}
198
+ select {
199
+ case <- ctx .Done ():
200
+ // signal a disconnect and still wait for receiver of done
201
+ // to finish.
202
+ close (s .disconnecting )
203
+ s .done <- struct {}{}
204
+ case s .done <- struct {}{}:
205
+ }
197
206
err := s .pool .disconnect (ctx )
198
207
if err != nil {
199
208
return err
@@ -398,6 +407,13 @@ func (s *Server) update() {
398
407
conn .nc .Close ()
399
408
}
400
409
for {
410
+ select {
411
+ case <- done :
412
+ closeServer ()
413
+ return
414
+ default :
415
+ }
416
+
401
417
select {
402
418
case <- heartbeatTicker .C :
403
419
case <- checkNow :
@@ -463,7 +479,15 @@ func (s *Server) heartbeat(conn *connection) (description.Server, *connection) {
463
479
var desc description.Server
464
480
var set bool
465
481
var err error
466
- ctx := context .Background ()
482
+ ctx , cancel := context .WithCancel (context .Background ())
483
+ defer cancel ()
484
+ go func () {
485
+ select {
486
+ case <- ctx .Done ():
487
+ case <- s .disconnecting :
488
+ cancel ()
489
+ }
490
+ }()
467
491
468
492
for i := 1 ; i <= maxRetry ; i ++ {
469
493
var now time.Time
@@ -499,7 +523,7 @@ func (s *Server) heartbeat(conn *connection) (description.Server, *connection) {
499
523
500
524
conn .connect (ctx )
501
525
502
- err : = conn .wait ()
526
+ err = conn .wait ()
503
527
if err == nil {
504
528
descPtr = & conn .desc
505
529
}
0 commit comments