Skip to content

Commit 38b1b81

Browse files
committed
Reimplement <*>
Use `coerce` for the `Functor` instance of `Elem` Using `fmap = coerce` for `Elem` speeds up `<*>` by somewhere around 20%. Benchmark results: OLD: benchmarking <*>/ix1000/500000 time 11.47 ms (11.37 ms .. 11.59 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 11.61 ms (11.52 ms .. 11.73 ms) std dev 279.9 μs (209.5 μs .. 385.6 μs) benchmarking <*>/nf100/2500/rep time 8.530 ms (8.499 ms .. 8.568 ms) 1.000 R² (1.000 R² .. 1.000 R²) mean 8.511 ms (8.498 ms .. 8.528 ms) std dev 40.40 μs (28.55 μs .. 63.84 μs) benchmarking <*>/nf100/2500/ff time 27.13 ms (26.16 ms .. 28.70 ms) 0.994 R² (0.988 R² .. 1.000 R²) mean 26.49 ms (26.29 ms .. 27.43 ms) std dev 697.1 μs (153.0 μs .. 1.443 ms) benchmarking <*>/nf500/500/rep time 8.421 ms (8.331 ms .. 8.491 ms) 0.991 R² (0.967 R² .. 1.000 R²) mean 8.518 ms (8.417 ms .. 9.003 ms) std dev 529.9 μs (40.37 μs .. 1.176 ms) variance introduced by outliers: 32% (moderately inflated) benchmarking <*>/nf500/500/ff time 33.71 ms (33.58 ms .. 33.86 ms) 1.000 R² (1.000 R² .. 1.000 R²) mean 33.69 ms (33.62 ms .. 33.76 ms) std dev 150.0 μs (119.0 μs .. 191.0 μs) benchmarking <*>/nf2500/100/rep time 8.390 ms (8.259 ms .. 8.456 ms) 0.997 R² (0.992 R² .. 1.000 R²) mean 8.544 ms (8.441 ms .. 8.798 ms) std dev 402.6 μs (21.25 μs .. 714.9 μs) variance introduced by outliers: 23% (moderately inflated) benchmarking <*>/nf2500/100/ff time 53.69 ms (53.33 ms .. 54.08 ms) 1.000 R² (1.000 R² .. 1.000 R²) mean 53.59 ms (53.38 ms .. 53.75 ms) std dev 341.2 μs (231.7 μs .. 473.9 μs) NEW benchmarking <*>/ix1000/500000 time 2.688 μs (2.607 μs .. 2.798 μs) 0.994 R² (0.988 R² .. 1.000 R²) mean 2.632 μs (2.607 μs .. 2.715 μs) std dev 129.9 ns (65.93 ns .. 242.8 ns) variance introduced by outliers: 64% (severely inflated) benchmarking <*>/nf100/2500/rep time 8.371 ms (8.064 ms .. 8.535 ms) 0.983 R² (0.947 R² .. 1.000 R²) mean 8.822 ms (8.590 ms .. 9.463 ms) std dev 991.2 μs (381.3 μs .. 1.809 ms) variance introduced by outliers: 61% (severely inflated) benchmarking <*>/nf100/2500/ff time 22.84 ms (22.74 ms .. 22.94 ms) 1.000 R² (1.000 R² .. 1.000 R²) mean 22.78 ms (22.71 ms .. 22.86 ms) std dev 183.3 μs (116.3 μs .. 291.3 μs) benchmarking <*>/nf500/500/rep time 8.320 ms (8.102 ms .. 8.514 ms) 0.995 R² (0.990 R² .. 0.999 R²) mean 8.902 ms (8.675 ms .. 9.407 ms) std dev 952.4 μs (435.5 μs .. 1.672 ms) variance introduced by outliers: 58% (severely inflated) benchmarking <*>/nf500/500/ff time 24.50 ms (24.41 ms .. 24.58 ms) 1.000 R² (1.000 R² .. 1.000 R²) mean 24.44 ms (24.41 ms .. 24.48 ms) std dev 75.08 μs (50.16 μs .. 111.3 μs) benchmarking <*>/nf2500/100/rep time 8.419 ms (8.366 ms .. 8.458 ms) 1.000 R² (1.000 R² .. 1.000 R²) mean 8.571 ms (8.525 ms .. 8.670 ms) std dev 179.5 μs (112.0 μs .. 278.1 μs) benchmarking <*>/nf2500/100/ff time 24.14 ms (24.07 ms .. 24.26 ms) 1.000 R² (1.000 R² .. 1.000 R²) mean 24.11 ms (24.07 ms .. 24.17 ms) std dev 103.8 μs (68.34 μs .. 142.0 μs)
1 parent ae97ceb commit 38b1b81

File tree

1 file changed

+258
-3
lines changed

1 file changed

+258
-3
lines changed

Data/Sequence.hs

Lines changed: 258 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ import Data.Functor.Identity (Identity(..))
194194

195195
infixr 5 `consTree`
196196
infixl 5 `snocTree`
197+
infixr 5 `appendTree0`
197198

198199
infixr 5 ><
199200
infixr 5 <|, :<
@@ -258,10 +259,255 @@ instance Monad Seq where
258259

259260
instance Applicative Seq where
260261
pure = singleton
261-
fs <*> xs = foldl' add empty fs
262+
263+
Seq Empty <*> xs = xs `seq` empty
264+
fs <*> Seq Empty = fs `seq` empty
265+
fs <*> Seq (Single (Elem x)) = fmap ($ x) fs
266+
fs <*> xs
267+
| length fs < 4 = foldl' add empty fs
262268
where add ys f = ys >< fmap f xs
269+
fs <*> xs | length xs < 4 = apShort fs xs
270+
fs <*> xs = apty fs xs
271+
263272
xs *> ys = replicateSeq (length xs) ys
264273

274+
-- <*> when the length of the first argument is at least two and
275+
-- the length of the second is two or three.
276+
apShort :: Seq (a -> b) -> Seq a -> Seq b
277+
apShort (Seq fs) xs = Seq $ case toList xs of
278+
[a,b] -> ap2FT fs (a,b)
279+
[a,b,c] -> ap3FT fs (a,b,c)
280+
_ -> error "apShort: not 2-6"
281+
282+
ap2FT :: FingerTree (Elem (a->b)) -> (a,a) -> FingerTree (Elem b)
283+
ap2FT fs (x,y) = Deep (size fs * 2)
284+
(Two (Elem $ firstf x) (Elem $ firstf y))
285+
(mapMulFT 2 (\(Elem f) -> Node2 2 (Elem (f x)) (Elem (f y))) m)
286+
(Two (Elem $ lastf x) (Elem $ lastf y))
287+
where
288+
(Elem firstf, m, Elem lastf) = trimTree fs
289+
290+
ap3FT :: FingerTree (Elem (a->b)) -> (a,a,a) -> FingerTree (Elem b)
291+
ap3FT fs (x,y,z) = Deep (size fs * 3)
292+
(Three (Elem $ firstf x) (Elem $ firstf y) (Elem $ firstf z))
293+
(mapMulFT 3 (\(Elem f) -> Node3 3 (Elem (f x)) (Elem (f y)) (Elem (f z))) m)
294+
(Three (Elem $ lastf x) (Elem $ lastf y) (Elem $ lastf z))
295+
where
296+
(Elem firstf, m, Elem lastf) = trimTree fs
297+
298+
-- <*> when the length of each argument is at least four.
299+
apty :: Seq (a -> b) -> Seq a -> Seq b
300+
apty (Seq fs) (Seq xs@Deep{}) = Seq $
301+
runApState (fmap firstf) (fmap lastf) fmap fs' (ApState xs' xs' xs')
302+
where
303+
(Elem firstf, fs', Elem lastf) = trimTree fs
304+
xs' = rigidify xs
305+
apty _ _ = error "apty: expects a Deep constructor"
306+
307+
data ApState a = ApState (FingerTree a) (FingerTree a) (FingerTree a)
308+
309+
-- | 'runApState' uses three copies of the @xs@ tree to produce the @fs<*>xs@
310+
-- tree. It pulls left digits off the left tree, right digits off the right tree,
311+
-- and squashes down the other four digits. Once it gets to the bottom, it turns
312+
-- the middle tree into a 2-3 tree, applies 'mapMulFT' to produce the main body,
313+
-- and glues all the pieces together.
314+
runApState
315+
:: Sized c =>
316+
(c -> d)
317+
-> (c -> d)
318+
-> ((a -> b) -> c -> d)
319+
-> FingerTree (Elem (a -> b))
320+
-> ApState c
321+
-> FingerTree d
322+
-- Not at the bottom yet
323+
runApState firstf
324+
lastf
325+
map23
326+
fs
327+
(ApState
328+
(Deep sl
329+
prl
330+
(Deep sml prml mml sfml)
331+
sfl)
332+
(Deep sm
333+
prm
334+
(Deep _smm prmm mmm sfmm)
335+
sfm)
336+
(Deep sr
337+
prr
338+
(Deep smr prmr mmr sfmr)
339+
sfr))
340+
= Deep (sl + sr + sm * size fs)
341+
(fmap firstf prl)
342+
(runApState (fmap firstf)
343+
(fmap lastf)
344+
(\f -> fmap (map23 f))
345+
fs
346+
nextState)
347+
(fmap lastf sfr)
348+
where nextState =
349+
ApState
350+
(Deep (sml + size sfl) prml mml (squashR sfml sfl))
351+
(Deep sm (squashL prm prmm) mmm (squashR sfmm sfm))
352+
(Deep (smr + size prr) (squashL prr prmr) mmr sfmr)
353+
354+
-- At the bottom
355+
runApState firstf
356+
lastf
357+
map23
358+
fs
359+
(ApState
360+
(Deep sl prl ml sfl)
361+
(Deep sm prm mm sfm)
362+
(Deep sr prr mr sfr))
363+
= Deep (sl + sr + sm * size fs)
364+
(fmap firstf prl)
365+
((fmap (fmap firstf) ml `snocTree` fmap firstf (digitToNode sfl))
366+
`appendTree0` middle `appendTree0`
367+
(fmap lastf (digitToNode prr) `consTree` fmap (fmap lastf) mr))
368+
(fmap lastf sfr)
369+
where middle = case trimTree $ mapMulFT sm (\(Elem f) -> fmap (fmap (map23 f)) converted) fs of
370+
(firstMapped, restMapped, lastMapped) ->
371+
Deep (size firstMapped + size restMapped + size lastMapped)
372+
(nodeToDigit firstMapped) restMapped (nodeToDigit lastMapped)
373+
converted = case mm of
374+
Empty -> Node2 sm lconv rconv
375+
Single q -> Node3 sm lconv q rconv
376+
Deep{} -> error "runApState: a tree is shallower than the middle tree"
377+
lconv = digitToNode prm
378+
rconv = digitToNode sfm
379+
380+
runApState _ _ _ _ _ = error "runApState: ApState must hold Deep finger trees of the same depth"
381+
382+
{-# SPECIALIZE
383+
runApState
384+
:: (Node c -> d)
385+
-> (Node c -> d)
386+
-> ((a -> b) -> Node c -> d)
387+
-> FingerTree (Elem (a -> b))
388+
-> ApState (Node c)
389+
-> FingerTree d
390+
#-}
391+
{-# SPECIALIZE
392+
runApState
393+
:: (Elem c -> d)
394+
-> (Elem c -> d)
395+
-> ((a -> b) -> Elem c -> d)
396+
-> FingerTree (Elem (a -> b))
397+
-> ApState (Elem c)
398+
-> FingerTree d
399+
#-}
400+
401+
digitToNode :: Sized a => Digit a -> Node a
402+
digitToNode (Two a b) = node2 a b
403+
digitToNode (Three a b c) = node3 a b c
404+
digitToNode _ = error "digitToNode: not representable as a node"
405+
406+
type Digit23 = Digit
407+
type Digit12 = Digit
408+
409+
-- Squash the first argument down onto the left side of the second.
410+
squashL :: Sized a => Digit23 a -> Digit12 (Node a) -> Digit23 (Node a)
411+
squashL (Two a b) (One n) = Two (node2 a b) n
412+
squashL (Two a b) (Two n1 n2) = Three (node2 a b) n1 n2
413+
squashL (Three a b c) (One n) = Two (node3 a b c) n
414+
squashL (Three a b c) (Two n1 n2) = Three (node3 a b c) n1 n2
415+
squashL _ _ = error "squashL: wrong digit types"
416+
417+
-- Squash the second argument down onto the right side of the first
418+
squashR :: Sized a => Digit12 (Node a) -> Digit23 a -> Digit23 (Node a)
419+
squashR (One n) (Two a b) = Two n (node2 a b)
420+
squashR (Two n1 n2) (Two a b) = Three n1 n2 (node2 a b)
421+
squashR (One n) (Three a b c) = Two n (node3 a b c)
422+
squashR (Two n1 n2) (Three a b c) = Three n1 n2 (node3 a b c)
423+
squashR _ _ = error "squashR: wrong digit types"
424+
425+
-- | /O(m*n)/ (incremental) Takes an /O(m)/ function and a finger tree of size
426+
-- /n/ and maps the function over the tree leaves. Unlike the usual 'fmap', the
427+
-- function is applied to the "leaves" of the 'FingerTree' (i.e., given a
428+
-- @FingerTree (Elem a)@, it applies the function to elements of type @Elem
429+
-- a@), replacing the leaves with subtrees of at least the same height, e.g.,
430+
-- @Node(Node(Elem y))@. The multiplier argument serves to make the annotations
431+
-- match up properly.
432+
mapMulFT :: Int -> (a -> b) -> FingerTree a -> FingerTree b
433+
mapMulFT _ _ Empty = Empty
434+
mapMulFT _mul f (Single a) = Single (f a)
435+
mapMulFT mul f (Deep s pr m sf) = Deep (mul * s) (fmap f pr) (mapMulFT mul (mapMulNode mul f) m) (fmap f sf)
436+
437+
mapMulNode :: Int -> (a -> b) -> Node a -> Node b
438+
mapMulNode mul f (Node2 s a b) = Node2 (mul * s) (f a) (f b)
439+
mapMulNode mul f (Node3 s a b c) = Node3 (mul * s) (f a) (f b) (f c)
440+
441+
442+
trimTree :: Sized a => FingerTree a -> (a, FingerTree a, a)
443+
trimTree Empty = error "trim: empty tree"
444+
trimTree Single{} = error "trim: singleton"
445+
trimTree t = case splitTree 0 t of
446+
Split _ hd r ->
447+
case splitTree (size r - 1) r of
448+
Split m tl _ -> (hd, m, tl)
449+
450+
-- | /O(log n)/ (incremental) Takes the extra flexibility out of a 'FingerTree'
451+
-- to make it a genuine 2-3 finger tree. The result of 'rigidify' will have
452+
-- only 'Two' and 'Three' digits at the top level and only 'One' and 'Two'
453+
-- digits elsewhere. It gives an error if the tree has fewer than four
454+
-- elements.
455+
rigidify :: Sized a => FingerTree a -> FingerTree a
456+
-- Note that 'rigidify' may call itself, but it will do so at most
457+
-- once: each call to 'rigidify' will either fix the whole tree or fix one digit
458+
-- and leave the other alone. The patterns below just fix up the top level of
459+
-- the tree; 'rigidify' delegates the hard work to 'thin'.
460+
461+
-- The top of the tree is fine.
462+
rigidify (Deep s pr@Two{} m sf@Three{}) = Deep s pr (thin m) sf
463+
rigidify (Deep s pr@Three{} m sf@Three{}) = Deep s pr (thin m) sf
464+
rigidify (Deep s pr@Two{} m sf@Two{}) = Deep s pr (thin m) sf
465+
rigidify (Deep s pr@Three{} m sf@Two{}) = Deep s pr (thin m) sf
466+
467+
-- One of the Digits is a Four.
468+
rigidify (Deep s (Four a b c d) m sf) =
469+
rigidify $ Deep s (Two a b) (node2 c d `consTree` m) sf
470+
rigidify (Deep s pr m (Four a b c d)) =
471+
rigidify $ Deep s pr (m `snocTree` node2 a b) (Two c d)
472+
473+
-- One of the Digits is a One. If the middle is empty, we can only rigidify the
474+
-- tree if the other Digit is a Three.
475+
rigidify (Deep s (One a) Empty (Three b c d)) = Deep s (Two a b) Empty (Two c d)
476+
rigidify (Deep s (One a) m sf) = rigidify $ case viewLTree m of
477+
Just2 (Node2 _ b c) m' -> Deep s (Three a b c) m' sf
478+
Just2 (Node3 _ b c d) m' -> Deep s (Two a b) (node2 c d `consTree` m') sf
479+
Nothing2 -> error "rigidify: small tree"
480+
rigidify (Deep s (Three a b c) Empty (One d)) = Deep s (Two a b) Empty (Two c d)
481+
rigidify (Deep s pr m (One e)) = rigidify $ case viewRTree m of
482+
Just2 m' (Node2 _ a b) -> Deep s pr m' (Three a b e)
483+
Just2 m' (Node3 _ a b c) -> Deep s pr (m' `snocTree` node2 a b) (Two c e)
484+
Nothing2 -> error "rigidify: small tree"
485+
rigidify Empty = error "rigidify: empty tree"
486+
rigidify Single{} = error "rigidify: singleton"
487+
488+
-- | /O(log n)/ (incremental) Rejigger a finger tree so the digits are all ones
489+
-- and twos.
490+
thin :: Sized a => FingerTree a -> FingerTree a
491+
-- Note that 'thin' may call itself at most once before passing the job on to
492+
-- 'thin12'. 'thin12' will produce a 'Deep' constructor immediately before
493+
-- calling 'thin'.
494+
thin Empty = Empty
495+
thin (Single a) = Single a
496+
thin t@(Deep s pr m sf) =
497+
case pr of
498+
One{} -> thin12 t
499+
Two{} -> thin12 t
500+
Three a b c -> thin $ Deep s (One a) (node2 b c `consTree` m) sf
501+
Four a b c d -> thin $ Deep s (Two a b) (node2 c d `consTree` m) sf
502+
503+
thin12 :: Sized a => FingerTree a -> FingerTree a
504+
thin12 (Deep s pr m sf@One{}) = Deep s pr (thin m) sf
505+
thin12 (Deep s pr m sf@Two{}) = Deep s pr (thin m) sf
506+
thin12 (Deep s pr m (Three a b c)) = Deep s pr (thin $ m `snocTree` node2 a b) (One c)
507+
thin12 (Deep s pr m (Four a b c d)) = Deep s pr (thin $ m `snocTree` node2 a b) (Two c d)
508+
thin12 _ = error "thin12 expects a Deep FingerTree."
509+
510+
265511
instance MonadPlus Seq where
266512
mzero = empty
267513
mplus = (><)
@@ -559,7 +805,12 @@ instance Sized (Elem a) where
559805
size _ = 1
560806

561807
instance Functor Elem where
808+
#if __GLASGOW_HASKELL__ >= 708
809+
-- This cuts the time for <*> by around a fifth.
810+
fmap = coerce
811+
#else
562812
fmap f (Elem x) = Elem (f x)
813+
#endif
563814

564815
instance Foldable Elem where
565816
foldMap f (Elem x) = f x
@@ -732,7 +983,9 @@ Seq xs >< Seq ys = Seq (appendTree0 xs ys)
732983

733984
-- The appendTree/addDigits gunk below is machine generated
734985

735-
appendTree0 :: FingerTree (Elem a) -> FingerTree (Elem a) -> FingerTree (Elem a)
986+
{-# SPECIALIZE appendTree0 :: FingerTree (Elem a) -> FingerTree (Elem a) -> FingerTree (Elem a) #-}
987+
{-# SPECIALIZE appendTree0 :: FingerTree (Node a) -> FingerTree (Node a) -> FingerTree (Node a) #-}
988+
appendTree0 :: Sized a => FingerTree a -> FingerTree a -> FingerTree a
736989
appendTree0 Empty xs =
737990
xs
738991
appendTree0 xs Empty =
@@ -744,7 +997,9 @@ appendTree0 xs (Single x) =
744997
appendTree0 (Deep s1 pr1 m1 sf1) (Deep s2 pr2 m2 sf2) =
745998
Deep (s1 + s2) pr1 (addDigits0 m1 sf1 pr2 m2) sf2
746999

747-
addDigits0 :: FingerTree (Node (Elem a)) -> Digit (Elem a) -> Digit (Elem a) -> FingerTree (Node (Elem a)) -> FingerTree (Node (Elem a))
1000+
{-# SPECIALIZE addDigits0 :: FingerTree (Node (Elem a)) -> Digit (Elem a) -> Digit (Elem a) -> FingerTree (Node (Elem a)) -> FingerTree (Node (Elem a)) #-}
1001+
{-# SPECIALIZE addDigits0 :: FingerTree (Node (Node a)) -> Digit (Node a) -> Digit (Node a) -> FingerTree (Node (Node a)) -> FingerTree (Node (Node a)) #-}
1002+
addDigits0 :: Sized a => FingerTree (Node a) -> Digit a -> Digit a -> FingerTree (Node a) -> FingerTree (Node a)
7481003
addDigits0 m1 (One a) (One b) m2 =
7491004
appendTree1 m1 (node2 a b) m2
7501005
addDigits0 m1 (One a) (Two b c) m2 =

0 commit comments

Comments
 (0)