Skip to content

Commit 31e1234

Browse files
committed
Make zipWith faster
Make `zipWith` build its result with the structure of its first argument, splitting up its second argument as it goes. This allows fast random access to the elements of the results immediately, without having to build large portions of the structure. It also seems to be slightly faster than the old implementation when the entire result is used, presumably by avoiding rebalancing costs. I believe most of this code will also help implement a fast `(<*>)`. Use the same approach to implement `zipWith3` and `zipWith4`. Clean up a couple warnings. Many thanks to Carter Schonwald for suggesting that I use the structure of the first sequence to structure the result, and for helping me come up with the splitTraverse approach. Benchmarks: Zipping two 100000 element lists and extracting the 50000th element takes about 11.4ms with the new implementation, as opposed to 88ms with the old. Zipping two 10000 element sequences and forcing the result to normal form takes 4.0ms now rather than 19.7ms. The indexing gains show up for even very short sequences, but the new implementation really starts to look good once the size gets to around 1000--presumably it handles cache effects better than the old one. Note that the naive approach of converting sequences to lists, zipping them, and then converting back, actually works very well for forcing short sequences to normal form, even better than the new implementation. But it starts to lose a lot of ground by the time the size gets to around 10000, and its performance on the indexing tests is bad.
1 parent f22d14b commit 31e1234

File tree

1 file changed

+92
-14
lines changed

1 file changed

+92
-14
lines changed

Data/Sequence.hs

Lines changed: 92 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -676,10 +676,10 @@ replicateM n x
676676

677677
-- | @'replicateSeq' n xs@ concatenates @n@ copies of @xs@.
678678
replicateSeq :: Int -> Seq a -> Seq a
679-
replicateSeq n xs
679+
replicateSeq n s
680680
| n < 0 = error "replicateSeq takes a nonnegative integer argument"
681681
| n == 0 = empty
682-
| otherwise = go n xs
682+
| otherwise = go n s
683683
where
684684
-- Invariant: k >= 1
685685
go 1 xs = xs
@@ -1702,6 +1702,75 @@ reverseNode :: (a -> a) -> Node a -> Node a
17021702
reverseNode f (Node2 s a b) = Node2 s (f b) (f a)
17031703
reverseNode f (Node3 s a b c) = Node3 s (f c) (f b) (f a)
17041704

