Skip to content

Commit 18c93d4

Browse files
committed
Optimize PerasWeightSnapshot
1 parent 3244b19 commit 18c93d4

File tree

1 file changed

+139
-20
lines changed
  • ouroboros-consensus/src/ouroboros-consensus/Ouroboros/Consensus/Peras

1 file changed

+139
-20
lines changed

ouroboros-consensus/src/ouroboros-consensus/Ouroboros/Consensus/Peras/Weight.hs

Lines changed: 139 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
{-# LANGUAGE DataKinds #-}
2+
{-# LANGUAGE DeriveAnyClass #-}
23
{-# LANGUAGE DeriveGeneric #-}
34
{-# LANGUAGE DerivingVia #-}
5+
{-# LANGUAGE DuplicateRecordFields #-}
6+
{-# LANGUAGE FlexibleContexts #-}
47
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
8+
{-# LANGUAGE LambdaCase #-}
9+
{-# LANGUAGE MultiParamTypeClasses #-}
10+
{-# LANGUAGE OverloadedRecordDot #-}
511
{-# LANGUAGE ScopedTypeVariables #-}
612
{-# LANGUAGE TypeOperators #-}
13+
{-# LANGUAGE NoFieldSelectors #-}
714

815
-- | Data structure for tracking the weight of blocks due to Peras boosts.
916
module Ouroboros.Consensus.Peras.Weight
@@ -30,9 +37,9 @@ module Ouroboros.Consensus.Peras.Weight
3037
, takeVolatileSuffix
3138
) where
3239

33-
import Data.Foldable as Foldable (foldl')
34-
import Data.Map.Strict (Map)
35-
import qualified Data.Map.Strict as Map
40+
import Data.FingerTree.Strict (Measured (..), StrictFingerTree)
41+
import qualified Data.FingerTree.Strict as SFT
42+
import Data.Foldable as Foldable (foldl', toList)
3643
import Data.Word (Word64)
3744
import GHC.Generics (Generic)
3845
import NoThunks.Class
@@ -42,8 +49,20 @@ import Ouroboros.Network.AnchoredFragment (AnchoredFragment)
4249
import qualified Ouroboros.Network.AnchoredFragment as AF
4350

4451
-- | Data structure for tracking the weight of blocks due to Peras boosts.
52+
--
53+
-- PRECONDITION: All boosted points tracked by this data structure must reside
54+
-- on a single linear chain, and no boosted point may be an EBB. Otherwise,
55+
-- queries on this data structure may return incorrect results.
56+
--
57+
-- TODO: This isn't true across cooldowns.
58+
--
59+
-- For Peras (assuming an honest majority), this is guaranteed by the voting
60+
-- rules, together with the fact that Peras is not to be used with blocks where
61+
-- EBBs (if they can even exist) may receive boosts.
4562
newtype PerasWeightSnapshot blk = PerasWeightSnapshot
46-
{ getPerasWeightSnapshot :: Map (Point blk) PerasWeight
63+
{ getPerasWeightSnapshot :: StrictFingerTree PWSMeasure (BoostedPoint blk)
64+
-- ^ INVARIANT: The slots of the boosted points are strictly monotonically
65+
-- increasing.
4766
}
4867
deriving stock Eq
4968
deriving Generic
@@ -52,13 +71,56 @@ newtype PerasWeightSnapshot blk = PerasWeightSnapshot
5271
instance StandardHash blk => Show (PerasWeightSnapshot blk) where
5372
show = show . perasWeightSnapshotToList
5473

74+
data PWSMeasure = PWSMeasure
75+
{ slot :: !(WithOrigin SlotNo)
76+
-- ^ The maximum slot of all boosted points.
77+
, weight :: !PerasWeight
78+
-- ^ The sum of all weight boosts.
79+
, size :: !Int
80+
-- ^ The number of boosted points.
81+
}
82+
deriving stock Show
83+
84+
instance Semigroup PWSMeasure where
85+
m0 <> m1 =
86+
PWSMeasure
87+
{ slot = m0.slot `max` m1.slot
88+
, weight = m0.weight <> m1.weight
89+
, size = m0.size + m1.size
90+
}
91+
92+
instance Monoid PWSMeasure where
93+
mempty =
94+
PWSMeasure
95+
{ slot = Origin
96+
, weight = mempty
97+
, size = 0
98+
}
99+
100+
data BoostedPoint blk = BoostedPoint
101+
{ pt :: !(Point blk)
102+
, weight :: !PerasWeight
103+
}
104+
deriving stock (Show, Eq, Generic)
105+
deriving anyclass NoThunks
106+
107+
instance Measured PWSMeasure (BoostedPoint blk) where
108+
measure bp =
109+
PWSMeasure
110+
{ slot = pointSlot bp.pt
111+
, weight = bp.weight
112+
, size = 1
113+
}
114+
55115
-- | An empty 'PerasWeightSnapshot' not containing any boosted blocks.
56116
emptyPerasWeightSnapshot :: PerasWeightSnapshot blk
57-
emptyPerasWeightSnapshot = PerasWeightSnapshot Map.empty
117+
emptyPerasWeightSnapshot = PerasWeightSnapshot SFT.empty
58118

59119
-- | Create a weight snapshot from a list of boosted points with an associated
60120
-- weight. In case of duplicate points, their weights are combined.
61121
--
122+
-- PRECONDITION: The points lie on a singular linear chain.
123+
--
62124
-- >>> :{
63125
-- weights :: [(Point Blk, PerasWeight)]
64126
-- weights =
@@ -98,11 +160,15 @@ mkPerasWeightSnapshot =
98160
-- >>> perasWeightSnapshotToList snap
99161
-- [(Point Origin,PerasWeight 3),(Point (At (Block {blockPointSlot = SlotNo 2, blockPointHash = "foo"})),PerasWeight 4),(Point (At (Block {blockPointSlot = SlotNo 3, blockPointHash = "bar"})),PerasWeight 2)]
100162
perasWeightSnapshotToList :: PerasWeightSnapshot blk -> [(Point blk, PerasWeight)]
101-
perasWeightSnapshotToList = Map.toAscList . getPerasWeightSnapshot
163+
perasWeightSnapshotToList (PerasWeightSnapshot ft) =
164+
(\(BoostedPoint pt w) -> (pt, w)) <$> toList ft
102165

103166
-- | Add weight for the given point to the 'PerasWeightSnapshot'. If the point
104167
-- already has some weight, it is added on top.
105168
--
169+
-- PRECONDITION: The point must lie on the same linear chain as the points
170+
-- already part of the 'PerasWeightSnapshot'.
171+
--
106172
-- >>> :{
107173
-- weights :: [(Point Blk, PerasWeight)]
108174
-- weights =
@@ -129,7 +195,17 @@ addToPerasWeightSnapshot ::
129195
PerasWeightSnapshot blk ->
130196
PerasWeightSnapshot blk
131197
addToPerasWeightSnapshot pt weight =
132-
PerasWeightSnapshot . Map.insertWith (<>) pt weight . getPerasWeightSnapshot
198+
\(PerasWeightSnapshot ft) ->
199+
let (l, r) = SFT.split (\m -> m.slot > pointSlot pt) ft
200+
in PerasWeightSnapshot $ insert l <> r
201+
where
202+
insert l = case SFT.viewr l of
203+
SFT.EmptyR -> SFT.singleton $ BoostedPoint pt weight
204+
l' SFT.:> BoostedPoint pt' weight'
205+
-- We already track some weight of @pt@.
206+
| pt == pt' -> l' SFT.|> BoostedPoint pt' (weight <> weight')
207+
-- Otherwise, insert @pt@ as a new boosted point.
208+
| otherwise -> l SFT.|> BoostedPoint pt weight
133209

134210
-- | Prune the given 'PerasWeightSnapshot' by removing the weight of all blocks
135211
-- strictly older than the given slot.
@@ -158,11 +234,8 @@ prunePerasWeightSnapshot ::
158234
SlotNo ->
159235
PerasWeightSnapshot blk ->
160236
PerasWeightSnapshot blk
161-
prunePerasWeightSnapshot slot =
162-
PerasWeightSnapshot . Map.dropWhileAntitone isTooOld . getPerasWeightSnapshot
163-
where
164-
isTooOld :: Point blk -> Bool
165-
isTooOld pt = pointSlot pt < NotOrigin slot
237+
prunePerasWeightSnapshot slot (PerasWeightSnapshot ft) =
238+
PerasWeightSnapshot $ SFT.dropUntil (\m -> m.slot >= NotOrigin slot) ft
166239

167240
-- | Get the weight boost for a point, or @'mempty' :: 'PerasWeight'@ otherwise.
168241
--
@@ -187,8 +260,12 @@ weightBoostOfPoint ::
187260
forall blk.
188261
StandardHash blk =>
189262
PerasWeightSnapshot blk -> Point blk -> PerasWeight
190-
weightBoostOfPoint (PerasWeightSnapshot weightByPoint) pt =
191-
Map.findWithDefault mempty pt weightByPoint
263+
weightBoostOfPoint (PerasWeightSnapshot ft) pt =
264+
case SFT.viewr $ SFT.takeUntil (\m -> m.slot > pointSlot pt) ft of
265+
SFT.EmptyR -> mempty
266+
_ SFT.:> BoostedPoint pt' weight'
267+
| pt == pt' -> weight'
268+
| otherwise -> mempty
192269

193270
-- | Get the weight boost for a fragment, ie the sum of all
194271
-- 'weightBoostOfPoint' for all points on the fragment (excluding the anchor).
@@ -234,11 +311,53 @@ weightBoostOfFragment ::
234311
PerasWeightSnapshot blk ->
235312
AnchoredFragment h ->
236313
PerasWeight
237-
weightBoostOfFragment weightSnap frag =
238-
-- TODO think about whether this could be done in sublinear complexity
239-
foldMap
240-
(weightBoostOfPoint weightSnap . castPoint . blockPoint)
241-
(AF.toOldestFirst frag)
314+
weightBoostOfFragment (PerasWeightSnapshot ft) = \case
315+
AF.Empty{} -> mempty
316+
frag@(oldestHdr AF.:< _) -> (measure boostingInfix).weight
317+
where
318+
-- /Not/ @'AF.lastSlot' frag@ as we want to ignore the anchor.
319+
oldestSlot = NotOrigin $ blockSlot oldestHdr
320+
321+
-- The infix of @ft@ which only contains boosted points which are also on
322+
-- @frag@ (via @isOnFrag@).
323+
boostingInfix :: StrictFingerTree PWSMeasure (BoostedPoint blk)
324+
boostingInfix = case SFT.viewr ft' of
325+
SFT.EmptyR -> ft'
326+
t SFT.:> bp
327+
| isOnFrag bp.pt -> ft'
328+
| otherwise -> go 0 (measure ft').size t
329+
where
330+
-- The suffix of @ft@ without boosted points which are too old to be on
331+
-- @frag@.
332+
ft' = SFT.dropUntil (\m -> m.slot >= oldestSlot) ft
333+
334+
-- Binary search on @ft'@ to find the longest prefix of @ft'@ where all
335+
-- boosted points satisfy @isOnFrag@.
336+
--
337+
-- PRECONDITION: @0 <= lb < ub@.
338+
go ::
339+
-- @lb@: All boosted points of the size @lb@ prefix of @ft'@ satisfy
340+
-- @isOnFrag@.
341+
Int ->
342+
-- @ub@: At least one boosted point of the size @ub@ prefix of @ft'@
343+
-- does not satisfy @isOnFrag@.
344+
Int ->
345+
-- The size @ub - 1@ prefix of @ft'@.
346+
StrictFingerTree PWSMeasure (BoostedPoint blk) ->
347+
StrictFingerTree PWSMeasure (BoostedPoint blk)
348+
go lb ub t
349+
| lb == ub - 1 = t
350+
| isOnFrag t'Pt = go mid ub t
351+
| otherwise = go lb mid t'
352+
where
353+
mid = (lb + ub) `div` 2
354+
(t', t'Pt) = case SFT.viewr $ SFT.takeUntil (\m -> m.size > mid) ft' of
355+
t'' SFT.:> bp -> (t'', bp.pt)
356+
-- @ft'@ is non-empty here, and we have @0 <= lb < mid@.
357+
SFT.EmptyR -> error "unreachable"
358+
359+
isOnFrag :: Point blk -> Bool
360+
isOnFrag pt = AF.pointOnFragment (castPoint pt) frag
242361

243362
-- | Get the total weight for a fragment, ie the length plus the weight boost
244363
-- ('weightBoostOfFragment') of the fragment.
@@ -339,7 +458,7 @@ takeVolatileSuffix ::
339458
AnchoredFragment h ->
340459
AnchoredFragment h
341460
takeVolatileSuffix snap secParam frag
342-
| Map.null $ getPerasWeightSnapshot snap =
461+
| SFT.null snap.getPerasWeightSnapshot =
343462
-- Optimize the case where Peras is disabled.
344463
AF.anchorNewest (unPerasWeight k) frag
345464
| hasAtMostWeightK frag = frag

0 commit comments

Comments
 (0)