@@ -517,15 +517,20 @@ func (s) TestControlChannelConnectivityStateTransitions(t *testing.T) {
517517 // Start an RLS server
518518 rlsServer , _ := rlstest .SetupFakeRLSServer (t , nil )
519519
520- // Setup callback to count invocations with synchronization
521- callbackCount := 0
520+ // Setup callback to count invocations
522521 var mu sync.Mutex
523- callbackInvoked := make (chan struct {}, 10 )
522+ var callbackCount int
523+ // Buffered channel to collect callback invocations without blocking
524+ callbackInvoked := make (chan struct {}, tt .wantCallbackCount + 5 )
524525 callback := func () {
525526 mu .Lock ()
526527 callbackCount ++
527528 mu .Unlock ()
528- callbackInvoked <- struct {}{}
529+ // Non-blocking send - if channel is full, we still counted it
530+ select {
531+ case callbackInvoked <- struct {}{}:
532+ default :
533+ }
529534 }
530535
531536 // Create control channel
@@ -535,53 +540,79 @@ func (s) TestControlChannelConnectivityStateTransitions(t *testing.T) {
535540 }
536541 defer ctrlCh .close ()
537542
538- // Wait for initial READY state by checking connectivity state buffer
543+ // Wait for initial READY state using state change notifications
539544 ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
540545 defer cancel ()
541- initialReady := false
542- for ! initialReady {
543- select {
544- case <- ctx .Done ():
545- t .Fatal ("Timeout waiting for initial READY state" )
546- default :
547- if ctrlCh .cc .GetState () == connectivity .Ready {
548- initialReady = true
549- } else {
550- time .Sleep (10 * time .Millisecond )
546+
547+ readyCh := make (chan struct {})
548+ go func () {
549+ for {
550+ state := ctrlCh .cc .GetState ()
551+ if state == connectivity .Ready {
552+ close (readyCh )
553+ return
554+ }
555+ if ! ctrlCh .cc .WaitForStateChange (ctx , state ) {
556+ return
551557 }
552558 }
559+ }()
560+
561+ select {
562+ case <- readyCh :
563+ // Initial READY state achieved
564+ case <- ctx .Done ():
565+ t .Fatal ("Timeout waiting for initial READY state" )
553566 }
554567
555- // Inject the test state sequence
568+ // Process states sequentially, waiting for callbacks when expected
569+ seenTransientFailure := false
570+ expectedCallbacks := 0
571+
556572 for _ , state := range tt .states {
573+ // Inject the state
557574 ctrlCh .OnMessage (state )
558- }
559575
560- // Wait for expected callbacks to be invoked
561- for i := 0 ; i < tt .wantCallbackCount ; i ++ {
562- select {
563- case <- callbackInvoked :
564- // Callback received as expected
565- case <- time .After (defaultTestTimeout ):
566- t .Fatalf ("Timeout waiting for callback %d/%d" , i + 1 , tt .wantCallbackCount )
576+ // Track if we're in a failure state
577+ if state == connectivity .TransientFailure {
578+ seenTransientFailure = true
567579 }
568- }
569580
570- // Ensure no extra callbacks are invoked
571- select {
572- case <- callbackInvoked :
573- t .Fatal ("Received more callbacks than expected" )
574- case <- time .After (100 * time .Millisecond ):
575- // Expected: no more callbacks
581+ // If transitioning to READY after a failure, wait for callback
582+ if state == connectivity .Ready && seenTransientFailure {
583+ expectedCallbacks ++
584+ select {
585+ case <- callbackInvoked :
586+ // Callback received as expected
587+ seenTransientFailure = false
588+ case <- time .After (defaultTestTimeout ):
589+ mu .Lock ()
590+ got := callbackCount
591+ mu .Unlock ()
592+ t .Fatalf ("Timeout waiting for callback %d/%d after TRANSIENT_FAILURE→READY (got %d callbacks so far)" , expectedCallbacks , tt .wantCallbackCount , got )
593+ }
594+ }
576595 }
577596
597+ // Verify final callback count matches expected
578598 mu .Lock ()
579599 gotCallbackCount := callbackCount
580600 mu .Unlock ()
581601
582602 if gotCallbackCount != tt .wantCallbackCount {
583603 t .Errorf ("Got %d callback invocations, want %d" , gotCallbackCount , tt .wantCallbackCount )
584604 }
605+
606+ // Ensure no extra callbacks are invoked
607+ select {
608+ case <- callbackInvoked :
609+ mu .Lock ()
610+ final := callbackCount
611+ mu .Unlock ()
612+ t .Fatalf ("Received more callbacks than expected: got %d, want %d" , final , tt .wantCallbackCount )
613+ case <- time .After (50 * time .Millisecond ):
614+ // Expected: no more callbacks
615+ }
585616 })
586617 }
587618}
0 commit comments