1705+
------------------------------------------------------------------------
1706+
-- Traversing with splittable "state"
1707+
------------------------------------------------------------------------
1708+
1709+
-- For zipping, and probably also for (<*>), it is useful to build a result by
1710+
-- traversing a sequence while splitting up something else. For zipping, we
1711+
-- traverse the first sequence while splitting up the second [and third [and
1712+
-- fourth]]. For fs <*> xs, we expect soon to traverse
1713+
--
1714+
-- > replicate (length fs * length xs) ()
1715+
--
1716+
-- while splitting something essentially equivalent to
1717+
--
1718+
-- > fmap (\f -> fmap f xs) fs
1719+
--
1720+
-- David Feuer, with excellent guidance from Carter Schonwald, December 2014
1721+
1722+
class Splittable s where
1723+
splitState :: Int -> s -> (s,s)
1724+
1725+
instance Splittable (Seq a) where
1726+
splitState = splitAt
1727+
1728+
instance (Splittable a, Splittable b) => Splittable (a, b) where
1729+
splitState i (a, b) = ((al, bl), (ar, br))
1730+
where
1731+
(al, ar) = splitState i a
1732+
(bl, br) = splitState i b
1733+
1734+
splitTraverseSeq :: (Splittable s) => (s -> a -> b) -> s -> Seq a -> Seq b
1735+
splitTraverseSeq f s (Seq xs) = Seq $ splitTraverseTree (\s' (Elem a) -> Elem (f s' a)) s xs
1736+
1737+
splitTraverseTree :: (Sized a, Splittable s) => (s -> a -> b) -> s -> FingerTree a -> FingerTree b
1738+
splitTraverseTree _f _s Empty = Empty
1739+
splitTraverseTree f s (Single xs) = Single $ f s xs
1740+
splitTraverseTree f s (Deep n pr m sf) = Deep n (splitTraverseDigit f prs pr) (splitTraverseTree (splitTraverseNode f) ms m) (splitTraverseDigit f sfs sf)
1741+
where
1742+
(prs, r) = splitState (size pr) s
1743+
(ms, sfs) = splitState (n - size pr - size sf) r
1744+
1745+
splitTraverseDigit :: (Sized a, Splittable s) => (s -> a -> b) -> s -> Digit a -> Digit b
1746+
splitTraverseDigit f s (One a) = One (f s a)
1747+
splitTraverseDigit f s (Two a b) = Two (f first a) (f second b)
1748+
where
1749+
(first, second) = splitState (size a) s
1750+
splitTraverseDigit f s (Three a b c) = Three (f first a) (f second b) (f third c)
1751+
where
1752+
(first, r) = splitState (size a) s
1753+
(second, third) = splitState (size b) r
1754+
splitTraverseDigit f s (Four a b c d) = Four (f first a) (f second b) (f third c) (f fourth d)
1755+
where
1756+
(first, s') = splitState (size a) s
1757+
(middle, fourth) = splitState (size b + size c) s'
1758+
(second, third) = splitState (size b) middle
1759+
1760+
splitTraverseNode :: (Sized a, Splittable s) => (s -> a -> b) -> s -> Node a -> Node b
1761+
splitTraverseNode f s (Node2 ns a b) = Node2 ns (f first a) (f second b)
1762+
where
1763+
(first, second) = splitState (size a) s
1764+
splitTraverseNode f s (Node3 ns a b c) = Node3 ns (f first a) (f second b) (f third c)
1765+
where
1766+
(first, r) = splitState (size a) s
1767+
(second, third) = splitState (size b) r
1768+
1769+
getSingleton :: Seq a -> a
1770+
getSingleton (Seq (Single (Elem a))) = a
1771+
getSingleton (Seq Empty) = error "getSingleton: Empty"
1772+
getSingleton _ = error "getSingleton: Not a singleton."
1773+
17051774
------------------------------------------------------------------------
17061775
-- Zipping
17071776
------------------------------------------------------------------------
@@ -1717,17 +1786,11 @@ zip = zipWith (,)
17171786
-- For example, @zipWith (+)@ is applied to two sequences to take the
17181787
-- sequence of corresponding sums.
17191788
zipWith :: (a -> b -> c) -> Seq a -> Seq b -> Seq c
1720-
zipWith f xs ys
1721-
| length xs <= length ys = zipWith' f xs ys
1722-
| otherwise = zipWith' (flip f) ys xs
1723-
1724-
-- like 'zipWith', but assumes length xs <= length ys
1725-
zipWith' :: (a -> b -> c) -> Seq a -> Seq b -> Seq c
1726-
zipWith' f xs ys = snd (mapAccumL k ys xs)
1789+
zipWith f s1 s2 = splitTraverseSeq (\s a -> f a (getSingleton s)) s2' s1'
17271790
where
1728-
k kys x = case viewl kys of
1729-
(z :< zs) -> (zs, f x z)
1730-
EmptyL -> error "zipWith': unexpected EmptyL"
1791+
minLen = min (length s1) (length s2)
1792+
s1' = take minLen s1
1793+
s2' = take minLen s2
17311794

17321795
-- | /O(min(n1,n2,n3))/. 'zip3' takes three sequences and returns a
17331796
-- sequence of triples, analogous to 'zip'.
@@ -1738,7 +1801,14 @@ zip3 = zipWith3 (,,)
17381801
-- three elements, as well as three sequences and returns a sequence of
17391802
-- their point-wise combinations, analogous to 'zipWith'.
17401803
zipWith3 :: (a -> b -> c -> d) -> Seq a -> Seq b -> Seq c -> Seq d
1741-
zipWith3 f s1 s2 s3 = zipWith ($) (zipWith f s1 s2) s3
1804+
zipWith3 f s1 s2 s3 = splitTraverseSeq (\s a ->
1805+
case s of
1806+
(b, c) -> f a (getSingleton b) (getSingleton c)) (s2', s3') s1'
1807+
where
1808+
minLen = minimum [length s1, length s2, length s3]
1809+
s1' = take minLen s1
1810+
s2' = take minLen s2
1811+
s3' = take minLen s3
17421812

17431813
-- | /O(min(n1,n2,n3,n4))/. 'zip4' takes four sequences and returns a
17441814
-- sequence of quadruples, analogous to 'zip'.
@@ -1749,7 +1819,15 @@ zip4 = zipWith4 (,,,)
17491819
-- four elements, as well as four sequences and returns a sequence of
17501820
-- their point-wise combinations, analogous to 'zipWith'.
17511821
zipWith4 :: (a -> b -> c -> d -> e) -> Seq a -> Seq b -> Seq c -> Seq d -> Seq e
1752-
zipWith4 f s1 s2 s3 s4 = zipWith ($) (zipWith ($) (zipWith f s1 s2) s3) s4
1822+
zipWith4 f s1 s2 s3 s4 = splitTraverseSeq (\s a ->
1823+
case s of
1824+
(b, (c, d)) -> f a (getSingleton b) (getSingleton c) (getSingleton d)) (s2', (s3', s4')) s1'
1825+
where
1826+
minLen = minimum [length s1, length s2, length s3, length s4]
1827+
s1' = take minLen s1
1828+
s2' = take minLen s2
1829+
s3' = take minLen s3
1830+
s4' = take minLen s4
17531831

17541832
------------------------------------------------------------------------
17551833
-- Sorting

0 commit comments

Comments
 (0)