Skip to content

Commit 991ae11

Browse files
authored
Merge pull request #880 from meooow25/tree
Faster Tree folds
2 parents f967d33 + 61f292c commit 991ae11

File tree

4 files changed

+199
-9
lines changed

4 files changed

+199
-9
lines changed

containers-tests/benchmarks/Tree.hs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
module Main where
2+
3+
import Control.DeepSeq (NFData, rnf)
4+
import Control.Exception (evaluate)
5+
import Data.Coerce (coerce)
6+
import Data.Foldable (fold, foldl', toList)
7+
import Data.Monoid (All(..))
8+
import Test.Tasty.Bench (Benchmark, Benchmarkable, bench, bgroup, defaultMain, whnf, nf)
9+
import qualified Data.Tree as T
10+
11+
main :: IO ()
12+
main = do
13+
evaluate $ rnf ts `seq` rnf tsBool
14+
defaultMain
15+
[ bgroup "fold" $ forTs tsBool $ whnf fold . (coerce :: T.Tree Bool -> T.Tree All)
16+
, bgroup "foldMap" $ forTs tsBool $ whnf (foldMap All)
17+
, bgroup "foldr_1" $ forTs tsBool $ whnf (foldr (&&) True)
18+
, bgroup "foldr_2" $ forTs ts $ whnf (length . foldr (:) [])
19+
, bgroup "foldr_3" $ forTs ts $ whnf (\t -> foldr (\x k acc -> if acc < 0 then acc else k $! acc + x) id t 0)
20+
, bgroup "foldl'" $ forTs ts $ whnf (foldl' (+) 0)
21+
, bgroup "foldr1" $ forTs tsBool $ whnf (foldr1 (&&))
22+
, bgroup "foldl1" $ forTs ts $ whnf (foldl1 (+))
23+
, bgroup "toList" $ forTs ts $ nf toList
24+
, bgroup "elem" $ forTs ts $ whnf (elem 0)
25+
, bgroup "maximum" $ forTs ts $ whnf maximum
26+
, bgroup "sum" $ forTs ts $ whnf sum
27+
]
28+
where
29+
ts = [binaryTree, lineTree] <*> [1000, 1000000]
30+
tsBool = [t { getT = True <$ getT t } | t <- ts]
31+
32+
forTs :: [Tree a] -> (T.Tree a -> Benchmarkable) -> [Benchmark]
33+
forTs ts f = [bench label (f t) | Tree label t <- ts]
34+
35+
data Tree a = Tree
36+
{ getLabel :: String
37+
, getT :: T.Tree a
38+
}
39+
40+
instance NFData a => NFData (Tree a) where
41+
rnf (Tree label t) = rnf label `seq` rnf t
42+
43+
binaryTree :: Int -> Tree Int
44+
binaryTree n = Tree label t
45+
where
46+
label = "bin,n=" ++ show n
47+
t = T.unfoldTree (\x -> (x, takeWhile (<=n) [2*x, 2*x+1])) 1
48+
49+
lineTree :: Int -> Tree Int
50+
lineTree n = Tree label t
51+
where
52+
label = "line,n=" ++ show n
53+
t = T.unfoldTree (\x -> (x, [x+1 | x+1 <= n])) 1

containers-tests/containers-tests.cabal

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,14 @@ benchmark map-benchmarks
152152
main-is: Map.hs
153153
ghc-options: -O2
154154

155+
benchmark tree-benchmarks
156+
import: benchmark-deps
157+
default-language: Haskell2010
158+
type: exitcode-stdio-1.0
159+
hs-source-dirs: benchmarks
160+
main-is: Tree.hs
161+
ghc-options: -O2
162+
155163
benchmark sequence-benchmarks
156164
import: benchmark-deps
157165
default-language: Haskell2010

containers-tests/tests/tree-properties.hs

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,37 @@ import Data.Tree as T
55
import Control.Applicative (Const(Const, getConst), pure, (<$>), (<*>), liftA2)
66

77
import Test.Tasty
8+
import Test.Tasty.HUnit
89
import Test.Tasty.QuickCheck
910
import Test.QuickCheck.Function (apply)
10-
import Test.QuickCheck.Poly (A, B, C)
11+
import Test.QuickCheck.Poly (A, B, C, OrdA)
1112
import Control.Monad.Fix (MonadFix (..))
1213
import Control.Monad (ap)
14+
import Data.Foldable (foldl', toList)
15+
import Data.Traversable (foldMapDefault)
1316

1417
default (Int)
1518

1619
main :: IO ()
1720
main = defaultMain $ testGroup "tree-properties"
1821
[
19-
testProperty "monad_id1" prop_monad_id1
22+
testCase "foldr" test_foldr
23+
, testProperty "monad_id1" prop_monad_id1
2024
, testProperty "monad_id2" prop_monad_id2
2125
, testProperty "monad_assoc" prop_monad_assoc
2226
, testProperty "ap_ap" prop_ap_ap
2327
, testProperty "ap_liftA2" prop_ap_liftA2
2428
, testProperty "monadFix_ls" prop_monadFix_ls
29+
, testProperty "toList" prop_toList
30+
, testProperty "foldMap" prop_foldMap
31+
, testProperty "foldl'" prop_foldl'
32+
, testProperty "foldr1" prop_foldr1
33+
, testProperty "foldl1" prop_foldl1
34+
, testProperty "foldr_infinite" prop_foldr_infinite
35+
, testProperty "maximum" prop_maximum
36+
, testProperty "minimum" prop_minimum
37+
, testProperty "sum" prop_sum
38+
, testProperty "product" prop_product
2539
]
2640

2741
{--------------------------------------------------------------------
@@ -52,10 +66,25 @@ instance Arbitrary a => Arbitrary (Tree a) where
5266
shrink = genericShrink
5367
#endif
5468

69+
----------------------------------------------------------------
70+
-- Utilities
71+
----------------------------------------------------------------
72+
73+
data Magma a
74+
= Inj a
75+
| Magma a :* Magma a
76+
deriving (Eq, Show)
77+
5578
----------------------------------------------------------------
5679
-- Unit tests
5780
----------------------------------------------------------------
5881

82+
test_foldr :: Assertion
83+
test_foldr = do
84+
foldr (:) [] (Node 1 []) @?= [1]
85+
foldr (:) [] (Node 1 [Node 2 [Node 3 []]]) @?= [1..3]
86+
foldr (:) [] (Node 1 [Node 2 [Node 3 [], Node 4 []], Node 5 [Node 6 [], Node 7 []]]) @?= [1..7]
87+
5988
----------------------------------------------------------------
6089
-- QuickCheck
6190
----------------------------------------------------------------
@@ -101,3 +130,39 @@ prop_monadFix_ls val ta ti =
101130
f :: (Int -> Int) -> Int -> Tree (Int -> Int)
102131
f q y = let t = apply ti y
103132
in fmap (\w -> fact w q) t
133+
134+
prop_toList :: Tree A -> Property
135+
prop_toList t = toList t === foldr (:) [] t
136+
137+
prop_foldMap :: Tree A -> Property
138+
prop_foldMap t =
139+
foldMap (:[]) t === toList t .&&.
140+
foldMap (:[]) t === foldMapDefault (:[]) t
141+
142+
prop_foldl' :: Tree A -> Property
143+
prop_foldl' t = foldl' (flip (:)) [] t === reverse (toList t)
144+
145+
prop_foldr1 :: Tree A -> Property
146+
prop_foldr1 t = foldr1 (:*) (fmap Inj t) === foldr1 (:*) (map Inj (toList t))
147+
148+
prop_foldl1 :: Tree A -> Property
149+
prop_foldl1 t = foldl1 (:*) (fmap Inj t) === foldl1 (:*) (map Inj (toList t))
150+
151+
prop_foldr_infinite :: NonNegative Int -> Property
152+
prop_foldr_infinite (NonNegative n) =
153+
forAllShow genInf (const "<possibly infinite tree>") $
154+
\t -> length (take n (foldr (:) [] t)) <= n
155+
where
156+
genInf = Node () <$> oneof [listOf genInf, infiniteListOf genInf]
157+
158+
prop_maximum :: Tree OrdA -> Property
159+
prop_maximum t = maximum t === maximum (toList t)
160+
161+
prop_minimum :: Tree OrdA -> Property
162+
prop_minimum t = minimum t === minimum (toList t)
163+
164+
prop_sum :: Tree OrdA -> Property
165+
prop_sum t = sum t === sum (toList t)
166+
167+
prop_product :: Tree OrdA -> Property
168+
prop_product t = product t === product (toList t)

containers/src/Data/Tree.hs

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# LANGUAGE BangPatterns #-}
12
{-# LANGUAGE PatternGuards #-}
23
{-# LANGUAGE CPP #-}
34
#if __GLASGOW_HASKELL__
@@ -52,7 +53,8 @@ module Data.Tree(
5253

5354
) where
5455

55-
import Data.Foldable (toList)
56+
import Data.Foldable (fold, foldl', toList)
57+
import Data.Traversable (foldMapDefault)
5658
import Control.Monad (liftM)
5759
import Control.Monad.Fix (MonadFix (..), fix)
5860
import Data.Sequence (Seq, empty, singleton, (<|), (|>), fromList,
@@ -179,16 +181,59 @@ mfixTree f
179181
[0..] children)
180182

181183
instance Traversable Tree where
182-
traverse f (Node x ts) = liftA2 Node (f x) (traverse (traverse f) ts)
184+
traverse f = go
185+
where go (Node x ts) = liftA2 Node (f x) (traverse go ts)
186+
{-# INLINE traverse #-}
183187

188+
-- | Folds in preorder
189+
190+
-- See Note [Implemented Foldable Tree functions]
184191
instance Foldable Tree where
185-
foldMap f (Node x ts) = f x `mappend` foldMap (foldMap f) ts
192+
fold = foldMap id
193+
{-# INLINABLE fold #-}
194+
195+
foldMap = foldMapDefault
196+
{-# INLINE foldMap #-}
197+
198+
foldr f z = \t -> go t z -- Use a lambda to allow inlining with two arguments
199+
where
200+
go (Node x ts) = f x . foldr (\t k -> go t . k) id ts
201+
-- This is equivalent to the following simpler definition, but has been found to optimize
202+
-- better in benchmarks:
203+
-- go (Node x ts) z' = f x (foldr go z' ts)
204+
{-# INLINE foldr #-}
205+
206+
foldl' f = go
207+
where go !z (Node x ts) = foldl' go (f z x) ts
208+
{-# INLINE foldl' #-}
209+
210+
foldr1 f = go id
211+
where go k (Node x ts) = foldr (\n k' prev -> f prev (go k' n)) k ts x
212+
{-# INLINE foldr1 #-}
213+
214+
foldl1 f (Node x ts) = foldl (foldl f) x ts
186215

187216
null _ = False
188217
{-# INLINE null #-}
189218

190-
toList = flatten
191-
{-# INLINE toList #-}
219+
elem = any . (==)
220+
{-# INLINABLE elem #-}
221+
222+
maximum = foldl1' max
223+
{-# INLINABLE maximum #-}
224+
225+
minimum = foldl1' min
226+
{-# INLINABLE minimum #-}
227+
228+
sum = foldl1' (+)
229+
{-# INLINABLE sum #-}
230+
231+
product = foldl1' (*)
232+
{-# INLINABLE product #-}
233+
234+
foldl1' :: (a -> a -> a) -> Tree a -> a
235+
foldl1' f = \(Node x ts) -> foldl' (foldl' f) x ts
236+
{-# INLINE foldl1' #-}
192237

193238
instance NFData a => NFData (Tree a) where
194239
rnf (Node x ts) = rnf x `seq` rnf ts
@@ -262,8 +307,7 @@ draw (Node x ts0) = lines x ++ drawSubTrees ts0
262307
--
263308
-- > flatten (Node 1 [Node 2 [], Node 3 []]) == [1,2,3]
264309
flatten :: Tree a -> [a]
265-
flatten t = squish t []
266-
where squish (Node x ts) xs = x:Prelude.foldr squish xs ts
310+
flatten = toList
267311

268312
-- | Returns the list of nodes at each level of the tree.
269313
--
@@ -415,3 +459,23 @@ unfoldForestQ f aQ = case viewl aQ of
415459
splitOnto as (_:bs) q = case viewr q of
416460
q' :> a -> splitOnto (a:as) bs q'
417461
EmptyR -> error "unfoldForestQ"
462+
463+
--------------------------------------------------------------------------------
464+
465+
-- Note [Implemented Foldable Tree functions]
466+
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
467+
--
468+
-- Implemented:
469+
--
470+
-- foldMap, foldr, foldl': Basic functions.
471+
-- fold, elem: Implemented same as the default definition, but INLINABLE to
472+
-- allow specialization.
473+
-- foldr1, foldl1, null, maximum, minimum: Implemented more efficiently than
474+
-- defaults since trees are non-empty.
475+
-- sum, product: Implemented as strict left folds. Defaults use the lazy foldMap
476+
-- before base 4.15.1.
477+
--
478+
-- Not implemented:
479+
--
480+
-- foldMap', toList, length: Defaults perform well.
481+
-- foldr', foldl: Unlikely to be used.

0 commit comments

Comments
 (0)