Skip to content

Commit 6e9767b

Browse files
committed
Add slidingWindowSum
`slidingWindow` typically helps write only functions with complexity around `O(n*k)`, where `n` is the number of elements in the stream and `k` is the size of the window. In many cases, this can be reduced to `O(n)` by looking not at the window itself but instead the sum of that window in some `Semigroup`. This can be used, for example, to implement moving averages such as arithmetic, geometric, or harmonic means.
1 parent d2685a5 commit 6e9767b

File tree

4 files changed

+294
-5
lines changed

4 files changed

+294
-5
lines changed

src/Data/AnnotatedQueue.hs

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
{-# language FunctionalDependencies, ScopedTypeVariables, FlexibleInstances,
2+
BangPatterns, UndecidableInstances #-}
3+
4+
-- | An implementation of Okasaki's implicit queues holding elements of some
5+
-- semigroup. We track the sum of them all. This structure is designed to
6+
-- support efficient *sliding window* algorithms for streams.
7+
--
8+
-- References:
9+
--
10+
-- Hinze, Ralf & Paterson, Ross. (2006). Finger trees: A simple general-purpose
11+
-- data structure. J. Funct. Program.. 16. 197-217. 10.1017/S0956796805005769.
12+
--
13+
-- Okasaki, C. (1998). Purely Functional Data Structures. Cambridge: Cambridge
14+
-- University Press. doi:10.1017/CBO9780511530104
15+
16+
module Data.AnnotatedQueue
17+
( Queue
18+
, ViewL (..)
19+
, empty
20+
, viewl
21+
, drop1
22+
, singleton
23+
, snoc
24+
, measure
25+
) where
26+
27+
import Data.Semigroup (Semigroup (..))
28+
29+
data FDigit a = FOne !a | FTwo !a !a
30+
data RDigit a = RZero | ROne !a
31+
data Node s a = Node !s !a !a
32+
33+
newtype Queue s = Queue (Tree s (Elem s))
34+
instance Semigroup s => Semigroup (Queue s) where
35+
(!t) <> u = case viewl u of
36+
EmptyL -> t
37+
ViewL x xs -> (t `snoc` x) <> xs
38+
instance Semigroup s => Monoid (Queue s) where
39+
mempty = empty
40+
mappend = (<>)
41+
42+
newtype Elem a = Elem a
43+
44+
-- Debit invariant (Okasaki): the middle tree of
45+
-- a Deep node is allowed |pr| - |sf| debits, where
46+
-- pr is the prefix and sf is the suffix.
47+
data Tree s a
48+
= Zero
49+
| One !a
50+
| Two !a !a
51+
| Deep !s !(FDigit a) (Tree s (Node s a)) !(RDigit a)
52+
53+
empty :: Queue s
54+
empty = Queue Zero
55+
56+
singleton :: s -> Queue s
57+
singleton = Queue . One . Elem
58+
59+
snoc :: Semigroup s => Queue s -> s -> Queue s
60+
snoc (Queue t) s = Queue (snocTree t (Elem s))
61+
{-# INLINABLE snoc #-}
62+
63+
measure :: Semigroup s => Queue s -> Maybe s
64+
measure (Queue q) = case q of
65+
Zero -> Nothing
66+
One a -> Just (measure_ a)
67+
Two a b -> Just (measure_ a <> measure_ b)
68+
Deep s _ _ _ -> Just s
69+
{-# INLINABLE measure #-}
70+
71+
class Measurable s a | a -> s where
72+
measure_ :: a -> s
73+
instance Measurable s (Elem s) where
74+
measure_ (Elem x) = x
75+
instance Measurable s (Node s a) where
76+
measure_ (Node s _ _) = s
77+
instance (Semigroup s, Measurable s a) => Measurable s (FDigit a) where
78+
measure_ (FOne a) = measure_ a
79+
measure_ (FTwo a b) = measure_ a <> measure_ b
80+
81+
class SemiMeasurable s a | a -> s where
82+
semimeasure :: s -> a -> s
83+
instance (Semigroup s, Measurable s a) => SemiMeasurable s (RDigit a) where
84+
semimeasure s RZero = s
85+
semimeasure s (ROne a) = s <> measure_ a
86+
instance (Semigroup s, Measurable s a)
87+
=> SemiMeasurable s (Tree s a) where
88+
semimeasure s Zero = s
89+
semimeasure s (One a) = s <> measure_ a
90+
semimeasure s (Two a b) = s <> measure_ a <> measure_ b
91+
semimeasure s (Deep t _ _ _) = s <> t
92+
93+
node
94+
:: (Semigroup s, Measurable s a)
95+
=> a -> a -> Node s a
96+
node a b = Node (measure_ a <> measure_ b) a b
97+
{-# INLINABLE node #-}
98+
99+
deep :: (Semigroup s, Measurable s a) => FDigit a -> Tree s (Node s a) -> RDigit a -> Tree s a
100+
deep pr m sf = Deep (measure_ pr `semimeasure` m `semimeasure` sf) pr m sf
101+
{-# INLINABLE deep #-}
102+
103+
snocTree :: (Measurable s a, Semigroup s) => Tree s a -> a -> Tree s a
104+
-- Note: in the last case we depart slightly from Okasaki. Following Hinze
105+
-- and Paterson, we force the *old* middle immediately to prevent a chain of
106+
-- thunks from accumulating in case of multiple sequential snocs.
107+
snocTree Zero a = One a
108+
snocTree (One a) b = Two a b
109+
snocTree (Two a b) c = Deep (measure_ a <> measure_ b <> measure_ c) (FTwo a b) Zero (ROne c)
110+
snocTree (Deep s pr m RZero) q = Deep (s <> measure_ q) pr m (ROne q)
111+
snocTree (Deep s pr !m (ROne p)) !q
112+
= Deep (s <> measure_ q) pr (snocTree m (node p q)) RZero
113+
{-# INLINABLE snocTree #-}
114+
115+
{-
116+
Theorem: snocTree runs in O(1) amortized time.
117+
118+
Proof:
119+
120+
We show that snocTree costs at most 2 units of work.
121+
122+
Reminder: The debit invariant allows the middle tree of a Deep
123+
node |pr| - |sf| debits.
124+
125+
The first three cases are trivial as they don't have any
126+
debits in their inputs or outputs.
127+
128+
In the fourth case (Deep s pr m RZero), the debit allowance on `m` drops by 1.
129+
We do 1 unit of unshared work and pay off one debit on `m`, for a total of 2
130+
units of work.
131+
132+
In the last case (Deep s pr m (ROne p)), we have two possibilities, depending
133+
on the prefix:
134+
135+
1. The prefix has one element. Then the debit allowance on `m` is 0. We force
136+
`m` (for free). We do 1 unit of unshared work. We create a suspension for the
137+
recursive call and place 2 debits on it to pay for that. Since the debit
138+
allowance for the result middle only allows 1 debit, we pay one of them off
139+
now. So the amortized cost is 2.
140+
141+
2. The prefix has two elements. Then the debit allowance on `m` is 1. We pay
142+
off that debit and force `m`. We do 1 unit of unshared work. We create a
143+
suspension for the recursive call and place 2 debits on it. This is within the
144+
debit allowance for the result middle. So the amortized cost is 2.
145+
-}
146+
147+
data ViewL s = EmptyL | ViewL !s (Queue s)
148+
149+
-- Note: we need the ViewLTree constructor to be lazy in the
150+
-- tail to maintain the right amortized bounds. We include
151+
-- the measure of a nonempty tree in its view because we
152+
-- need that in the recursive case of viewlTree.
153+
data ViewLTree s a = EmptyLTree | ViewLTree !s !a (Tree s a)
154+
155+
viewl :: Semigroup s => Queue s -> ViewL s
156+
-- We could write a separate version for this top layer to avoid unnecessarily
157+
-- calculating a sum in the Two case.
158+
viewl (Queue q) = case viewlTree q of
159+
EmptyLTree -> EmptyL
160+
ViewLTree _ (Elem s) q' -> ViewL s (Queue q')
161+
{-# INLINABLE viewl #-}
162+
163+
viewlTree :: (Semigroup s, Measurable s a) => Tree s a -> ViewLTree s a
164+
-- Important note: we produce the head before forcing the tail. This
165+
-- is key to maintaining O(1) amortized time here.
166+
viewlTree Zero = EmptyLTree
167+
viewlTree (One a) = ViewLTree (measure_ a) a Zero
168+
viewlTree (Two a b) = ViewLTree (measure_ a <> measure_ b) a (One b)
169+
viewlTree (Deep s (FTwo a b) m sf) = ViewLTree s a (deep (FOne b) m sf)
170+
viewlTree (Deep s (FOne a) m sf) = ViewLTree s a $ case viewlTree m of
171+
EmptyLTree -> case sf of
172+
RZero -> Zero
173+
ROne b -> One b
174+
ViewLTree sm (Node p b c) m' -> Deep (sm `semimeasure` sf) (FTwo b c) m' sf
175+
{-# INLINABLE viewlTree #-}
176+
177+
{-
178+
Theorem: drop1 runs in O(1) amortized time.
179+
180+
Proof. We follow the general outline of Okasaki Theorem 11.1, adjusting for the
181+
need to measure (and therefore force) certain suspended middle trees in the
182+
fourth case.
183+
184+
The short version: everything is the same as in Okasaki, but if the recursive
185+
viewing reaches an FOne digit, we need to discharge up to two debits on the
186+
tree middle there, adding just a constant amount to the amortized cost of
187+
the operation.
188+
189+
The long version, in lots of detail:
190+
191+
This particular proof doesn't make use of the "debit passing" concept, because
192+
we seem to be able to get away without it. We will analyze `drop1` as taking 3
193+
units of work. When reading this proof, it may be helpful to mentally imagine
194+
breaking down `viewlTree` into `headTree` and `drop1Tree`, much like Okasaki
195+
does.
196+
197+
The first three cases are trivial, with no debits on inputs or outputs, so we
198+
can assign them each a cost of 1.
199+
200+
In the fourth case (an FTwo digit), we may have up to 2 debits on `m` we must
201+
discharge so we can measure it in `deep`, plus 1 unit of unshared work, for
202+
a total of 3.
203+
204+
In the fifth case (an FOne digit), we have two possibilities:
205+
206+
The suffix is RZero: We may have up to 1 debit on `m`, which we discharge to
207+
view it. We do 1 unit of unshared work. If `m` is nonempty, we create a
208+
suspension to take its tail `m'`, and by the inductive hypothesis create 3
209+
debits to cover that. We place two of them on `m'` and discharge the third. So
210+
the amortized cost is 3.
211+
212+
The suffix is ROne: There are no debits on `m`, so we can view it immediately.
213+
We do one unit of unshared work. If `m` is nonempty, we create a suspension to
214+
take its tail `m'`, and create 3 debits to cover that. We place one debit on
215+
`m'` and discharge the other two. The amortized cost is 3.
216+
-}
217+
218+
drop1 :: Semigroup s => Queue s -> Queue s
219+
drop1 q = case viewl q of
220+
EmptyL -> empty
221+
ViewL _ q' -> q'
222+
{-
223+
-- We could expand out the upper layer to avoid an unnecessary view allocation.
224+
-- Is that worth the extra code size?
225+
-}
226+
{-# INLINABLE drop1 #-}

src/Streaming/Prelude.hs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ module Streaming.Prelude (
134134
, show
135135
, cons
136136
, slidingWindow
137+
, slidingWindowSum
137138
, slidingWindowMin
138139
, slidingWindowMinBy
139140
, slidingWindowMinOn
@@ -272,8 +273,10 @@ import Data.Functor.Of
272273
import Data.Functor.Sum
273274
import Data.Monoid (Monoid (mappend, mempty))
274275
import Data.Ord (Ordering (..), comparing)
276+
import Data.Semigroup (Semigroup (..))
275277
import Foreign.C.Error (Errno(Errno), ePIPE)
276278
import Text.Read (readMaybe)
279+
import qualified Data.AnnotatedQueue as AQ
277280
import qualified Data.Foldable as Foldable
278281
import qualified Data.IntSet as IntSet
279282
import qualified Data.Sequence as Seq
@@ -2846,7 +2849,7 @@ mapMaybe phi = loop where
28462849
{-# INLINABLE mapMaybe #-}
28472850

28482851
{-| 'slidingWindow' accumulates the first @n@ elements of a stream,
2849-
update thereafter to form a sliding window of length @n@.
2852+
updating thereafter to form a sliding window of length @n@.
28502853
It follows the behavior of the slidingWindow function in
28512854
<https://hackage.haskell.org/package/conduit-combinators-1.0.4/docs/Data-Conduit-Combinators.html#v:slidingWindow conduit-combinators>.
28522855
@@ -2880,6 +2883,33 @@ slidingWindow n = setup (max 1 n :: Int) mempty
28802883
Right (x,rest) -> setup (m-1) (sequ Seq.|> x) rest
28812884
{-# INLINABLE slidingWindow #-}
28822885

2886+
{-| 'slidingWindowSum' accumulates the first @n@ elements of a stream
2887+
with elements in some 'Semigroup',
2888+
updating thereafter to form a sliding window of length @n@.
2889+
-}
2890+
slidingWindowSum :: (Monad m, Semigroup a)
2891+
=> Int
2892+
-> Stream (Of a) m b
2893+
-> Stream (Of a) m b
2894+
slidingWindowSum n = setup (max 1 n) AQ.empty
2895+
where
2896+
window !qu str = do
2897+
case AQ.measure qu of
2898+
Just s -> yield s
2899+
Nothing -> pure ()
2900+
e <- lift (next str)
2901+
case e of
2902+
Left r -> return r
2903+
Right (a,rest) ->
2904+
window (AQ.drop1 $ qu `AQ.snoc` a) rest
2905+
setup 0 !qu str = window qu str
2906+
setup m !qu str = do
2907+
e <- lift $ next str
2908+
case e of
2909+
Left r -> window qu (return r)
2910+
Right (x,rest) -> setup (m-1) (qu `AQ.snoc` x) rest
2911+
{-# INLINABLE slidingWindowSum #-}
2912+
28832913
-- | 'slidingWindowMin' finds the minimum in every sliding window of @n@
28842914
-- elements of a stream. If within a window there are multiple elements that are
28852915
-- the least, it prefers the first occurrence (if you prefer to have the last

streaming.cabal

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ library
204204
, Streaming.Prelude
205205
, Streaming.Internal
206206
, Data.Functor.Of
207+
other-modules:
208+
Data.AnnotatedQueue
207209
build-depends:
208210
base >=4.8 && <5
209211
, mtl >=2.1 && <2.3

0 commit comments

Comments
 (0)