Skip to content

Commit 2546efe

Browse files
committed
Merge pull request #104 from treeowl/ap
Make <*> fast
2 parents ae97ceb + 41b7cb4 commit 2546efe

File tree

2 files changed

+260
-14
lines changed

2 files changed

+260
-14
lines changed

Data/Sequence.hs

Lines changed: 243 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ import Data.Functor.Identity (Identity(..))
194194

195195
infixr 5 `consTree`
196196
infixl 5 `snocTree`
197+
infixr 5 `appendTree0`
197198

198199
infixr 5 ><
199200
infixr 5 <|, :<
@@ -258,10 +259,236 @@ instance Monad Seq where
258259

259260
instance Applicative Seq where
260261
pure = singleton
261-
fs <*> xs = foldl' add empty fs
262+
263+
Seq Empty <*> xs = xs `seq` empty
264+
fs <*> Seq Empty = fs `seq` empty
265+
fs <*> Seq (Single (Elem x)) = fmap ($ x) fs
266+
fs <*> xs
267+
| length fs < 4 = foldl' add empty fs
262268
where add ys f = ys >< fmap f xs
269+
fs <*> xs | length xs < 4 = apShort fs xs
270+
fs <*> xs = apty fs xs
271+
263272
xs *> ys = replicateSeq (length xs) ys
264273

