@@ -49,6 +49,10 @@ type SessionManager struct {
4949 connections map [uint32 ]* mysql.Conn
5050 lastPid uint64
5151 ctxFactory sql.ContextFactory
52+ // Implements WaitForClosedConnections(), which is only used
53+ // at server shutdown to allow the integrator to ensure that
54+ // no connections are being handled by handlers.
55+ wg sync.WaitGroup
5256}
5357
5458// NewSessionManager creates a SessionManager with the given SessionBuilder.
@@ -82,6 +86,13 @@ func (s *SessionManager) nextPid() uint64 {
8286 return s .lastPid
8387}
8488
89+ // Block the calling thread until all known connections are closed. It
90+ // is an error to call this concurrently while the server might still
91+ // be accepting new connections.
92+ func (s * SessionManager ) WaitForClosedConnections () {
93+ s .wg .Wait ()
94+ }
95+
8596// AddConn adds a connection to be tracked by the SessionManager. Should be called as
8697// soon as possible after the server has accepted the connection. Results in
8798// the connection being tracked by ProcessList and being available through
@@ -93,6 +104,7 @@ func (s *SessionManager) AddConn(conn *mysql.Conn) {
93104 defer s .mu .Unlock ()
94105 s .connections [conn .ConnectionID ] = conn
95106 s .processlist .AddConnection (conn .ConnectionID , conn .RemoteAddr ().String ())
107+ s .wg .Add (1 )
96108}
97109
98110// NewSession creates a Session for the given connection and saves it to the session pool.
@@ -270,6 +282,7 @@ func (s *SessionManager) KillConnection(connID uint32) error {
270282func (s * SessionManager ) RemoveConn (conn * mysql.Conn ) {
271283 s .mu .Lock ()
272284 defer s .mu .Unlock ()
285+ s .wg .Done ()
273286 if cur , ok := s .sessions [conn .ConnectionID ]; ok {
274287 sql .SessionEnd (cur )
275288 }
0 commit comments