diff --git a/epoch.go b/epoch.go index 43184b09..868f8e66 100644 --- a/epoch.go +++ b/epoch.go @@ -214,9 +214,6 @@ func (e *Epoch) init() error { func (e *Epoch) initOldestNotFinalizedNotarization() { rebroadcastFinalizationVotes := func() { - e.lock.Lock() - defer e.lock.Unlock() - if err := e.rebroadcastPastFinalizeVotes(); err != nil { e.Logger.Error("Could not rebroadcast past finalization votes", zap.Error(err)) } @@ -224,13 +221,13 @@ func (e *Epoch) initOldestNotFinalizedNotarization() { e.oldestNotFinalizedNotarization = NewNotarizationTime( e.FinalizeRebroadcastTimeout, e.haveNotFinalizedNotarizedRound, - rebroadcastFinalizationVotes, e.getRound) + rebroadcastFinalizationVotes, + e.getRound, + &e.lock, + ) } func (e *Epoch) getRound() uint64 { - e.lock.Lock() - defer e.lock.Unlock() - return e.round } @@ -2869,9 +2866,6 @@ func (e *Epoch) locateQuorumRecordByRound(targetRound uint64) *VerifiedQuorumRou } func (e *Epoch) haveNotFinalizedNotarizedRound() (uint64, bool) { - e.lock.Lock() - defer e.lock.Unlock() - var minRoundNum uint64 var found bool for _, round := range e.rounds { diff --git a/util.go b/util.go index f47fb0a3..4a1ff8a4 100644 --- a/util.go +++ b/util.go @@ -279,6 +279,9 @@ type NotarizationTime struct { latestRound uint64 lastRebroadcastTime time.Time oldestNotFinalizedRound uint64 + + // epoch lock + lock *sync.Mutex } func NewNotarizationTime( @@ -286,6 +289,7 @@ func NewNotarizationTime( haveUnFinalizedNotarization func() (uint64, bool), rebroadcastFinalizationVotes func(), getRound func() uint64, + lock *sync.Mutex, ) NotarizationTime { return NotarizationTime{ finalizeVoteRebroadcastTimeout: finalizeVoteRebroadcastTimeout, @@ -293,10 +297,14 @@ func NewNotarizationTime( rebroadcastFinalizationVotes: rebroadcastFinalizationVotes, getRound: getRound, checkInterval: finalizeVoteRebroadcastTimeout / 3, + lock: lock, } } func (nt *NotarizationTime) CheckForNotFinalizedNotarizedBlocks(now time.Time) { + nt.lock.Lock() + defer nt.lock.Unlock() + // If we have recently checked, don't check again if !nt.lastSampleTime.IsZero() && nt.lastSampleTime.Add(nt.checkInterval).After(now) { return @@ -305,7 +313,6 @@ func (nt *NotarizationTime) CheckForNotFinalizedNotarizedBlocks(now time.Time) { nt.lastSampleTime = now round := nt.getRound() - // As long as we make some progress, we don't check for a round not finalized. if round > nt.latestRound { nt.latestRound = round diff --git a/util_test.go b/util_test.go index a9340f71..609bcf5a 100644 --- a/util_test.go +++ b/util_test.go @@ -6,6 +6,7 @@ package simplex_test import ( "context" "fmt" + "sync" "testing" "time" @@ -400,13 +401,16 @@ func TestNotarizationTime(t *testing.T) { rebroadcastFinalizationVotes := func() { invoked++ } + lock := &sync.Mutex{} nt := NewNotarizationTime( defaultFinalizeVoteRebroadcastTimeout, haveNotFinalizedRound, rebroadcastFinalizationVotes, func() uint64 { return round - }) + }, + lock, + ) // First call should set the time and the round. have = true