@@ -11,7 +11,6 @@ import (
1111	"os" 
1212	"strings" 
1313	"sync" 
14- 	"sync/atomic" 
1514	"syscall" 
1615	"time" 
1716
@@ -30,12 +29,15 @@ type ServeFunction = func(net.Listener) error
3029
3130// Server represents our graceful server 
3231type  Server  struct  {
33- 	network               string 
34- 	address               string 
35- 	listener              net.Listener 
36- 	wg                    sync.WaitGroup 
37- 	state                 state 
38- 	lock                  * sync.RWMutex 
32+ 	network   string 
33+ 	address   string 
34+ 	listener  net.Listener 
35+ 
36+ 	lock           sync.RWMutex 
37+ 	state          state 
38+ 	connCounter    int64 
39+ 	connEmptyCond  * sync.Cond 
40+ 
3941	BeforeBegin           func (network , address  string )
4042	OnShutdown            func ()
4143	PerWriteTimeout       time.Duration 
@@ -50,14 +52,13 @@ func NewServer(network, address, name string) *Server {
5052		log .Info ("Starting new %s server: %s:%s on PID: %d" , name , network , address , os .Getpid ())
5153	}
5254	srv  :=  & Server {
53- 		wg :                   sync.WaitGroup {},
5455		state :                stateInit ,
55- 		lock :                 & sync.RWMutex {},
5656		network :              network ,
5757		address :              address ,
5858		PerWriteTimeout :      setting .PerWriteTimeout ,
5959		PerWritePerKbTimeout : setting .PerWritePerKbTimeout ,
6060	}
61+ 	srv .connEmptyCond  =  sync .NewCond (& srv .lock )
6162
6263	srv .BeforeBegin  =  func (network , addr  string ) {
6364		log .Debug ("Starting server on %s:%s (PID: %d)" , network , addr , syscall .Getpid ())
@@ -154,7 +155,7 @@ func (srv *Server) Serve(serve ServeFunction) error {
154155	GetManager ().RegisterServer ()
155156	err  :=  serve (srv .listener )
156157	log .Debug ("Waiting for connections to finish... (PID: %d)" , syscall .Getpid ())
157- 	srv .wg . Wait ()
158+ 	srv .waitForActiveConnections ()
158159	srv .setState (stateTerminate )
159160	GetManager ().ServerDone ()
160161	// use of closed means that the listeners are closed - i.e. we should be shutting down - return nil 
@@ -178,63 +179,87 @@ func (srv *Server) setState(st state) {
178179	srv .state  =  st 
179180}
180181
182+ func  (srv  * Server ) waitForActiveConnections () {
183+ 	srv .lock .Lock ()
184+ 	for  srv .connCounter  >  0  {
185+ 		srv .connEmptyCond .Wait ()
186+ 	}
187+ 	srv .lock .Unlock ()
188+ }
189+ 
190+ func  (srv  * Server ) wrapConnection (c  net.Conn ) (net.Conn , error ) {
191+ 	srv .lock .Lock ()
192+ 	defer  srv .lock .Unlock ()
193+ 
194+ 	if  srv .state  !=  stateRunning  {
195+ 		_  =  c .Close ()
196+ 		return  nil , syscall .EINVAL  // same as AcceptTCP 
197+ 	}
198+ 
199+ 	srv .connCounter ++ 
200+ 	return  & wrappedConn {Conn : c , server : srv }, nil 
201+ }
202+ 
203+ func  (srv  * Server ) removeConnection (_  * wrappedConn ) {
204+ 	srv .lock .Lock ()
205+ 	defer  srv .lock .Unlock ()
206+ 
207+ 	srv .connCounter -- 
208+ 	if  srv .connCounter  <=  0  {
209+ 		srv .connEmptyCond .Broadcast ()
210+ 	}
211+ }
212+ 
213+ // closeAllConnections forcefully closes all active connections 
214+ func  (srv  * Server ) closeAllConnections () {
215+ 	srv .lock .Lock ()
216+ 	if  srv .connCounter  >  0  {
217+ 		log .Warn ("After graceful shutdown period, %d connections are still active. Forcefully close." , srv .connCounter )
218+ 		srv .connCounter  =  0  // OS will close all the connections after the process exits, so we just assume there is no active connection now 
219+ 	}
220+ 	srv .lock .Unlock ()
221+ 	srv .connEmptyCond .Broadcast ()
222+ }
223+ 
181224type  filer  interface  {
182225	File () (* os.File , error )
183226}
184227
185228type  wrappedListener  struct  {
186229	net.Listener 
187- 	stopped  bool 
188- 	server   * Server 
230+ 	server  * Server 
189231}
190232
233+ var  (
234+ 	_  net.Listener  =  (* wrappedListener )(nil )
235+ 	_  filer         =  (* wrappedListener )(nil )
236+ )
237+ 
191238func  newWrappedListener (l  net.Listener , srv  * Server ) * wrappedListener  {
192239	return  & wrappedListener {
193240		Listener : l ,
194241		server :   srv ,
195242	}
196243}
197244
198- func  (wl  * wrappedListener ) Accept () (net.Conn , error ) {
199- 	var  c  net.Conn 
200- 	// Set keepalive on TCPListeners connections. 
245+ func  (wl  * wrappedListener ) Accept () (c  net.Conn , err  error ) {
201246	if  tcl , ok  :=  wl .Listener .(* net.TCPListener ); ok  {
247+ 		// Set keepalive on TCPListeners connections if possible, see http.tcpKeepAliveListener 
202248		tc , err  :=  tcl .AcceptTCP ()
203249		if  err  !=  nil  {
204250			return  nil , err 
205251		}
206- 		_  =  tc .SetKeepAlive (true )                   // see http.tcpKeepAliveListener 
207- 		_  =  tc .SetKeepAlivePeriod (3  *  time .Minute )  // see http.tcpKeepAliveListener 
252+ 		_  =  tc .SetKeepAlive (true )
253+ 		_  =  tc .SetKeepAlivePeriod (3  *  time .Minute )
208254		c  =  tc 
209255	} else  {
210- 		var  err  error 
211256		c , err  =  wl .Listener .Accept ()
212257		if  err  !=  nil  {
213258			return  nil , err 
214259		}
215260	}
216261
217- 	closed  :=  int32 (0 )
218- 
219- 	c  =  & wrappedConn {
220- 		Conn :                 c ,
221- 		server :               wl .server ,
222- 		closed :               & closed ,
223- 		perWriteTimeout :      wl .server .PerWriteTimeout ,
224- 		perWritePerKbTimeout : wl .server .PerWritePerKbTimeout ,
225- 	}
226- 
227- 	wl .server .wg .Add (1 )
228- 	return  c , nil 
229- }
230- 
231- func  (wl  * wrappedListener ) Close () error  {
232- 	if  wl .stopped  {
233- 		return  syscall .EINVAL 
234- 	}
235- 
236- 	wl .stopped  =  true 
237- 	return  wl .Listener .Close ()
262+ 	return  wl .server .wrapConnection (c )
238263}
239264
240265func  (wl  * wrappedListener ) File () (* os.File , error ) {
@@ -244,17 +269,14 @@ func (wl *wrappedListener) File() (*os.File, error) {
244269
245270type  wrappedConn  struct  {
246271	net.Conn 
247- 	server                * Server 
248- 	closed                * int32 
249- 	deadline              time.Time 
250- 	perWriteTimeout       time.Duration 
251- 	perWritePerKbTimeout  time.Duration 
272+ 	server    * Server 
273+ 	deadline  time.Time 
252274}
253275
254276func  (w  * wrappedConn ) Write (p  []byte ) (n  int , err  error ) {
255- 	if  w .perWriteTimeout  >  0  {
256- 		minTimeout  :=  time .Duration (len (p )/ 1024 ) *  w .perWritePerKbTimeout 
257- 		minDeadline  :=  time .Now ().Add (minTimeout ).Add (w .perWriteTimeout )
277+ 	if  w .server . PerWriteTimeout  >  0  {
278+ 		minTimeout  :=  time .Duration (len (p )/ 1024 ) *  w .server . PerWritePerKbTimeout 
279+ 		minDeadline  :=  time .Now ().Add (minTimeout ).Add (w .server . PerWriteTimeout )
258280
259281		w .deadline  =  w .deadline .Add (minTimeout )
260282		if  minDeadline .After (w .deadline ) {
@@ -266,19 +288,6 @@ func (w *wrappedConn) Write(p []byte) (n int, err error) {
266288}
267289
268290func  (w  * wrappedConn ) Close () error  {
269- 	if  atomic .CompareAndSwapInt32 (w .closed , 0 , 1 ) {
270- 		defer  func () {
271- 			if  err  :=  recover (); err  !=  nil  {
272- 				select  {
273- 				case  <- GetManager ().IsHammer ():
274- 					// Likely deadlocked request released at hammertime 
275- 					log .Warn ("Panic during connection close! %v. Likely there has been a deadlocked request which has been released by forced shutdown." , err )
276- 				default :
277- 					log .Error ("Panic during connection close! %v" , err )
278- 				}
279- 			}
280- 		}()
281- 		w .server .wg .Done ()
282- 	}
291+ 	w .server .removeConnection (w )
283292	return  w .Conn .Close ()
284293}
0 commit comments