Skip to content

Commit 3fa606c

Browse files
committed
clientconn: Wait for all goroutines on close
Three goroutines could outlive a call to ClientConn.close(). Add mechanics to cancel them and wait for them to complete when closing a client connection. RELEASE NOTES: - Closing a client connection will cancel all pending goroutines and block until they complete. Signed-off-by: Tom Wieczorek <[email protected]>
1 parent 50c6321 commit 3fa606c

File tree

6 files changed

+116
-38
lines changed

6 files changed

+116
-38
lines changed

balancer_wrapper.go

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,10 @@ type acBalancerWrapper struct {
282282
// dropped or updated. This is required as closures can't be compared for
283283
// equality.
284284
healthData *healthData
285+
286+
shutdownMu sync.Mutex
287+
shutdownCh chan struct{}
288+
activeGofuncs sync.WaitGroup
285289
}
286290

287291
// healthData holds data related to health state reporting.
@@ -347,16 +351,43 @@ func (acbw *acBalancerWrapper) String() string {
347351
}
348352

349353
func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) {
350-
acbw.ac.updateAddrs(addrs)
354+
acbw.ac.updateAddrs(acbw, addrs)
351355
}
352356

353357
func (acbw *acBalancerWrapper) Connect() {
354-
go acbw.ac.connect()
358+
acbw.goFunc(acbw.ac.connect)
359+
}
360+
361+
func (acbw *acBalancerWrapper) goFunc(fn func(shutdown <-chan struct{})) {
362+
acbw.shutdownMu.Lock()
363+
defer acbw.shutdownMu.Unlock()
364+
365+
shutdown := acbw.shutdownCh
366+
if shutdown == nil {
367+
shutdown = make(chan struct{})
368+
acbw.shutdownCh = shutdown
369+
}
370+
371+
acbw.activeGofuncs.Add(1)
372+
go func() {
373+
defer acbw.activeGofuncs.Done()
374+
fn(shutdown)
375+
}()
355376
}
356377

357378
func (acbw *acBalancerWrapper) Shutdown() {
358379
acbw.closeProducers()
359380
acbw.ccb.cc.removeAddrConn(acbw.ac, errConnDrain)
381+
382+
acbw.shutdownMu.Lock()
383+
defer acbw.shutdownMu.Unlock()
384+
385+
shutdown := acbw.shutdownCh
386+
acbw.shutdownCh = nil
387+
if shutdown != nil {
388+
close(shutdown)
389+
acbw.activeGofuncs.Wait()
390+
}
360391
}
361392

362393
// NewStream begins a streaming RPC on the addrConn. If the addrConn is not

