@@ -35,10 +35,11 @@ type Server struct {
35
35
36
36
channelHandlers map [string ]channelHandler
37
37
38
- mu sync.Mutex
39
- listeners map [net.Listener ]struct {}
40
- conns map [* gossh.ServerConn ]struct {}
41
- doneChan chan struct {}
38
+ listenerWg sync.WaitGroup
39
+ mu sync.Mutex
40
+ listeners map [net.Listener ]struct {}
41
+ conns map [* gossh.ServerConn ]struct {}
42
+ doneChan chan struct {}
42
43
}
43
44
44
45
// internal for now
@@ -110,15 +111,6 @@ func (srv *Server) Close() error {
110
111
return err
111
112
}
112
113
113
- // shutdownPollInterval is how often we poll for quiescence
114
- // during Server.Shutdown. This is lower during tests, to
115
- // speed up tests.
116
- // Ideally we could find a solution that doesn't involve polling,
117
- // but which also doesn't have a high runtime cost (and doesn't
118
- // involve any contentious mutexes), but that is left as an
119
- // exercise for the reader.
120
- var shutdownPollInterval = 500 * time .Millisecond
121
-
122
114
// Shutdown gracefully shuts down the server without interrupting any
123
115
// active connections. Shutdown works by first closing all open
124
116
// listeners, and then waiting indefinitely for connections to close.
@@ -129,22 +121,19 @@ func (srv *Server) Shutdown(ctx context.Context) error {
129
121
lnerr := srv .closeListenersLocked ()
130
122
srv .closeDoneChanLocked ()
131
123
srv .mu .Unlock ()
132
- ticker := time .NewTicker (shutdownPollInterval )
133
- defer ticker .Stop ()
134
- for {
135
- srv .mu .Lock ()
136
- conns := len (srv .conns )
137
- srv .mu .Unlock ()
138
- if conns == 0 {
139
- return lnerr
140
- }
141
- select {
142
- case <- ctx .Done ():
143
- return ctx .Err ()
144
- case <- ticker .C :
145
- }
146
- }
147
124
125
+ listenerWgChan := make (chan struct {}, 1 )
126
+ go func () {
127
+ srv .listenerWg .Wait ()
128
+ listenerWgChan <- struct {}{}
129
+ }()
130
+
131
+ select {
132
+ case <- ctx .Done ():
133
+ return ctx .Err ()
134
+ case <- listenerWgChan :
135
+ return lnerr
136
+ }
148
137
}
149
138
150
139
// Serve accepts incoming connections on the Listener l, creating a new
@@ -315,8 +304,10 @@ func (srv *Server) trackListener(ln net.Listener, add bool) {
315
304
srv .doneChan = nil
316
305
}
317
306
srv .listeners [ln ] = struct {}{}
307
+ srv .listenerWg .Add (1 )
318
308
} else {
319
309
delete (srv .listeners , ln )
310
+ srv .listenerWg .Done ()
320
311
}
321
312
}
322
313
0 commit comments