Skip to content

Commit cfbb600

Browse files
edanielsDivjot Arora
authored andcommitted
GODRIVER-1345 - Respect context.Context cancellation in Disconnect (#199)
1 parent 2815d49 commit cfbb600

File tree

1 file changed

+32
-8
lines changed

1 file changed

+32
-8
lines changed

x/mongo/driver/topology/server.go

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,10 @@ type Server struct {
8989
sem *semaphore.Weighted
9090

9191
// goroutine management fields
92-
done chan struct{}
93-
checkNow chan struct{}
94-
closewg sync.WaitGroup
92+
done chan struct{}
93+
checkNow chan struct{}
94+
disconnecting chan struct{}
95+
closewg sync.WaitGroup
9596

9697
// description related fields
9798
desc atomic.Value // holds a description.Server
@@ -139,8 +140,9 @@ func NewServer(addr address.Address, opts ...ServerOption) (*Server, error) {
139140

140141
sem: semaphore.NewWeighted(int64(maxConns)),
141142

142-
done: make(chan struct{}),
143-
checkNow: make(chan struct{}, 1),
143+
done: make(chan struct{}),
144+
checkNow: make(chan struct{}, 1),
145+
disconnecting: make(chan struct{}),
144146

145147
subscribers: make(map[uint64]chan description.Server),
146148
}
@@ -193,7 +195,14 @@ func (s *Server) Disconnect(ctx context.Context) error {
193195

194196
// For every call to Connect there must be at least 1 goroutine that is
195197
// waiting on the done channel.
196-
s.done <- struct{}{}
198+
select {
199+
case <-ctx.Done():
200+
// signal a disconnect and still wait for receiver of done
201+
// to finish.
202+
close(s.disconnecting)
203+
s.done <- struct{}{}
204+
case s.done <- struct{}{}:
205+
}
197206
err := s.pool.disconnect(ctx)
198207
if err != nil {
199208
return err
@@ -398,6 +407,13 @@ func (s *Server) update() {
398407
conn.nc.Close()
399408
}
400409
for {
410+
select {
411+
case <-done:
412+
closeServer()
413+
return
414+
default:
415+
}
416+
401417
select {
402418
case <-heartbeatTicker.C:
403419
case <-checkNow:
@@ -463,7 +479,15 @@ func (s *Server) heartbeat(conn *connection) (description.Server, *connection) {
463479
var desc description.Server
464480
var set bool
465481
var err error
466-
ctx := context.Background()
482+
ctx, cancel := context.WithCancel(context.Background())
483+
defer cancel()
484+
go func() {
485+
select {
486+
case <-ctx.Done():
487+
case <-s.disconnecting:
488+
cancel()
489+
}
490+
}()
467491

468492
for i := 1; i <= maxRetry; i++ {
469493
var now time.Time
@@ -499,7 +523,7 @@ func (s *Server) heartbeat(conn *connection) (description.Server, *connection) {
499523

500524
conn.connect(ctx)
501525

502-
err := conn.wait()
526+
err = conn.wait()
503527
if err == nil {
504528
descPtr = &conn.desc
505529
}

0 commit comments

Comments
 (0)