274+
-- <*> when the length of the first argument is at least two and
275+
-- the length of the second is two or three.
276+
apShort :: Seq (a -> b) -> Seq a -> Seq b
277+
apShort (Seq fs) xs = Seq $ case toList xs of
278+
[a,b] -> ap2FT fs (a,b)
279+
[a,b,c] -> ap3FT fs (a,b,c)
280+
_ -> error "apShort: not 2-3"
281+
282+
ap2FT :: FingerTree (Elem (a->b)) -> (a,a) -> FingerTree (Elem b)
283+
ap2FT fs (x,y) = Deep (size fs * 2)
284+
(Two (Elem $ firstf x) (Elem $ firstf y))
285+
(mapMulFT 2 (\(Elem f) -> Node2 2 (Elem (f x)) (Elem (f y))) m)
286+
(Two (Elem $ lastf x) (Elem $ lastf y))
287+
where
288+
(Elem firstf, m, Elem lastf) = trimTree fs
289+
290+
ap3FT :: FingerTree (Elem (a->b)) -> (a,a,a) -> FingerTree (Elem b)
291+
ap3FT fs (x,y,z) = Deep (size fs * 3)
292+
(Three (Elem $ firstf x) (Elem $ firstf y) (Elem $ firstf z))
293+
(mapMulFT 3 (\(Elem f) -> Node3 3 (Elem (f x)) (Elem (f y)) (Elem (f z))) m)
294+
(Three (Elem $ lastf x) (Elem $ lastf y) (Elem $ lastf z))
295+
where
296+
(Elem firstf, m, Elem lastf) = trimTree fs
297+
298+
-- <*> when the length of each argument is at least four.
299+
apty :: Seq (a -> b) -> Seq a -> Seq b
300+
apty (Seq fs) (Seq xs@Deep{}) = Seq $
301+
Deep (s' * size fs)
302+
(fmap (fmap firstf) pr')
303+
(aptyMiddle (fmap firstf) (fmap lastf) fmap fs' xs')
304+
(fmap (fmap lastf) sf')
305+
where
306+
(Elem firstf, fs', Elem lastf) = trimTree fs
307+
xs'@(Deep s' pr' _m' sf') = rigidify xs
308+
apty _ _ = error "apty: expects a Deep constructor"
309+
310+
-- | 'aptyMiddle' does most of the hard work of computing @fs<*>xs@.
311+
-- It produces the center part of a finger tree, with a prefix corresponding
312+
-- to the prefix of @xs@ and a suffix corresponding to the suffix of @xs@
313+
-- omitted; the missing suffix and prefix are added by the caller.
314+
-- For the recursive call, it squashes the prefix and the suffix into
315+
-- the center tree. Once it gets to the bottom, it turns the tree into
316+
-- a 2-3 tree, applies 'mapMulFT' to produce the main body, and glues all
317+
-- the pieces together.
318+
aptyMiddle
319+
:: Sized c =>
320+
(c -> d)
321+
-> (c -> d)
322+
-> ((a -> b) -> c -> d)
323+
-> FingerTree (Elem (a -> b))
324+
-> FingerTree c
325+
-> FingerTree (Node d)
326+
-- Not at the bottom yet
327+
aptyMiddle firstf
328+
lastf
329+
map23
330+
fs
331+
(Deep s pr (Deep sm prm mm sfm) sf)
332+
= Deep (sm + s * (size fs + 1)) -- note: sm = s - size pr - size sf
333+
(fmap (fmap firstf) prm)
334+
(aptyMiddle (fmap firstf)
335+
(fmap lastf)
336+
(\f -> fmap (map23 f))
337+
fs
338+
(Deep s (squashL pr prm) mm (squashR sfm sf)))
339+
(fmap (fmap lastf) sfm)
340+
341+
-- At the bottom
342+
aptyMiddle firstf
343+
lastf
344+
map23
345+
fs
346+
(Deep s pr m sf)
347+
= (fmap (fmap firstf) m `snocTree` fmap firstf (digitToNode sf))
348+
`appendTree0` middle `appendTree0`
349+
(fmap lastf (digitToNode pr) `consTree` fmap (fmap lastf) m)
350+
where middle = case trimTree $ mapMulFT s (\(Elem f) -> fmap (fmap (map23 f)) converted) fs of
351+
(firstMapped, restMapped, lastMapped) ->
352+
Deep (size firstMapped + size restMapped + size lastMapped)
353+
(nodeToDigit firstMapped) restMapped (nodeToDigit lastMapped)
354+
converted = case m of
355+
Empty -> Node2 s lconv rconv
356+
Single q -> Node3 s lconv q rconv
357+
Deep{} -> error "aptyMiddle: impossible"
358+
lconv = digitToNode pr
359+
rconv = digitToNode sf
360+
361+
aptyMiddle _ _ _ _ _ = error "aptyMiddle: expected Deep finger tree"
362+
363+
{-# SPECIALIZE
364+
aptyMiddle
365+
:: (Node c -> d)
366+
-> (Node c -> d)
367+
-> ((a -> b) -> Node c -> d)
368+
-> FingerTree (Elem (a -> b))
369+
-> FingerTree (Node c)
370+
-> FingerTree (Node d)
371+
#-}
372+
{-# SPECIALIZE
373+
aptyMiddle
374+
:: (Elem c -> d)
375+
-> (Elem c -> d)
376+
-> ((a -> b) -> Elem c -> d)
377+
-> FingerTree (Elem (a -> b))
378+
-> FingerTree (Elem c)
379+
-> FingerTree (Node d)
380+
#-}
381+
382+
digitToNode :: Sized a => Digit a -> Node a
383+
digitToNode (Two a b) = node2 a b
384+
digitToNode (Three a b c) = node3 a b c
385+
digitToNode _ = error "digitToNode: not representable as a node"
386+
387+
type Digit23 = Digit
388+
type Digit12 = Digit
389+
390+
-- Squash the first argument down onto the left side of the second.
391+
squashL :: Sized a => Digit23 a -> Digit12 (Node a) -> Digit23 (Node a)
392+
squashL (Two a b) (One n) = Two (node2 a b) n
393+
squashL (Two a b) (Two n1 n2) = Three (node2 a b) n1 n2
394+
squashL (Three a b c) (One n) = Two (node3 a b c) n
395+
squashL (Three a b c) (Two n1 n2) = Three (node3 a b c) n1 n2
396+
squashL _ _ = error "squashL: wrong digit types"
397+
398+
-- Squash the second argument down onto the right side of the first
399+
squashR :: Sized a => Digit12 (Node a) -> Digit23 a -> Digit23 (Node a)
400+
squashR (One n) (Two a b) = Two n (node2 a b)
401+
squashR (Two n1 n2) (Two a b) = Three n1 n2 (node2 a b)
402+
squashR (One n) (Three a b c) = Two n (node3 a b c)
403+
squashR (Two n1 n2) (Three a b c) = Three n1 n2 (node3 a b c)
404+
squashR _ _ = error "squashR: wrong digit types"
405+
406+
-- | /O(m*n)/ (incremental) Takes an /O(m)/ function and a finger tree of size
407+
-- /n/ and maps the function over the tree leaves. Unlike the usual 'fmap', the
408+
-- function is applied to the "leaves" of the 'FingerTree' (i.e., given a
409+
-- @FingerTree (Elem a)@, it applies the function to elements of type @Elem
410+
-- a@), replacing the leaves with subtrees of at least the same height, e.g.,
411+
-- @Node(Node(Elem y))@. The multiplier argument serves to make the annotations
412+
-- match up properly.
413+
mapMulFT :: Int -> (a -> b) -> FingerTree a -> FingerTree b
414+
mapMulFT _ _ Empty = Empty
415+
mapMulFT _mul f (Single a) = Single (f a)
416+
mapMulFT mul f (Deep s pr m sf) = Deep (mul * s) (fmap f pr) (mapMulFT mul (mapMulNode mul f) m) (fmap f sf)
417+
418+
mapMulNode :: Int -> (a -> b) -> Node a -> Node b
419+
mapMulNode mul f (Node2 s a b) = Node2 (mul * s) (f a) (f b)
420+
mapMulNode mul f (Node3 s a b c) = Node3 (mul * s) (f a) (f b) (f c)
421+
422+
423+
trimTree :: Sized a => FingerTree a -> (a, FingerTree a, a)
424+
trimTree Empty = error "trim: empty tree"
425+
trimTree Single{} = error "trim: singleton"
426+
trimTree t = case splitTree 0 t of
427+
Split _ hd r ->
428+
case splitTree (size r - 1) r of
429+
Split m tl _ -> (hd, m, tl)
430+
431+
-- | /O(log n)/ (incremental) Takes the extra flexibility out of a 'FingerTree'
432+
-- to make it a genuine 2-3 finger tree. The result of 'rigidify' will have
433+
-- only 'Two' and 'Three' digits at the top level and only 'One' and 'Two'
434+
-- digits elsewhere. It gives an error if the tree has fewer than four
435+
-- elements.
436+
rigidify :: Sized a => FingerTree a -> FingerTree a
437+
-- Note that 'rigidify' may call itself, but it will do so at most
438+
-- once: each call to 'rigidify' will either fix the whole tree or fix one digit
439+
-- and leave the other alone. The patterns below just fix up the top level of
440+
-- the tree; 'rigidify' delegates the hard work to 'thin'.
441+
442+
-- The top of the tree is fine.
443+
rigidify (Deep s pr@Two{} m sf@Three{}) = Deep s pr (thin m) sf
444+
rigidify (Deep s pr@Three{} m sf@Three{}) = Deep s pr (thin m) sf
445+
rigidify (Deep s pr@Two{} m sf@Two{}) = Deep s pr (thin m) sf
446+
rigidify (Deep s pr@Three{} m sf@Two{}) = Deep s pr (thin m) sf
447+
448+
-- One of the Digits is a Four.
449+
rigidify (Deep s (Four a b c d) m sf) =
450+
rigidify $ Deep s (Two a b) (node2 c d `consTree` m) sf
451+
rigidify (Deep s pr m (Four a b c d)) =
452+
rigidify $ Deep s pr (m `snocTree` node2 a b) (Two c d)
453+
454+
-- One of the Digits is a One. If the middle is empty, we can only rigidify the
455+
-- tree if the other Digit is a Three.
456+
rigidify (Deep s (One a) Empty (Three b c d)) = Deep s (Two a b) Empty (Two c d)
457+
rigidify (Deep s (One a) m sf) = rigidify $ case viewLTree m of
458+
Just2 (Node2 _ b c) m' -> Deep s (Three a b c) m' sf
459+
Just2 (Node3 _ b c d) m' -> Deep s (Two a b) (node2 c d `consTree` m') sf
460+
Nothing2 -> error "rigidify: small tree"
461+
rigidify (Deep s (Three a b c) Empty (One d)) = Deep s (Two a b) Empty (Two c d)
462+
rigidify (Deep s pr m (One e)) = rigidify $ case viewRTree m of
463+
Just2 m' (Node2 _ a b) -> Deep s pr m' (Three a b e)
464+
Just2 m' (Node3 _ a b c) -> Deep s pr (m' `snocTree` node2 a b) (Two c e)
465+
Nothing2 -> error "rigidify: small tree"
466+
rigidify Empty = error "rigidify: empty tree"
467+
rigidify Single{} = error "rigidify: singleton"
468+
469+
-- | /O(log n)/ (incremental) Rejigger a finger tree so the digits are all ones
470+
-- and twos.
471+
thin :: Sized a => FingerTree a -> FingerTree a
472+
-- Note that 'thin' may call itself at most once before passing the job on to
473+
-- 'thin12'. 'thin12' will produce a 'Deep' constructor immediately before
474+
-- calling 'thin'.
475+
thin Empty = Empty
476+
thin (Single a) = Single a
477+
thin t@(Deep s pr m sf) =
478+
case pr of
479+
One{} -> thin12 t
480+
Two{} -> thin12 t
481+
Three a b c -> thin $ Deep s (One a) (node2 b c `consTree` m) sf
482+
Four a b c d -> thin $ Deep s (Two a b) (node2 c d `consTree` m) sf
483+
484+
thin12 :: Sized a => FingerTree a -> FingerTree a
485+
thin12 (Deep s pr m sf@One{}) = Deep s pr (thin m) sf
486+
thin12 (Deep s pr m sf@Two{}) = Deep s pr (thin m) sf
487+
thin12 (Deep s pr m (Three a b c)) = Deep s pr (thin $ m `snocTree` node2 a b) (One c)
488+
thin12 (Deep s pr m (Four a b c d)) = Deep s pr (thin $ m `snocTree` node2 a b) (Two c d)
489+
thin12 _ = error "thin12 expects a Deep FingerTree."
490+
491+
265492
instance MonadPlus Seq where
266493
mzero = empty
267494
mplus = (><)
@@ -559,7 +786,12 @@ instance Sized (Elem a) where
559786
size _ = 1
560787

561788
instance Functor Elem where
789+
#if __GLASGOW_HASKELL__ >= 708
790+
-- This cuts the time for <*> by around a fifth.
791+
fmap = coerce
792+
#else
562793
fmap f (Elem x) = Elem (f x)
794+
#endif
563795

564796
instance Foldable Elem where
565797
foldMap f (Elem x) = f x
@@ -732,7 +964,9 @@ Seq xs >< Seq ys = Seq (appendTree0 xs ys)
732964

733965
-- The appendTree/addDigits gunk below is machine generated
734966

735-
appendTree0 :: FingerTree (Elem a) -> FingerTree (Elem a) -> FingerTree (Elem a)
967+
{-# SPECIALIZE appendTree0 :: FingerTree (Elem a) -> FingerTree (Elem a) -> FingerTree (Elem a) #-}
968+
{-# SPECIALIZE appendTree0 :: FingerTree (Node a) -> FingerTree (Node a) -> FingerTree (Node a) #-}
969+
appendTree0 :: Sized a => FingerTree a -> FingerTree a -> FingerTree a
736970
appendTree0 Empty xs =
737971
xs
738972
appendTree0 xs Empty =
@@ -744,7 +978,9 @@ appendTree0 xs (Single x) =
744978
appendTree0 (Deep s1 pr1 m1 sf1) (Deep s2 pr2 m2 sf2) =
745979
Deep (s1 + s2) pr1 (addDigits0 m1 sf1 pr2 m2) sf2
746980

747-
addDigits0 :: FingerTree (Node (Elem a)) -> Digit (Elem a) -> Digit (Elem a) -> FingerTree (Node (Elem a)) -> FingerTree (Node (Elem a))
981+
{-# SPECIALIZE addDigits0 :: FingerTree (Node (Elem a)) -> Digit (Elem a) -> Digit (Elem a) -> FingerTree (Node (Elem a)) -> FingerTree (Node (Elem a)) #-}
982+
{-# SPECIALIZE addDigits0 :: FingerTree (Node (Node a)) -> Digit (Node a) -> Digit (Node a) -> FingerTree (Node (Node a)) -> FingerTree (Node (Node a)) #-}
983+
addDigits0 :: Sized a => FingerTree (Node a) -> Digit a -> Digit a -> FingerTree (Node a) -> FingerTree (Node a)
748984
addDigits0 m1 (One a) (One b) m2 =
749985
appendTree1 m1 (node2 a b) m2
750986
addDigits0 m1 (One a) (Two b c) m2 =
@@ -1841,16 +2077,9 @@ reverseNode f (Node3 s a b c) = Node3 s (f c) (f b) (f a)
18412077
-- Mapping with a splittable value
18422078
------------------------------------------------------------------------
18432079

1844-
-- For zipping, and probably also for (<*>), it is useful to build a result by
2080+
-- For zipping, it is useful to build a result by
18452081
-- traversing a sequence while splitting up something else. For zipping, we
1846-
-- traverse the first sequence while splitting up the second [and third [and
1847-
-- fourth]]. For fs <*> xs, we hope to traverse
1848-
--
1849-
-- > replicate (length fs * length xs) ()
1850-
--
1851-
-- while splitting something essentially equivalent to
1852-
--
1853-
-- > fmap (\f -> fmap f xs) fs
2082+
-- traverse the first sequence while splitting up the second.
18542083
--
18552084
-- What makes all this crazy code a good idea:
18562085
--
@@ -1874,8 +2103,8 @@ reverseNode f (Node3 s a b c) = Node3 s (f c) (f b) (f a)
18742103
-- they're actually needed. We do the same thing for Digits (splitting into
18752104
-- between one and four pieces) and Nodes (splitting into two or three). The
18762105
-- ultimate result is that we can index into, or split at, any location in zs
1877-
-- in O((log(min{i,n-i}))^2) time *immediately*, while still being able to
1878-
-- force all the thunks in O(n) time.
2106+
-- in polylogarithmic time *immediately*, while still being able to force all
2107+
-- the thunks in O(n) time.
18792108
--
18802109
-- Benchmark info, and alternatives:
18812110
--

benchmarks/Sequence.hs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
-- > ghc -DTESTING --make -O2 -fforce-recomp -i.. Sequence.hs
22
module Main where
33

4+
import Control.Applicative
45
import Control.DeepSeq
56
import Criterion.Main
67
import Data.List (foldl')
@@ -44,6 +45,22 @@ main = do
4445
, bench "nf1000" $ nf (\s -> S.fromFunction s (+1)) 1000
4546
, bench "nf10000" $ nf (\s -> S.fromFunction s (+1)) 10000
4647
]
48+
, bgroup "<*>"
49+
[ bench "ix1000/500000" $
50+
nf (\s -> ((+) <$> s <*> s) `S.index` (S.length s `div` 2)) (S.fromFunction 1000 (+1))
51+
, bench "nf100/2500/rep" $
52+
nf (\(s,t) -> (,) <$> replicate s () <*> replicate t ()) (100,2500)
53+
, bench "nf100/2500/ff" $
54+
nf (\(s,t) -> (,) <$> S.fromFunction s (+1) <*> S.fromFunction t (*2)) (100,2500)
55+
, bench "nf500/500/rep" $
56+
nf (\(s,t) -> (,) <$> replicate s () <*> replicate t ()) (500,500)
57+
, bench "nf500/500/ff" $
58+
nf (\(s,t) -> (,) <$> S.fromFunction s (+1) <*> S.fromFunction t (*2)) (500,500)
59+
, bench "nf2500/100/rep" $
60+
nf (\(s,t) -> (,) <$> replicate s () <*> replicate t ()) (2500,100)
61+
, bench "nf2500/100/ff" $
62+
nf (\(s,t) -> (,) <$> S.fromFunction s (+1) <*> S.fromFunction t (*2)) (2500,100)
63+
]
4764
]
4865

4966
-- splitAt+append: repeatedly cut the sequence at a random point

0 commit comments

Comments
 (0)