Skip to content

Commit c13b020

Browse files
committed
Optimize PerasWeightSnapshot
1 parent 4e4ad00 commit c13b020

File tree

1 file changed

+142
-26
lines changed
  • ouroboros-consensus/src/ouroboros-consensus/Ouroboros/Consensus/Peras

1 file changed

+142
-26
lines changed

ouroboros-consensus/src/ouroboros-consensus/Ouroboros/Consensus/Peras/Weight.hs

Lines changed: 142 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
{-# LANGUAGE DataKinds #-}
2+
{-# LANGUAGE DeriveAnyClass #-}
23
{-# LANGUAGE DeriveGeneric #-}
34
{-# LANGUAGE DerivingVia #-}
5+
{-# LANGUAGE DuplicateRecordFields #-}
6+
{-# LANGUAGE FlexibleContexts #-}
47
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
8+
{-# LANGUAGE LambdaCase #-}
9+
{-# LANGUAGE MultiParamTypeClasses #-}
10+
{-# LANGUAGE OverloadedRecordDot #-}
511
{-# LANGUAGE ScopedTypeVariables #-}
612
{-# LANGUAGE TypeOperators #-}
13+
{-# LANGUAGE NoFieldSelectors #-}
714

815
-- | Data structure for tracking the weight of blocks due to Peras boosts.
916
module Ouroboros.Consensus.Peras.Weight
@@ -28,18 +35,29 @@ module Ouroboros.Consensus.Peras.Weight
2835
, weightBoostOfFragment
2936
) where
3037

31-
import Data.Foldable as Foldable (foldl')
32-
import Data.Map.Strict (Map)
33-
import qualified Data.Map.Strict as Map
38+
import Data.FingerTree.Strict (Measured (..), StrictFingerTree)
39+
import qualified Data.FingerTree.Strict as SFT
40+
import Data.Foldable (toList)
41+
import Data.List (sortOn)
3442
import GHC.Generics (Generic)
3543
import NoThunks.Class
3644
import Ouroboros.Consensus.Block
3745
import Ouroboros.Network.AnchoredFragment (AnchoredFragment)
3846
import qualified Ouroboros.Network.AnchoredFragment as AF
3947

4048
-- | Data structure for tracking the weight of blocks due to Peras boosts.
49+
--
50+
-- PRECONDITION: All boosted points tracked by this data structure must reside
51+
-- on a single linear chain, and no boosted point may be an EBB. Otherwise,
52+
-- queries on this data structure may return incorrect results.
53+
--
54+
-- For Peras (assuming an honest majority), this is guaranteed by the voting
55+
-- rules, together with the fact that Peras is not to be used with blocks where
56+
-- EBBs (if they can even exist) may receive boosts.
4157
newtype PerasWeightSnapshot blk = PerasWeightSnapshot
42-
{ getPerasWeightSnapshot :: Map (Point blk) PerasWeight
58+
{ getPerasWeightSnapshot :: StrictFingerTree PWSMeasure (BoostedPoint blk)
59+
-- ^ INVARIANT: The slots of the boosted points are strictly monotonically
60+
-- increasing.
4361
}
4462
deriving stock Eq
4563
deriving Generic
@@ -48,13 +66,56 @@ newtype PerasWeightSnapshot blk = PerasWeightSnapshot
4866
instance StandardHash blk => Show (PerasWeightSnapshot blk) where
4967
show = show . perasWeightSnapshotToList
5068

69+
data PWSMeasure = PWSMeasure
70+
{ slot :: !(WithOrigin SlotNo)
71+
-- ^ The maximum slot of all boosted points.
72+
, weight :: !PerasWeight
73+
-- ^ The sum of all weight boosts.
74+
, size :: !Int
75+
-- ^ The number of boosted points.
76+
}
77+
deriving stock Show
78+
79+
instance Semigroup PWSMeasure where
80+
m0 <> m1 =
81+
PWSMeasure
82+
{ slot = m0.slot `max` m1.slot
83+
, weight = m0.weight <> m1.weight
84+
, size = m0.size + m1.size
85+
}
86+
87+
instance Monoid PWSMeasure where
88+
mempty =
89+
PWSMeasure
90+
{ slot = Origin
91+
, weight = mempty
92+
, size = 0
93+
}
94+
95+
data BoostedPoint blk = BoostedPoint
96+
{ pt :: !(Point blk)
97+
, weight :: !PerasWeight
98+
}
99+
deriving stock (Show, Eq, Generic)
100+
deriving anyclass NoThunks
101+
102+
instance Measured PWSMeasure (BoostedPoint blk) where
103+
measure bp =
104+
PWSMeasure
105+
{ slot = pointSlot bp.pt
106+
, weight = bp.weight
107+
, size = 1
108+
}
109+
51110
-- | An empty 'PerasWeightSnapshot' not containing any boosted blocks.
52111
emptyPerasWeightSnapshot :: PerasWeightSnapshot blk
53-
emptyPerasWeightSnapshot = PerasWeightSnapshot Map.empty
112+
emptyPerasWeightSnapshot = PerasWeightSnapshot SFT.empty
54113

55114
-- | Create a weight snapshot from a list of boosted points with an associated
56115
-- weight. In case of duplicate points, their weights are combined.
57116
--
117+
-- PRECONDITION: The points lie on a singular linear chain.
118+
--
58119
-- >>> :{
59120
-- weights :: [(Point Blk, PerasWeight)]
60121
-- weights =
@@ -68,14 +129,12 @@ emptyPerasWeightSnapshot = PerasWeightSnapshot Map.empty
68129
-- >>> snap = mkPerasWeightSnapshot weights
69130
-- >>> snap
70131
-- [(Origin,PerasWeight 3),(At (Block {blockPointSlot = SlotNo 2, blockPointHash = "foo"}),PerasWeight 4),(At (Block {blockPointSlot = SlotNo 3, blockPointHash = "bar"}),PerasWeight 2)]
71-
mkPerasWeightSnapshot ::
72-
StandardHash blk =>
73-
[(Point blk, PerasWeight)] ->
74-
PerasWeightSnapshot blk
132+
mkPerasWeightSnapshot :: [(Point blk, PerasWeight)] -> PerasWeightSnapshot blk
75133
mkPerasWeightSnapshot =
76-
Foldable.foldl'
77-
(\s (pt, weight) -> addToPerasWeightSnapshot pt weight s)
78-
emptyPerasWeightSnapshot
134+
PerasWeightSnapshot
135+
. SFT.fromList
136+
. fmap (\(pt, w) -> BoostedPoint pt w)
137+
. sortOn (pointSlot . fst)
79138

80139
-- | Return the list of boosted points with their associated weight, sorted
81140
-- based on their point. Does not contain duplicate points.
@@ -94,11 +153,15 @@ mkPerasWeightSnapshot =
94153
-- >>> perasWeightSnapshotToList snap
95154
-- [(Origin,PerasWeight 3),(At (Block {blockPointSlot = SlotNo 2, blockPointHash = "foo"}),PerasWeight 4),(At (Block {blockPointSlot = SlotNo 3, blockPointHash = "bar"}),PerasWeight 2)]
96155
perasWeightSnapshotToList :: PerasWeightSnapshot blk -> [(Point blk, PerasWeight)]
97-
perasWeightSnapshotToList = Map.toAscList . getPerasWeightSnapshot
156+
perasWeightSnapshotToList (PerasWeightSnapshot ft) =
157+
(\(BoostedPoint pt w) -> (pt, w)) <$> toList ft
98158

99159
-- | Add weight for the given point to the 'PerasWeightSnapshot'. If the point
100160
-- already has some weight, it is added on top.
101161
--
162+
-- PRECONDITION: The point must lie on the same linear chain as the points
163+
-- already part of the 'PerasWeightSnapshot'.
164+
--
102165
-- >>> :{
103166
-- weights :: [(Point Blk, PerasWeight)]
104167
-- weights =
@@ -125,7 +188,17 @@ addToPerasWeightSnapshot ::
125188
PerasWeightSnapshot blk ->
126189
PerasWeightSnapshot blk
127190
addToPerasWeightSnapshot pt weight =
128-
PerasWeightSnapshot . Map.insertWith (<>) pt weight . getPerasWeightSnapshot
191+
\(PerasWeightSnapshot ft) ->
192+
let (l, r) = SFT.split (\m -> m.slot > pointSlot pt) ft
193+
in PerasWeightSnapshot $ insert l <> r
194+
where
195+
insert l = case SFT.viewr l of
196+
SFT.EmptyR -> SFT.singleton $ BoostedPoint pt weight
197+
l' SFT.:> BoostedPoint pt' weight'
198+
-- We already track some weight of @pt@.
199+
| pt == pt' -> l' SFT.|> BoostedPoint pt' (weight <> weight')
200+
-- Otherwise, insert @pt@ as a new boosted point.
201+
| otherwise -> l SFT.|> BoostedPoint pt weight
129202

130203
-- | Prune the given 'PerasWeightSnapshot' by removing the weight of all blocks
131204
-- strictly older than the given slot.
@@ -154,11 +227,8 @@ prunePerasWeightSnapshot ::
154227
SlotNo ->
155228
PerasWeightSnapshot blk ->
156229
PerasWeightSnapshot blk
157-
prunePerasWeightSnapshot slot =
158-
PerasWeightSnapshot . Map.dropWhileAntitone isTooOld . getPerasWeightSnapshot
159-
where
160-
isTooOld :: Point blk -> Bool
161-
isTooOld pt = pointSlot pt < NotOrigin slot
230+
prunePerasWeightSnapshot slot (PerasWeightSnapshot ft) =
231+
PerasWeightSnapshot $ SFT.dropUntil (\m -> m.slot >= NotOrigin slot) ft
162232

163233
-- | Get the weight boost for a point, or @'mempty' :: 'PerasWeight'@ otherwise.
164234
--
@@ -183,8 +253,12 @@ weightBoostOfPoint ::
183253
forall blk.
184254
StandardHash blk =>
185255
PerasWeightSnapshot blk -> Point blk -> PerasWeight
186-
weightBoostOfPoint (PerasWeightSnapshot weightByPoint) pt =
187-
Map.findWithDefault mempty pt weightByPoint
256+
weightBoostOfPoint (PerasWeightSnapshot ft) pt =
257+
case SFT.viewr $ SFT.takeUntil (\m -> m.slot > pointSlot pt) ft of
258+
SFT.EmptyR -> mempty
259+
_ SFT.:> BoostedPoint pt' weight'
260+
| pt == pt' -> weight'
261+
| otherwise -> mempty
188262

189263
-- | Get the weight boost for a fragment, ie the sum of all
190264
-- 'weightBoostOfPoint' for all points on the fragment (excluding the anchor).
@@ -230,11 +304,53 @@ weightBoostOfFragment ::
230304
PerasWeightSnapshot blk ->
231305
AnchoredFragment h ->
232306
PerasWeight
233-
weightBoostOfFragment weightSnap frag =
234-
-- TODO think about whether this could be done in sublinear complexity
235-
foldMap
236-
(weightBoostOfPoint weightSnap . castPoint . blockPoint)
237-
(AF.toOldestFirst frag)
307+
weightBoostOfFragment (PerasWeightSnapshot ft) = \case
308+
AF.Empty{} -> mempty
309+
frag@(oldestHdr AF.:< _) -> (measure boostingInfix).weight
310+
where
311+
-- /Not/ @'AF.lastSlot' frag@ as we want to ignore the anchor.
312+
oldestSlot = NotOrigin $ blockSlot oldestHdr
313+
314+
-- The infix of @ft@ which only contains boosted points which are also on
315+
-- @frag@ (via @isOnFrag@).
316+
boostingInfix :: StrictFingerTree PWSMeasure (BoostedPoint blk)
317+
boostingInfix = case SFT.viewr ft' of
318+
SFT.EmptyR -> ft'
319+
t SFT.:> bp
320+
| isOnFrag bp.pt -> ft'
321+
| otherwise -> go 0 (measure ft').size t
322+
where
323+
-- The suffix of @ft@ without boosted points which are too old to be on
324+
-- @frag@.
325+
ft' = SFT.dropUntil (\m -> m.slot >= oldestSlot) ft
326+
327+
-- Binary search on @ft'@ to find the longest prefix of @ft'@ where all
328+
-- boosted points satisfy @isOnFrag@.
329+
--
330+
-- PRECONDITION: @0 <= lb < ub@.
331+
go ::
332+
-- @lb@: All boosted points of the size @lb@ prefix of @ft'@ satisfy
333+
-- @isOnFrag@.
334+
Int ->
335+
-- @ub@: At least one boosted point of the size @ub@ prefix of @ft'@
336+
-- does not satisfy @isOnFrag@.
337+
Int ->
338+
-- The size @ub - 1@ prefix of @ft'@.
339+
StrictFingerTree PWSMeasure (BoostedPoint blk) ->
340+
StrictFingerTree PWSMeasure (BoostedPoint blk)
341+
go lb ub t
342+
| lb == ub - 1 = t
343+
| isOnFrag t'Pt = go mid ub t
344+
| otherwise = go lb mid t'
345+
where
346+
mid = (lb + ub) `div` 2
347+
(t', t'Pt) = case SFT.viewr $ SFT.takeUntil (\m -> m.size > mid) ft' of
348+
t'' SFT.:> bp -> (t'', bp.pt)
349+
-- @ft'@ is non-empty here, and we have @0 <= lb < mid@.
350+
SFT.EmptyR -> error "unreachable"
351+
352+
isOnFrag :: Point blk -> Bool
353+
isOnFrag pt = AF.pointOnFragment (castPoint pt) frag
238354

239355
-- $setup
240356
-- >>> import Ouroboros.Consensus.Block

0 commit comments

Comments
 (0)