diff --git a/store/store.go b/store/store.go index 70b95b1e..1e2d5011 100644 --- a/store/store.go +++ b/store/store.go @@ -55,6 +55,8 @@ type Store[H header.Header[H]] struct { contiguousHead atomic.Pointer[H] // pending keeps headers pending to be written in one batch pending *batch[H] + // syncCh is a channel used to synchronize writes + syncCh chan chan struct{} Params Parameters } @@ -109,6 +111,7 @@ func newStore[H header.Header[H]](ds datastore.Batching, opts ...Option) (*Store writes: make(chan []H, 16), writesDn: make(chan struct{}), pending: newBatch[H](params.WriteBatchSize), + syncCh: make(chan chan struct{}), Params: params, }, nil } @@ -153,6 +156,28 @@ func (s *Store[H]) Stop(ctx context.Context) error { return s.metrics.Close() } +// Sync ensures all pending writes are synchronized. It blocks until the operation completes or fails. +func (s *Store[H]) Sync(ctx context.Context) error { + waitCh := make(chan struct{}) + select { + case s.syncCh <- waitCh: + case <-s.writesDn: + return errStoppedStore + case <-ctx.Done(): + return ctx.Err() + } + + select { + case <-waitCh: + case <-s.writesDn: + return errStoppedStore + case <-ctx.Done(): + return ctx.Err() + } + + return nil +} + func (s *Store[H]) Height() uint64 { return s.heightSub.Height() } @@ -305,6 +330,12 @@ func (s *Store[H]) HasAt(_ context.Context, height uint64) bool { // DeleteTo implements [header.Store] interface. func (s *Store[H]) DeleteTo(ctx context.Context, to uint64) error { + // ensure all the pending headers are synchronized + err := s.Sync(ctx) + if err != nil { + return err + } + head, err := s.Head(ctx) if err != nil { return fmt.Errorf("header/store: reading head: %w", err) @@ -468,7 +499,8 @@ func (s *Store[H]) Append(ctx context.Context, headers ...H) error { func (s *Store[H]) flushLoop() { defer close(s.writesDn) ctx := context.Background() - for headers := range s.writes { + + flush := func(headers []H) { s.ensureInit(headers) // add headers to the pending and ensure they are accessible s.pending.Append(headers...) @@ -482,7 +514,7 @@ func (s *Store[H]) flushLoop() { // don't flush and continue if pending batch is not grown enough, // and Store is not stopping(headers == nil) if s.pending.Len() < s.Params.WriteBatchSize && headers != nil { - continue + return } startTime := time.Now() @@ -506,15 +538,37 @@ func (s *Store[H]) flushLoop() { s.metrics.flush(ctx, time.Since(startTime), s.pending.Len(), false) // reset pending s.pending.Reset() + } - if headers == nil { - // a signal to stop - return + for { + select { + case dn := <-s.syncCh: + for { + select { + case headers := <-s.writes: + flush(headers) + if headers == nil { + // a signal to stop + return + } + continue + default: + } + + close(dn) + break + } + case headers := <-s.writes: + flush(headers) + if headers == nil { + // a signal to stop + return + } } } } -// flush writes the given batch to datastore. +// flush writes given headers to datastore func (s *Store[H]) flush(ctx context.Context, headers ...H) error { ln := len(headers) if ln == 0 { diff --git a/store/store_test.go b/store/store_test.go index 0338a539..6a2cf409 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -675,6 +675,32 @@ func TestStore_DeleteTo_MoveHeadAndTail(t *testing.T) { assert.Equal(t, suite.Head().Height(), head.Height()) } +func TestStore_DeleteTo_Synchronized(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + suite := headertest.NewTestSuite(t) + + ds := sync.MutexWrap(datastore.NewMapDatastore()) + store := NewTestStore(t, ctx, ds, suite.Head(), WithWriteBatchSize(10)) + + err := store.Append(ctx, suite.GenDummyHeaders(50)...) + require.NoError(t, err) + + err = store.Append(ctx, suite.GenDummyHeaders(50)...) + require.NoError(t, err) + + err = store.Append(ctx, suite.GenDummyHeaders(50)...) + require.NoError(t, err) + + err = store.DeleteTo(ctx, 100) + require.NoError(t, err) + + tail, err := store.Tail(ctx) + require.NoError(t, err) + require.EqualValues(t, 100, tail.Height()) +} + func TestStorePendingCacheMiss(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) t.Cleanup(cancel)