clientconn.go

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -925,25 +925,24 @@ func (cc *ClientConn) incrCallsFailed() {
925925
// connect starts creating a transport.
926926
// It does nothing if the ac is not IDLE.
927927
// TODO(bar) Move this to the addrConn section.
928-
func (ac *addrConn) connect() error {
928+
func (ac *addrConn) connect(abort <-chan struct{}) {
929929
ac.mu.Lock()
930930
if ac.state == connectivity.Shutdown {
931931
if logger.V(2) {
932932
logger.Infof("connect called on shutdown addrConn; ignoring.")
933933
}
934934
ac.mu.Unlock()
935-
return errConnClosing
935+
return
936936
}
937937
if ac.state != connectivity.Idle {
938938
if logger.V(2) {
939939
logger.Infof("connect called on addrConn in non-idle state (%v); ignoring.", ac.state)
940940
}
941941
ac.mu.Unlock()
942-
return nil
942+
return
943943
}
944944

945-
ac.resetTransportAndUnlock()
946-
return nil
945+
ac.resetTransportAndUnlock(abort)
947946
}
948947

949948
// equalAddressIgnoringBalAttributes returns true is a and b are considered equal.
@@ -962,7 +961,7 @@ func equalAddressesIgnoringBalAttributes(a, b []resolver.Address) bool {
962961

963962
// updateAddrs updates ac.addrs with the new addresses list and handles active
964963
// connections or connection attempts.
965-
func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
964+
func (ac *addrConn) updateAddrs(acbw *acBalancerWrapper, addrs []resolver.Address) {
966965
addrs = copyAddresses(addrs)
967966
limit := len(addrs)
968967
if limit > 5 {
@@ -1018,7 +1017,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
10181017

10191018
// Since we were connecting/connected, we should start a new connection
10201019
// attempt.
1021-
go ac.resetTransportAndUnlock()
1020+
acbw.goFunc(ac.resetTransportAndUnlock)
10221021
}
10231022

10241023
// getServerName determines the serverName to be used in the connection
@@ -1249,9 +1248,17 @@ func (ac *addrConn) adjustParams(r transport.GoAwayReason) {
12491248
// resetTransportAndUnlock unconditionally connects the addrConn.
12501249
//
12511250
// ac.mu must be held by the caller, and this function will guarantee it is released.
1252-
func (ac *addrConn) resetTransportAndUnlock() {
1253-
acCtx := ac.ctx
1254-
if acCtx.Err() != nil {
1251+
func (ac *addrConn) resetTransportAndUnlock(abort <-chan struct{}) {
1252+
ctx, cancel := context.WithCancel(ac.ctx)
1253+
go func() {
1254+
select {
1255+
case <-abort:
1256+
cancel()
1257+
case <-ctx.Done():
1258+
}
1259+
}()
1260+
1261+
if ctx.Err() != nil {
12551262
ac.mu.Unlock()
12561263
return
12571264
}
@@ -1279,12 +1286,12 @@ func (ac *addrConn) resetTransportAndUnlock() {
12791286
ac.updateConnectivityState(connectivity.Connecting, nil)
12801287
ac.mu.Unlock()
12811288

1282-
if err := ac.tryAllAddrs(acCtx, addrs, connectDeadline); err != nil {
1289+
if err := ac.tryAllAddrs(ctx, addrs, connectDeadline); err != nil {
12831290
// TODO: #7534 - Move re-resolution requests into the pick_first LB policy
12841291
// to ensure one resolution request per pass instead of per subconn failure.
12851292
ac.cc.resolveNow(resolver.ResolveNowOptions{})
12861293
ac.mu.Lock()
1287-
if acCtx.Err() != nil {
1294+
if ctx.Err() != nil {
12881295
// addrConn was torn down.
12891296
ac.mu.Unlock()
12901297
return
@@ -1305,13 +1312,13 @@ func (ac *addrConn) resetTransportAndUnlock() {
13051312
ac.mu.Unlock()
13061313
case <-b:
13071314
timer.Stop()
1308-
case <-acCtx.Done():
1315+
case <-ctx.Done():
13091316
timer.Stop()
13101317
return
13111318
}
13121319

13131320
ac.mu.Lock()
1314-
if acCtx.Err() == nil {
1321+
if ctx.Err() == nil {
13151322
ac.updateConnectivityState(connectivity.Idle, err)
13161323
}
13171324
ac.mu.Unlock()
@@ -1366,6 +1373,9 @@ func (ac *addrConn) tryAllAddrs(ctx context.Context, addrs []resolver.Address, c
13661373
// new transport.
13671374
func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error {
13681375
addr.ServerName = ac.cc.getServerName(addr)
1376+
1377+
var healthCheckStarted atomic.Bool
1378+
healthCheckDone := make(chan struct{})
13691379
hctx, hcancel := context.WithCancel(ctx)
13701380

13711381
onClose := func(r transport.GoAwayReason) {
@@ -1394,6 +1404,9 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
13941404
// Always go idle and wait for the LB policy to initiate a new
13951405
// connection attempt.
13961406
ac.updateConnectivityState(connectivity.Idle, nil)
1407+
if healthCheckStarted.Load() {
1408+
<-healthCheckDone
1409+
}
13971410
}
13981411

13991412
connectCtx, cancel := context.WithDeadline(ctx, connectDeadline)
@@ -1406,29 +1419,35 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
14061419
logger.Infof("Creating new client transport to %q: %v", addr, err)
14071420
}
14081421
// newTr is either nil, or closed.
1409-
hcancel()
14101422
channelz.Warningf(logger, ac.channelz, "grpc: addrConn.createTransport failed to connect to %s. Err: %v", addr, err)
14111423
return err
14121424
}
14131425

1414-
ac.mu.Lock()
1415-
defer ac.mu.Unlock()
1426+
acMu := &ac.mu
1427+
acMu.Lock()
1428+
defer func() {
1429+
if acMu != nil {
1430+
acMu.Unlock()
1431+
}
1432+
}()
14161433
if ctx.Err() != nil {
14171434
// This can happen if the subConn was removed while in `Connecting`
14181435
// state. tearDown() would have set the state to `Shutdown`, but
14191436
// would not have closed the transport since ac.transport would not
14201437
// have been set at that point.
1421-
//
1422-
// We run this in a goroutine because newTr.Close() calls onClose()
1438+
1439+
// We unlock ac.mu because newTr.Close() calls onClose()
14231440
// inline, which requires locking ac.mu.
1424-
//
1441+
acMu.Unlock()
1442+
acMu = nil
1443+
14251444
// The error we pass to Close() is immaterial since there are no open
14261445
// streams at this point, so no trailers with error details will be sent
14271446
// out. We just need to pass a non-nil error.
14281447
//
14291448
// This can also happen when updateAddrs is called during a connection
14301449
// attempt.
1431-
go newTr.Close(transport.ErrConnClosing)
1450+
newTr.Close(transport.ErrConnClosing)
14321451
return nil
14331452
}
14341453
if hctx.Err() != nil {
@@ -1440,7 +1459,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
14401459
}
14411460
ac.curAddr = addr
14421461
ac.transport = newTr
1443-
ac.startHealthCheck(hctx) // Will set state to READY if appropriate.
1462+
healthCheckStarted.Store(ac.startHealthCheck(hctx, healthCheckDone)) // Will set state to READY if appropriate.
14441463
return nil
14451464
}
14461465

@@ -1456,7 +1475,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
14561475
// It sets addrConn to READY if the health checking stream is not started.
14571476
//
14581477
// Caller must hold ac.mu.
1459-
func (ac *addrConn) startHealthCheck(ctx context.Context) {
1478+
func (ac *addrConn) startHealthCheck(ctx context.Context, done chan<- struct{}) bool {
14601479
var healthcheckManagingState bool
14611480
defer func() {
14621481
if !healthcheckManagingState {
@@ -1465,22 +1484,22 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
14651484
}()
14661485

14671486
if ac.cc.dopts.disableHealthCheck {
1468-
return
1487+
return false
14691488
}
14701489
healthCheckConfig := ac.cc.healthCheckConfig()
14711490
if healthCheckConfig == nil {
1472-
return
1491+
return false
14731492
}
14741493
if !ac.scopts.HealthCheckEnabled {
1475-
return
1494+
return false
14761495
}
14771496
healthCheckFunc := internal.HealthCheckFunc
14781497
if healthCheckFunc == nil {
14791498
// The health package is not imported to set health check function.
14801499
//
14811500
// TODO: add a link to the health check doc in the error message.
14821501
channelz.Error(logger, ac.channelz, "Health check is requested but health check function is not set.")
1483-
return
1502+
return false
14841503
}
14851504

14861505
healthcheckManagingState = true
@@ -1506,6 +1525,7 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
15061525
}
15071526
// Start the health checking stream.
15081527
go func() {
1528+
defer close(done)
15091529
err := healthCheckFunc(ctx, newStream, setConnectivityState, healthCheckConfig.ServiceName)
15101530
if err != nil {
15111531
if status.Code(err) == codes.Unimplemented {
@@ -1515,6 +1535,7 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
15151535
}
15161536
}
15171537
}()
1538+
return true
15181539
}
15191540

15201541
func (ac *addrConn) resetConnectBackoff() {

internal/balancer/gracefulswitch/gracefulswitch.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ type Balancer struct {
6767
// balancerCurrent before the UpdateSubConnState is called on the
6868
// balancerCurrent.
6969
currentMu sync.Mutex
70+
71+
// activeGoroutines tracks all the goroutines that this balancer has started
72+
// and that should be waited on when the balancer closes.
73+
activeGoroutines sync.WaitGroup
7074
}
7175

7276
// swap swaps out the current lb with the pending lb and updates the ClientConn.
@@ -76,7 +80,9 @@ func (gsb *Balancer) swap() {
7680
cur := gsb.balancerCurrent
7781
gsb.balancerCurrent = gsb.balancerPending
7882
gsb.balancerPending = nil
83+
gsb.activeGoroutines.Add(1)
7984
go func() {
85+
defer gsb.activeGoroutines.Done()
8086
gsb.currentMu.Lock()
8187
defer gsb.currentMu.Unlock()
8288
cur.Close()
@@ -274,6 +280,7 @@ func (gsb *Balancer) Close() {
274280

275281
currentBalancerToClose.Close()
276282
pendingBalancerToClose.Close()
283+
gsb.activeGoroutines.Wait()
277284
}
278285

279286
// balancerWrapper wraps a balancer.Balancer, and overrides some Balancer
@@ -324,7 +331,12 @@ func (bw *balancerWrapper) UpdateState(state balancer.State) {
324331
defer bw.gsb.mu.Unlock()
325332
bw.lastState = state
326333

334+
// If Close() acquires the mutex before UpdateState(), the balancer
335+
// will already have been removed from the current or pending state when
336+
// reaching this point.
327337
if !bw.gsb.balancerCurrentOrPending(bw) {
338+
// Returning here ensures that (*Balancer).swap() is not invoked after
339+
// (*Balancer).Close() and therefore prevents "use after close".
328340
return
329341
}
330342

internal/testutils/pipe_listener.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package testutils
2121

2222
import (
23+
"context"
2324
"errors"
2425
"net"
2526
"time"
@@ -81,11 +82,20 @@ func (p *PipeListener) Addr() net.Addr {
8182
// Dialer dials a connection.
8283
func (p *PipeListener) Dialer() func(string, time.Duration) (net.Conn, error) {
8384
return func(string, time.Duration) (net.Conn, error) {
85+
return p.ContextDialer()(context.Background(), "")
86+
}
87+
}
88+
89+
// ContextDialer dials a using a context.
90+
func (p *PipeListener) ContextDialer() func(context.Context, string) (net.Conn, error) {
91+
return func(ctx context.Context, _ string) (net.Conn, error) {
8492
connChan := make(chan net.Conn)
8593
select {
8694
case p.c <- connChan:
8795
case <-p.done:
8896
return nil, errClosed
97+
case <-ctx.Done():
98+
return nil, context.Cause(ctx)
8999
}
90100
conn, ok := <-connChan
91101
if !ok {

test/clientconn_state_transition_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ func testStateTransitionSingleAddress(t *testing.T, wantStates []connectivity.St
185185

186186
dopts := []grpc.DialOption{
187187
grpc.WithTransportCredentials(insecure.NewCredentials()),
188-
grpc.WithDialer(pl.Dialer()),
188+
grpc.WithContextDialer(pl.ContextDialer()),
189189
grpc.WithConnectParams(grpc.ConnectParams{
190190
Backoff: backoff.Config{},
191191
MinConnectTimeout: 100 * time.Millisecond,

0 commit comments

Comments
 (0)