Skip to content

Commit 9544349

Browse files
committed
Optimize PerasWeightSnapshot
1 parent 4e4ad00 commit 9544349

File tree

1 file changed

+136
-19
lines changed
  • ouroboros-consensus/src/ouroboros-consensus/Ouroboros/Consensus/Peras

1 file changed

+136
-19
lines changed

ouroboros-consensus/src/ouroboros-consensus/Ouroboros/Consensus/Peras/Weight.hs

Lines changed: 136 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
{-# LANGUAGE DataKinds #-}
2+
{-# LANGUAGE DeriveAnyClass #-}
23
{-# LANGUAGE DeriveGeneric #-}
34
{-# LANGUAGE DerivingVia #-}
5+
{-# LANGUAGE DuplicateRecordFields #-}
6+
{-# LANGUAGE FlexibleContexts #-}
47
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
8+
{-# LANGUAGE LambdaCase #-}
9+
{-# LANGUAGE MultiParamTypeClasses #-}
10+
{-# LANGUAGE OverloadedRecordDot #-}
511
{-# LANGUAGE ScopedTypeVariables #-}
612
{-# LANGUAGE TypeOperators #-}
13+
{-# LANGUAGE NoFieldSelectors #-}
714

815
-- | Data structure for tracking the weight of blocks due to Peras boosts.
916
module Ouroboros.Consensus.Peras.Weight
@@ -28,18 +35,28 @@ module Ouroboros.Consensus.Peras.Weight
2835
, weightBoostOfFragment
2936
) where
3037

31-
import Data.Foldable as Foldable (foldl')
32-
import Data.Map.Strict (Map)
33-
import qualified Data.Map.Strict as Map
38+
import Data.FingerTree.Strict (Measured (..), StrictFingerTree)
39+
import qualified Data.FingerTree.Strict as SFT
40+
import Data.Foldable as Foldable (foldl', toList)
3441
import GHC.Generics (Generic)
3542
import NoThunks.Class
3643
import Ouroboros.Consensus.Block
3744
import Ouroboros.Network.AnchoredFragment (AnchoredFragment)
3845
import qualified Ouroboros.Network.AnchoredFragment as AF
3946

4047
-- | Data structure for tracking the weight of blocks due to Peras boosts.
48+
--
49+
-- PRECONDITION: All boosted points tracked by this data structure must reside
50+
-- on a single linear chain, and no boosted point may be an EBB. Otherwise,
51+
-- queries on this data structure may return incorrect results.
52+
--
53+
-- For Peras (assuming an honest majority), this is guaranteed by the voting
54+
-- rules, together with the fact that Peras is not to be used with blocks where
55+
-- EBBs (if they can even exist) may receive boosts.
4156
newtype PerasWeightSnapshot blk = PerasWeightSnapshot
42-
{ getPerasWeightSnapshot :: Map (Point blk) PerasWeight
57+
{ getPerasWeightSnapshot :: StrictFingerTree PWSMeasure (BoostedPoint blk)
58+
-- ^ INVARIANT: The slots of the boosted points are strictly monotonically
59+
-- increasing.
4360
}
4461
deriving stock Eq
4562
deriving Generic
@@ -48,13 +65,56 @@ newtype PerasWeightSnapshot blk = PerasWeightSnapshot
4865
instance StandardHash blk => Show (PerasWeightSnapshot blk) where
4966
show = show . perasWeightSnapshotToList
5067

68+
data PWSMeasure = PWSMeasure
69+
{ slot :: !(WithOrigin SlotNo)
70+
-- ^ The maximum slot of all boosted points.
71+
, weight :: !PerasWeight
72+
-- ^ The sum of all weight boosts.
73+
, size :: !Int
74+
-- ^ The number of boosted points.
75+
}
76+
deriving stock Show
77+
78+
instance Semigroup PWSMeasure where
79+
m0 <> m1 =
80+
PWSMeasure
81+
{ slot = m0.slot `max` m1.slot
82+
, weight = m0.weight <> m1.weight
83+
, size = m0.size + m1.size
84+
}
85+
86+
instance Monoid PWSMeasure where
87+
mempty =
88+
PWSMeasure
89+
{ slot = Origin
90+
, weight = mempty
91+
, size = 0
92+
}
93+
94+
data BoostedPoint blk = BoostedPoint
95+
{ pt :: !(Point blk)
96+
, weight :: !PerasWeight
97+
}
98+
deriving stock (Show, Eq, Generic)
99+
deriving anyclass NoThunks
100+
101+
instance Measured PWSMeasure (BoostedPoint blk) where
102+
measure bp =
103+
PWSMeasure
104+
{ slot = pointSlot bp.pt
105+
, weight = bp.weight
106+
, size = 1
107+
}
108+
51109
-- | An empty 'PerasWeightSnapshot' not containing any boosted blocks.
52110
emptyPerasWeightSnapshot :: PerasWeightSnapshot blk
53-
emptyPerasWeightSnapshot = PerasWeightSnapshot Map.empty
111+
emptyPerasWeightSnapshot = PerasWeightSnapshot SFT.empty
54112

55113
-- | Create a weight snapshot from a list of boosted points with an associated
56114
-- weight. In case of duplicate points, their weights are combined.
57115
--
116+
-- PRECONDITION: The points lie on a singular linear chain.
117+
--
58118
-- >>> :{
59119
-- weights :: [(Point Blk, PerasWeight)]
60120
-- weights =
@@ -94,11 +154,15 @@ mkPerasWeightSnapshot =
94154
-- >>> perasWeightSnapshotToList snap
95155
-- [(Origin,PerasWeight 3),(At (Block {blockPointSlot = SlotNo 2, blockPointHash = "foo"}),PerasWeight 4),(At (Block {blockPointSlot = SlotNo 3, blockPointHash = "bar"}),PerasWeight 2)]
96156
perasWeightSnapshotToList :: PerasWeightSnapshot blk -> [(Point blk, PerasWeight)]
97-
perasWeightSnapshotToList = Map.toAscList . getPerasWeightSnapshot
157+
perasWeightSnapshotToList (PerasWeightSnapshot ft) =
158+
(\(BoostedPoint pt w) -> (pt, w)) <$> toList ft
98159

99160
-- | Add weight for the given point to the 'PerasWeightSnapshot'. If the point
100161
-- already has some weight, it is added on top.
101162
--
163+
-- PRECONDITION: The point must lie on the same linear chain as the points
164+
-- already part of the 'PerasWeightSnapshot'.
165+
--
102166
-- >>> :{
103167
-- weights :: [(Point Blk, PerasWeight)]
104168
-- weights =
@@ -125,7 +189,17 @@ addToPerasWeightSnapshot ::
125189
PerasWeightSnapshot blk ->
126190
PerasWeightSnapshot blk
127191
addToPerasWeightSnapshot pt weight =
128-
PerasWeightSnapshot . Map.insertWith (<>) pt weight . getPerasWeightSnapshot
192+
\(PerasWeightSnapshot ft) ->
193+
let (l, r) = SFT.split (\m -> m.slot > pointSlot pt) ft
194+
in PerasWeightSnapshot $ insert l <> r
195+
where
196+
insert l = case SFT.viewr l of
197+
SFT.EmptyR -> SFT.singleton $ BoostedPoint pt weight
198+
l' SFT.:> BoostedPoint pt' weight'
199+
-- We already track some weight of @pt@.
200+
| pt == pt' -> l' SFT.|> BoostedPoint pt' (weight <> weight')
201+
-- Otherwise, insert @pt@ as a new boosted point.
202+
| otherwise -> l SFT.|> BoostedPoint pt weight
129203

130204
-- | Prune the given 'PerasWeightSnapshot' by removing the weight of all blocks
131205
-- strictly older than the given slot.
@@ -154,11 +228,8 @@ prunePerasWeightSnapshot ::
154228
SlotNo ->
155229
PerasWeightSnapshot blk ->
156230
PerasWeightSnapshot blk
157-
prunePerasWeightSnapshot slot =
158-
PerasWeightSnapshot . Map.dropWhileAntitone isTooOld . getPerasWeightSnapshot
159-
where
160-
isTooOld :: Point blk -> Bool
161-
isTooOld pt = pointSlot pt < NotOrigin slot
231+
prunePerasWeightSnapshot slot (PerasWeightSnapshot ft) =
232+
PerasWeightSnapshot $ SFT.dropUntil (\m -> m.slot >= NotOrigin slot) ft
162233

163234
-- | Get the weight boost for a point, or @'mempty' :: 'PerasWeight'@ otherwise.
164235
--
@@ -183,8 +254,12 @@ weightBoostOfPoint ::
183254
forall blk.
184255
StandardHash blk =>
185256
PerasWeightSnapshot blk -> Point blk -> PerasWeight
186-
weightBoostOfPoint (PerasWeightSnapshot weightByPoint) pt =
187-
Map.findWithDefault mempty pt weightByPoint
257+
weightBoostOfPoint (PerasWeightSnapshot ft) pt =
258+
case SFT.viewr $ SFT.takeUntil (\m -> m.slot > pointSlot pt) ft of
259+
SFT.EmptyR -> mempty
260+
_ SFT.:> BoostedPoint pt' weight'
261+
| pt == pt' -> weight'
262+
| otherwise -> mempty
188263

189264
-- | Get the weight boost for a fragment, ie the sum of all
190265
-- 'weightBoostOfPoint' for all points on the fragment (excluding the anchor).
@@ -230,11 +305,53 @@ weightBoostOfFragment ::
230305
PerasWeightSnapshot blk ->
231306
AnchoredFragment h ->
232307
PerasWeight
233-
weightBoostOfFragment weightSnap frag =
234-
-- TODO think about whether this could be done in sublinear complexity
235-
foldMap
236-
(weightBoostOfPoint weightSnap . castPoint . blockPoint)
237-
(AF.toOldestFirst frag)
308+
weightBoostOfFragment (PerasWeightSnapshot ft) = \case
309+
AF.Empty{} -> mempty
310+
frag@(oldestHdr AF.:< _) -> (measure boostingInfix).weight
311+
where
312+
-- /Not/ @'AF.lastSlot' frag@ as we want to ignore the anchor.
313+
oldestSlot = NotOrigin $ blockSlot oldestHdr
314+
315+
-- The infix of @ft@ which only contains boosted points which are also on
316+
-- @frag@ (via @isOnFrag@).
317+
boostingInfix :: StrictFingerTree PWSMeasure (BoostedPoint blk)
318+
boostingInfix = case SFT.viewr ft' of
319+
SFT.EmptyR -> ft'
320+
t SFT.:> bp
321+
| isOnFrag bp.pt -> ft'
322+
| otherwise -> go 0 (measure ft').size t
323+
where
324+
-- The suffix of @ft@ without boosted points which are too old to be on
325+
-- @frag@.
326+
ft' = SFT.dropUntil (\m -> m.slot >= oldestSlot) ft
327+
328+
-- Binary search on @ft'@ to find the longest prefix of @ft'@ where all
329+
-- boosted points satisfy @isOnFrag@.
330+
--
331+
-- PRECONDITION: @0 <= lb < ub@.
332+
go ::
333+
-- @lb@: All boosted points of the size @lb@ prefix of @ft'@ satisfy
334+
-- @isOnFrag@.
335+
Int ->
336+
-- @ub@: At least one boosted point of the size @ub@ prefix of @ft'@
337+
-- does not satisfy @isOnFrag@.
338+
Int ->
339+
-- The size @ub - 1@ prefix of @ft'@.
340+
StrictFingerTree PWSMeasure (BoostedPoint blk) ->
341+
StrictFingerTree PWSMeasure (BoostedPoint blk)
342+
go lb ub t
343+
| lb == ub - 1 = t
344+
| isOnFrag t'Pt = go mid ub t
345+
| otherwise = go lb mid t'
346+
where
347+
mid = (lb + ub) `div` 2
348+
(t', t'Pt) = case SFT.viewr $ SFT.takeUntil (\m -> m.size > mid) ft' of
349+
t'' SFT.:> bp -> (t'', bp.pt)
350+
-- @ft'@ is non-empty here, and we have @0 <= lb < mid@.
351+
SFT.EmptyR -> error "unreachable"
352+
353+
isOnFrag :: Point blk -> Bool
354+
isOnFrag pt = AF.pointOnFragment (castPoint pt) frag
238355

239356
-- $setup
240357
-- >>> import Ouroboros.Consensus.Block

0 commit comments

Comments
 (0)