@@ -16,6 +16,8 @@ package cache
1616
1717import (
1818 "context"
19+ "errors"
20+ "fmt"
1921 "sync"
2022 "testing"
2123 "time"
@@ -501,6 +503,7 @@ type mockWatcher struct {
501503 wg sync.WaitGroup
502504 mu sync.Mutex
503505 lastStartRev int64
506+ progressErr error
504507}
505508
506509func newMockWatcher (buf int ) * mockWatcher {
@@ -522,7 +525,7 @@ func (m *mockWatcher) Watch(ctx context.Context, _ string, opts ...clientv3.OpOp
522525 return out
523526}
524527
525- func (m * mockWatcher ) RequestProgress (_ context.Context ) error { return nil }
528+ func (m * mockWatcher ) RequestProgress (_ context.Context ) error { return m . progressErr }
526529
527530func (m * mockWatcher ) Close () error {
528531 m .closeOnce .Do (func () { close (m .responses ) })
@@ -600,6 +603,7 @@ func (m *mockWatcher) streamResponses(ctx context.Context, out chan<- clientv3.W
600603type kvStub struct {
601604 queued []* clientv3.GetResponse
602605 defaultResp * clientv3.GetResponse
606+ defaultErr error
603607}
604608
605609func newKVStub (resps ... * clientv3.GetResponse ) * kvStub {
@@ -610,7 +614,11 @@ func newKVStub(resps ...*clientv3.GetResponse) *kvStub {
610614 }
611615}
612616
613- func (s * kvStub ) Get (ctx context.Context , key string , _ ... clientv3.OpOption ) (* clientv3.GetResponse , error ) {
617+ func (s * kvStub ) Get (_ context.Context , key string , opts ... clientv3.OpOption ) (* clientv3.GetResponse , error ) {
618+ if s .defaultErr != nil {
619+ return nil , s .defaultErr
620+ }
621+
614622 if len (s .queued ) > 0 {
615623 next := s .queued [0 ]
616624 s .queued = s .queued [1 :]
@@ -692,3 +700,171 @@ func verifySnapshot(t *testing.T, cache *Cache, want []*mvccpb.KeyValue) {
692700 t .Fatalf ("cache snapshot mismatch (-want +got):\n %s" , diff )
693701 }
694702}
703+
704+ type noopProgressNotifier struct {}
705+
706+ func (n * noopProgressNotifier ) RequestProgress (_ context.Context ) error {
707+ return nil
708+ }
709+
710+ func newTestProgressRequestor () * conditionalProgressRequestor {
711+ return newConditionalProgressRequestor (& noopProgressNotifier {}, realClock {}, 100 * time .Millisecond )
712+ }
713+
714+ func newCacheForWaitTest (serverRev int64 , localRev int64 , pr progressRequestor ) (* Cache , * store ) {
715+ cfg := defaultConfig ()
716+ st := newStore (cfg .BTreeDegree , cfg .HistoryWindowSize )
717+ if localRev > 0 {
718+ st .Restore (nil , localRev )
719+ }
720+ kv := & kvStub {
721+ defaultResp : & clientv3.GetResponse {Header : & pb.ResponseHeader {Revision : serverRev }},
722+ }
723+ return & Cache {
724+ kv : kv ,
725+ store : st ,
726+ prefix : "/" ,
727+ progressRequestor : pr ,
728+ cfg : cfg ,
729+ }, st
730+ }
731+
732+ func TestWaitTillRevision (t * testing.T ) {
733+ t .Run ("cache_already_caught_up" , func (t * testing.T ) {
734+ c , _ := newCacheForWaitTest (10 , 10 , newTestProgressRequestor ())
735+
736+ if err := c .waitTillRevision (context .Background (), 10 ); err != nil {
737+ t .Fatalf ("unexpected error: %v" , err )
738+ }
739+ })
740+
741+ t .Run ("local_rev_sufficient_skips_server_call" , func (t * testing.T ) {
742+ cfg := defaultConfig ()
743+ st := newStore (cfg .BTreeDegree , cfg .HistoryWindowSize )
744+ st .Restore (nil , 10 )
745+ c := & Cache {
746+ kv : & kvStub {defaultErr : fmt .Errorf ("should not be called" )},
747+ store : st ,
748+ prefix : "/" ,
749+ progressRequestor : newTestProgressRequestor (),
750+ cfg : cfg ,
751+ }
752+
753+ if err := c .waitTillRevision (context .Background (), 5 ); err != nil {
754+ t .Fatalf ("unexpected error: %v" , err )
755+ }
756+ })
757+
758+ t .Run ("cache_catches_up" , func (t * testing.T ) {
759+ c , st := newCacheForWaitTest (15 , 5 , newTestProgressRequestor ())
760+
761+ go func () {
762+ time .Sleep (200 * time .Millisecond )
763+ st .Restore (nil , 10 )
764+ }()
765+
766+ ctx , cancel := context .WithTimeout (context .Background (), 2 * time .Second )
767+ defer cancel ()
768+ if err := c .waitTillRevision (ctx , 10 ); err != nil {
769+ t .Fatalf ("unexpected error: %v" , err )
770+ }
771+ })
772+
773+ t .Run ("rev_zero_cache_caught_up" , func (t * testing.T ) {
774+ c , _ := newCacheForWaitTest (10 , 10 , newTestProgressRequestor ())
775+
776+ if err := c .waitTillRevision (context .Background (), 0 ); err != nil {
777+ t .Fatalf ("unexpected error: %v" , err )
778+ }
779+ })
780+
781+ t .Run ("rev_zero_waits_for_server_rev" , func (t * testing.T ) {
782+ c , st := newCacheForWaitTest (10 , 5 , newTestProgressRequestor ())
783+
784+ go func () {
785+ time .Sleep (200 * time .Millisecond )
786+ st .Restore (nil , 10 )
787+ }()
788+
789+ ctx , cancel := context .WithTimeout (context .Background (), 2 * time .Second )
790+ defer cancel ()
791+ if err := c .waitTillRevision (ctx , 0 ); err != nil {
792+ t .Fatalf ("unexpected error: %v" , err )
793+ }
794+ })
795+
796+ t .Run ("context_cancelled" , func (t * testing.T ) {
797+ c , _ := newCacheForWaitTest (10 , 5 , newTestProgressRequestor ())
798+
799+ ctx , cancel := context .WithTimeout (context .Background (), 200 * time .Millisecond )
800+ defer cancel ()
801+ err := c .waitTillRevision (ctx , 10 )
802+ if ! errors .Is (err , context .DeadlineExceeded ) {
803+ t .Fatalf ("got %v, want context.DeadlineExceeded" , err )
804+ }
805+ })
806+
807+ t .Run ("timeout" , func (t * testing.T ) {
808+ c , _ := newCacheForWaitTest (10 , 5 , newTestProgressRequestor ())
809+
810+ ctx , cancel := context .WithTimeout (context .Background (), 10 * time .Second )
811+ defer cancel ()
812+ err := c .waitTillRevision (ctx , 10 )
813+ if ! errors .Is (err , ErrCacheTimeout ) {
814+ t .Fatalf ("got %v, want ErrCacheTimeout" , err )
815+ }
816+ })
817+ }
818+
819+ func TestWaitTillRevisionTriggersProgressRequests (t * testing.T ) {
820+ fc := newFakeClock ()
821+ pr := newTestConditionalProgressRequestor (fc , 50 * time .Millisecond )
822+ c , st := newCacheForWaitTest (15 , 5 , pr )
823+
824+ // Start progress requestor
825+ ctx , cancel := context .WithCancel (context .Background ())
826+ defer cancel ()
827+ go pr .run (ctx )
828+
829+ // Wait for goroutine to start
830+ time .Sleep (10 * time .Millisecond )
831+
832+ // Initially, no progress requests should be sent (no waiters)
833+ fc .Advance (100 * time .Millisecond )
834+ if err := pollConditionNoChange (func () bool {
835+ return pr .progressRequestsSentCount .Load () == 0
836+ }); err != nil {
837+ t .Fatal ("expected no progress requests without active waiters" )
838+ }
839+
840+ // Start waiting - this should trigger progress requests
841+ errCh := make (chan error , 1 )
842+ go func () {
843+ errCh <- c .waitTillRevision (context .Background (), 10 )
844+ }()
845+
846+ // Advance time and wait for progress requests to start
847+ fc .Advance (50 * time .Millisecond )
848+ time .Sleep (10 * time .Millisecond )
849+
850+ // Verify progress requests are being sent while waiting
851+ if pr .progressRequestsSentCount .Load () == 0 {
852+ t .Fatal ("expected progress requests during wait" )
853+ }
854+
855+ // Complete the wait
856+ st .Restore (nil , 15 )
857+
858+ if err := <- errCh ; err != nil {
859+ t .Fatalf ("unexpected error: %v" , err )
860+ }
861+
862+ // After completion, progress requests should stop
863+ finalCount := pr .progressRequestsSentCount .Load ()
864+ fc .Advance (100 * time .Millisecond )
865+ if err := pollConditionNoChange (func () bool {
866+ return pr .progressRequestsSentCount .Load () == finalCount
867+ }); err != nil {
868+ t .Fatalf ("expected no new progress requests after completion, got %d initially, then changed" , finalCount )
869+ }
870+ }
0 commit comments