Skip to content

Commit 510b17b

Browse files
IliaBulavintsevWondertanIliaBulavintsev
authored
feat(sync): non-recursive singleflight protection for Head (#229)
Closes: #154 This PR adds singleflight protection to prevent duplicate header requests in the Exchange component. The focus is on minimizing redundant requests, especially around the Head method, which is most likely to experience simultaneous calls. I moved out single-flight protection out of exchange to be in sync.Syncer. This way the single-flight protection works for every Head implementation: p2p, core. Besides, I avoid using the x/sync/singleflight as it does not handle context cancellation. This also fixes an extremely rare but possible issue with subjective initialization similar to #247 (cc @walldiss). This issue is caused by to be removed syncGetter when used together with recursion. The fix is to move singleflight handling out of sync.Syncer, it should not be responsible for this. This way we fix the issue, without the need for read/write refactoring in sync_head.go we discussed. --------- Co-authored-by: Wondertan <[email protected]> Co-authored-by: IliaBulavintsev <[email protected]>
1 parent 817429e commit 510b17b

File tree

6 files changed

+105
-143
lines changed

6 files changed

+105
-143
lines changed

sync/head_sync.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package sync
2+
3+
import (
4+
"context"
5+
"sync"
6+
7+
"github.com/celestiaorg/go-header"
8+
)
9+
10+
// syncHead is a wrapper around the header.Head that provides single-flight access to the Head.
11+
type syncHead[H header.Header[H]] struct {
12+
headMu sync.Mutex
13+
headCh chan struct{}
14+
head header.Head[H]
15+
16+
resHead H
17+
resErr error
18+
}
19+
20+
func (sh *syncHead[H]) Head(ctx context.Context, opts ...header.HeadOption[H]) (H, error) {
21+
sh.headMu.Lock()
22+
doneCh := sh.headCh
23+
acquired := doneCh == nil
24+
if acquired {
25+
doneCh = make(chan struct{})
26+
sh.headCh = doneCh
27+
}
28+
sh.headMu.Unlock()
29+
30+
if acquired {
31+
head, err := sh.head.Head(ctx, opts...)
32+
33+
sh.headMu.Lock()
34+
sh.resHead, sh.resErr = head, err
35+
sh.headCh = nil
36+
sh.headMu.Unlock()
37+
38+
close(doneCh)
39+
return head, err
40+
}
41+
42+
select {
43+
case <-doneCh:
44+
sh.headMu.Lock()
45+
defer sh.headMu.Unlock()
46+
return sh.resHead, sh.resErr
47+
case <-ctx.Done():
48+
var zero H
49+
return zero, ctx.Err()
50+
}
51+
}

sync/head_sync_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package sync
2+
3+
import (
4+
"context"
5+
"math/rand/v2"
6+
"sync"
7+
"sync/atomic"
8+
"testing"
9+
"time"
10+
11+
"github.com/stretchr/testify/assert"
12+
13+
libhead "github.com/celestiaorg/go-header"
14+
"github.com/celestiaorg/go-header/headertest"
15+
)
16+
17+
func Test_syncHead(t *testing.T) {
18+
counter := &headCounter{}
19+
sh := &syncHead[*headertest.DummyHeader]{head: counter}
20+
21+
callsN := 1000
22+
wg := sync.WaitGroup{}
23+
wg.Add(callsN)
24+
for i := 0; i < callsN; i++ {
25+
go func() {
26+
defer wg.Done()
27+
time.Sleep(time.Duration(rand.IntN(1000)) * time.Microsecond)
28+
_, _ = sh.Head(context.Background())
29+
}()
30+
}
31+
32+
wg.Wait()
33+
34+
assert.Less(t, int(counter.cntr.Load()), callsN/100) // <1% of calls should go through
35+
}
36+
37+
type headCounter struct {
38+
cntr atomic.Int64
39+
}
40+
41+
func (h *headCounter) Head(
42+
ctx context.Context,
43+
h2 ...libhead.HeadOption[*headertest.DummyHeader],
44+
) (*headertest.DummyHeader, error) {
45+
time.Sleep(time.Millisecond * 1)
46+
h.cntr.Add(1)
47+
return &headertest.DummyHeader{}, nil
48+
}

sync/sync.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ var log = logging.Logger("header/sync")
3333
type Syncer[H header.Header[H]] struct {
3434
sub header.Subscriber[H] // to subscribe for new Network Heads
3535
store syncStore[H] // to store all the headers to
36-
getter syncGetter[H] // to fetch headers from
36+
head syncHead[H]
37+
getter header.Getter[H] // to fetch headers from
3738
metrics *metrics
3839

3940
// stateLk protects state which represents the current or latest sync
@@ -80,7 +81,8 @@ func NewSyncer[H header.Header[H]](
8081
return &Syncer[H]{
8182
sub: sub,
8283
store: syncStore[H]{Store: store},
83-
getter: syncGetter[H]{Getter: getter},
84+
head: syncHead[H]{head: getter},
85+
getter: getter,
8486
metrics: metrics,
8587
triggerSync: make(chan struct{}, 1), // should be buffered
8688
Params: &params,

sync/sync_getter.go

Lines changed: 0 additions & 52 deletions
This file was deleted.

sync/sync_getter_test.go

Lines changed: 0 additions & 72 deletions
This file was deleted.

sync/sync_head.go

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,11 @@ func (s *Syncer[H]) Head(ctx context.Context, _ ...header.HeadOption[H]) (H, err
3030
return sbjHead, nil
3131
}
3232

33-
// single-flight protection ensure only one Head is requested at the time
34-
if !s.getter.Lock() {
35-
// means that other routine held the lock and set the subjective head
36-
return s.subjectiveHead(ctx)
37-
}
38-
defer s.getter.Unlock()
39-
4033
s.metrics.outdatedHead(s.ctx)
4134

4235
reqCtx, cancel := context.WithTimeout(ctx, headRequestTimeout)
4336
defer cancel()
44-
netHead, err := s.getter.Head(reqCtx, header.WithTrustedHead[H](sbjHead))
37+
netHead, err := s.head.Head(reqCtx, header.WithTrustedHead[H](sbjHead))
4538
if err != nil {
4639
log.Warnw(
4740
"failed to get recent head, returning current subjective",
@@ -86,16 +79,8 @@ func (s *Syncer[H]) subjectiveHead(ctx context.Context) (H, error) {
8679
}
8780
// otherwise, request head from a trusted peer
8881
log.Infow("stored head header expired", "height", storeHead.Height())
89-
// single-flight protection
90-
// ensure only one Head is requested at the time
91-
if !s.getter.Lock() {
92-
// means that other routine held the lock and set the subjective head for us,
93-
// so just recursively get it
94-
return s.subjectiveHead(ctx)
95-
}
96-
defer s.getter.Unlock()
9782

98-
trustHead, err := s.getter.Head(ctx)
83+
trustHead, err := s.head.Head(ctx)
9984
if err != nil {
10085
return trustHead, err
10186
}

0 commit comments

Comments
 (0)