Skip to content

Commit ef66069

Browse files
belakprogrium
authored andcommitted
Update shutdown to use a WaitGroup rather than sleeping (#74)
1 parent 66f55c8 commit ef66069

File tree

1 file changed

+19
-28
lines changed

1 file changed

+19
-28
lines changed

server.go

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ type Server struct {
3535

3636
channelHandlers map[string]channelHandler
3737

38-
mu sync.Mutex
39-
listeners map[net.Listener]struct{}
40-
conns map[*gossh.ServerConn]struct{}
41-
doneChan chan struct{}
38+
listenerWg sync.WaitGroup
39+
mu sync.Mutex
40+
listeners map[net.Listener]struct{}
41+
conns map[*gossh.ServerConn]struct{}
42+
doneChan chan struct{}
4243
}
4344

4445
// internal for now
@@ -110,15 +111,6 @@ func (srv *Server) Close() error {
110111
return err
111112
}
112113

113-
// shutdownPollInterval is how often we poll for quiescence
114-
// during Server.Shutdown. This is lower during tests, to
115-
// speed up tests.
116-
// Ideally we could find a solution that doesn't involve polling,
117-
// but which also doesn't have a high runtime cost (and doesn't
118-
// involve any contentious mutexes), but that is left as an
119-
// exercise for the reader.
120-
var shutdownPollInterval = 500 * time.Millisecond
121-
122114
// Shutdown gracefully shuts down the server without interrupting any
123115
// active connections. Shutdown works by first closing all open
124116
// listeners, and then waiting indefinitely for connections to close.
@@ -129,22 +121,19 @@ func (srv *Server) Shutdown(ctx context.Context) error {
129121
lnerr := srv.closeListenersLocked()
130122
srv.closeDoneChanLocked()
131123
srv.mu.Unlock()
132-
ticker := time.NewTicker(shutdownPollInterval)
133-
defer ticker.Stop()
134-
for {
135-
srv.mu.Lock()
136-
conns := len(srv.conns)
137-
srv.mu.Unlock()
138-
if conns == 0 {
139-
return lnerr
140-
}
141-
select {
142-
case <-ctx.Done():
143-
return ctx.Err()
144-
case <-ticker.C:
145-
}
146-
}
147124

125+
listenerWgChan := make(chan struct{}, 1)
126+
go func() {
127+
srv.listenerWg.Wait()
128+
listenerWgChan <- struct{}{}
129+
}()
130+
131+
select {
132+
case <-ctx.Done():
133+
return ctx.Err()
134+
case <-listenerWgChan:
135+
return lnerr
136+
}
148137
}
149138

150139
// Serve accepts incoming connections on the Listener l, creating a new
@@ -315,8 +304,10 @@ func (srv *Server) trackListener(ln net.Listener, add bool) {
315304
srv.doneChan = nil
316305
}
317306
srv.listeners[ln] = struct{}{}
307+
srv.listenerWg.Add(1)
318308
} else {
319309
delete(srv.listeners, ln)
310+
srv.listenerWg.Done()
320311
}
321312
}
322313

0 commit comments

Comments
 (0)