Skip to content

Commit 61f292c

Browse files
committed
Add performant Foldable functions for Data.Tree
Implement Foldable functions for Data.Tree, which only had the bare minimum foldMap so far. Add tests and benchmarks for the added functions. Benchmarks show very good improvements.
1 parent 468aa9d commit 61f292c

File tree

4 files changed

+199
-9
lines changed

4 files changed

+199
-9
lines changed

containers-tests/benchmarks/Tree.hs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
module Main where
2+
3+
import Control.DeepSeq (NFData, rnf)
4+
import Control.Exception (evaluate)
5+
import Data.Coerce (coerce)
6+
import Data.Foldable (fold, foldl', toList)
7+
import Data.Monoid (All(..))
8+
import Test.Tasty.Bench (Benchmark, Benchmarkable, bench, bgroup, defaultMain, whnf, nf)
9+
import qualified Data.Tree as T
10+
11+
main :: IO ()
12+
main = do
13+
evaluate $ rnf ts `seq` rnf tsBool
14+
defaultMain
15+
[ bgroup "fold" $ forTs tsBool $ whnf fold . (coerce :: T.Tree Bool -> T.Tree All)
16+
, bgroup "foldMap" $ forTs tsBool $ whnf (foldMap All)
17+
, bgroup "foldr_1" $ forTs tsBool $ whnf (foldr (&&) True)
18+
, bgroup "foldr_2" $ forTs ts $ whnf (length . foldr (:) [])
19+
, bgroup "foldr_3" $ forTs ts $ whnf (\t -> foldr (\x k acc -> if acc < 0 then acc else k $! acc + x) id t 0)
20+
, bgroup "foldl'" $ forTs ts $ whnf (foldl' (+) 0)
21+
, bgroup "foldr1" $ forTs tsBool $ whnf (foldr1 (&&))
22+
, bgroup "foldl1" $ forTs ts $ whnf (foldl1 (+))
23+
, bgroup "toList" $ forTs ts $ nf toList
24+
, bgroup "elem" $ forTs ts $ whnf (elem 0)
25+
, bgroup "maximum" $ forTs ts $ whnf maximum
26+
, bgroup "sum" $ forTs ts $ whnf sum
27+
]
28+
where
29+
ts = [binaryTree, lineTree] <*> [1000, 1000000]
30+
tsBool = [t { getT = True <$ getT t } | t <- ts]
31+
32+
forTs :: [Tree a] -> (T.Tree a -> Benchmarkable) -> [Benchmark]
33+
forTs ts f = [bench label (f t) | Tree label t <- ts]
34+
35+
data Tree a = Tree
36+
{ getLabel :: String
37+
, getT :: T.Tree a
38+
}
39+
40+
instance NFData a => NFData (Tree a) where
41+
rnf (Tree label t) = rnf label `seq` rnf t
42+
43+
binaryTree :: Int -> Tree Int
44+
binaryTree n = Tree label t
45+
where
46+
label = "bin,n=" ++ show n
47+
t = T.unfoldTree (\x -> (x, takeWhile (<=n) [2*x, 2*x+1])) 1
48+
49+
lineTree :: Int -> Tree Int
50+
lineTree n = Tree label t
51+
where
52+
label = "line,n=" ++ show n
53+
t = T.unfoldTree (\x -> (x, [x+1 | x+1 <= n])) 1

containers-tests/containers-tests.cabal

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,14 @@ benchmark map-benchmarks
152152
main-is: Map.hs
153153
ghc-options: -O2
154154

155+
benchmark tree-benchmarks
156+
import: benchmark-deps
157+
default-language: Haskell2010
158+
type: exitcode-stdio-1.0
159+
hs-source-dirs: benchmarks
160+
main-is: Tree.hs
161+
ghc-options: -O2
162+
155163
benchmark sequence-benchmarks
156164
import: benchmark-deps
157165
default-language: Haskell2010

containers-tests/tests/tree-properties.hs

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,37 @@ import Data.Tree as T
55
import Control.Applicative (Const(Const, getConst), pure, (<$>), (<*>), liftA2)
66

77
import Test.Tasty
8+
import Test.Tasty.HUnit
89
import Test.Tasty.QuickCheck
910
import Test.QuickCheck.Function (apply)
10-
import Test.QuickCheck.Poly (A, B, C)
11+
import Test.QuickCheck.Poly (A, B, C, OrdA)
1112
import Control.Monad.Fix (MonadFix (..))
1213
import Control.Monad (ap)
14+
import Data.Foldable (foldl', toList)
15+
import Data.Traversable (foldMapDefault)
1316

1417
default (Int)
1518

1619
main :: IO ()
1720
main = defaultMain $ testGroup "tree-properties"
1821
[
19-
testProperty "monad_id1" prop_monad_id1
22+
testCase "foldr" test_foldr
23+
, testProperty "monad_id1" prop_monad_id1
2024
, testProperty "monad_id2" prop_monad_id2
2125
, testProperty "monad_assoc" prop_monad_assoc
2226
, testProperty "ap_ap" prop_ap_ap
2327
, testProperty "ap_liftA2" prop_ap_liftA2
2428
, testProperty "monadFix_ls" prop_monadFix_ls
29+
, testProperty "toList" prop_toList
30+
, testProperty "foldMap" prop_foldMap
31+
, testProperty "foldl'" prop_foldl'
32+
, testProperty "foldr1" prop_foldr1
33+
, testProperty "foldl1" prop_foldl1
34+
, testProperty "foldr_infinite" prop_foldr_infinite
35+
, testProperty "maximum" prop_maximum
36+
, testProperty "minimum" prop_minimum
37+
, testProperty "sum" prop_sum
38+
, testProperty "product" prop_product
2539
]
2640

2741
{--------------------------------------------------------------------
@@ -52,10 +66,25 @@ instance Arbitrary a => Arbitrary (Tree a) where
5266
shrink = genericShrink
5367
#endif
5468

69+
----------------------------------------------------------------
70+
-- Utilities
71+
----------------------------------------------------------------
72+
73+
data Magma a
74+
= Inj a
75+
| Magma a :* Magma a
76+
deriving (Eq, Show)
77+
5578
----------------------------------------------------------------
5679
-- Unit tests
5780
----------------------------------------------------------------
5881

82+
test_foldr :: Assertion
83+
test_foldr = do
84+
foldr (:) [] (Node 1 []) @?= [1]
85+
foldr (:) [] (Node 1 [Node 2 [Node 3 []]]) @?= [1..3]
86+
foldr (:) [] (Node 1 [Node 2 [Node 3 [], Node 4 []], Node 5 [Node 6 [], Node 7 []]]) @?= [1..7]
87+
5988
----------------------------------------------------------------
6089
-- QuickCheck
6190
----------------------------------------------------------------
@@ -101,3 +130,39 @@ prop_monadFix_ls val ta ti =
101130
f :: (Int -> Int) -> Int -> Tree (Int -> Int)
102131
f q y = let t = apply ti y
103132
in fmap (\w -> fact w q) t
133+
134+
prop_toList :: Tree A -> Property
135+
prop_toList t = toList t === foldr (:) [] t
136+
137+
prop_foldMap :: Tree A -> Property
138+
prop_foldMap t =
139+
foldMap (:[]) t === toList t .&&.
140+
foldMap (:[]) t === foldMapDefault (:[]) t
141+
142+
prop_foldl' :: Tree A -> Property
143+
prop_foldl' t = foldl' (flip (:)) [] t === reverse (toList t)
144+
145+
prop_foldr1 :: Tree A -> Property
146+
prop_foldr1 t = foldr1 (:*) (fmap Inj t) === foldr1 (:*) (map Inj (toList t))
147+
148+
prop_foldl1 :: Tree A -> Property
149+
prop_foldl1 t = foldl1 (:*) (fmap Inj t) === foldl1 (:*) (map Inj (toList t))
150+
151+
prop_foldr_infinite :: NonNegative Int -> Property
152+
prop_foldr_infinite (NonNegative n) =
153+
forAllShow genInf (const "<possibly infinite tree>") $
154+
\t -> length (take n (foldr (:) [] t)) <= n
155+
where
156+
genInf = Node () <$> oneof [listOf genInf, infiniteListOf genInf]
157+
158+
prop_maximum :: Tree OrdA -> Property
159+
prop_maximum t = maximum t === maximum (toList t)
160+
161+
prop_minimum :: Tree OrdA -> Property
162+
prop_minimum t = minimum t === minimum (toList t)
163+
164+
prop_sum :: Tree OrdA -> Property
165+
prop_sum t = sum t === sum (toList t)
166+
167+
prop_product :: Tree OrdA -> Property
168+
prop_product t = product t === product (toList t)

containers/src/Data/Tree.hs

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# LANGUAGE BangPatterns #-}
12
{-# LANGUAGE PatternGuards #-}
23
{-# LANGUAGE CPP #-}
34
#if __GLASGOW_HASKELL__
@@ -52,7 +53,8 @@ module Data.Tree(
5253

5354
) where
5455

55-
import Data.Foldable (toList)
56+
import Data.Foldable (fold, foldl', toList)
57+
import Data.Traversable (foldMapDefault)
5658
import Control.Monad (liftM)
5759
import Control.Monad.Fix (MonadFix (..), fix)
5860
import Data.Sequence (Seq, empty, singleton, (<|), (|>), fromList,
@@ -179,16 +181,59 @@ mfixTree f
179181
[0..] children)
180182

181183
instance Traversable Tree where
182-
traverse f (Node x ts) = liftA2 Node (f x) (traverse (traverse f) ts)
184+
traverse f = go
185+
where go (Node x ts) = liftA2 Node (f x) (traverse go ts)
186+
{-# INLINE traverse #-}
183187

188+
-- | Folds in preorder
189+
190+
-- See Note [Implemented Foldable Tree functions]
184191
instance Foldable Tree where
185-
foldMap f (Node x ts) = f x `mappend` foldMap (foldMap f) ts
192+
fold = foldMap id
193+
{-# INLINABLE fold #-}
194+
195+
foldMap = foldMapDefault
196+
{-# INLINE foldMap #-}
197+
198+
foldr f z = \t -> go t z -- Use a lambda to allow inlining with two arguments
199+
where
200+
go (Node x ts) = f x . foldr (\t k -> go t . k) id ts
201+
-- This is equivalent to the following simpler definition, but has been found to optimize
202+
-- better in benchmarks:
203+
-- go (Node x ts) z' = f x (foldr go z' ts)
204+
{-# INLINE foldr #-}
205+
206+
foldl' f = go
207+
where go !z (Node x ts) = foldl' go (f z x) ts
208+
{-# INLINE foldl' #-}
209+
210+
foldr1 f = go id
211+
where go k (Node x ts) = foldr (\n k' prev -> f prev (go k' n)) k ts x
212+
{-# INLINE foldr1 #-}
213+
214+
foldl1 f (Node x ts) = foldl (foldl f) x ts
186215

187216
null _ = False
188217
{-# INLINE null #-}
189218

190-
toList = flatten
191-
{-# INLINE toList #-}
219+
elem = any . (==)
220+
{-# INLINABLE elem #-}
221+
222+
maximum = foldl1' max
223+
{-# INLINABLE maximum #-}
224+
225+
minimum = foldl1' min
226+
{-# INLINABLE minimum #-}
227+
228+
sum = foldl1' (+)
229+
{-# INLINABLE sum #-}
230+
231+
product = foldl1' (*)
232+
{-# INLINABLE product #-}
233+
234+
foldl1' :: (a -> a -> a) -> Tree a -> a
235+
foldl1' f = \(Node x ts) -> foldl' (foldl' f) x ts
236+
{-# INLINE foldl1' #-}
192237

193238
instance NFData a => NFData (Tree a) where
194239
rnf (Node x ts) = rnf x `seq` rnf ts
@@ -262,8 +307,7 @@ draw (Node x ts0) = lines x ++ drawSubTrees ts0
262307
--
263308
-- > flatten (Node 1 [Node 2 [], Node 3 []]) == [1,2,3]
264309
flatten :: Tree a -> [a]
265-
flatten t = squish t []
266-
where squish (Node x ts) xs = x:Prelude.foldr squish xs ts
310+
flatten = toList
267311

268312
-- | Returns the list of nodes at each level of the tree.
269313
--
@@ -415,3 +459,23 @@ unfoldForestQ f aQ = case viewl aQ of
415459
splitOnto as (_:bs) q = case viewr q of
416460
q' :> a -> splitOnto (a:as) bs q'
417461
EmptyR -> error "unfoldForestQ"
462+
463+
--------------------------------------------------------------------------------
464+
465+
-- Note [Implemented Foldable Tree functions]
466+
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
467+
--
468+
-- Implemented:
469+
--
470+
-- foldMap, foldr, foldl': Basic functions.
471+
-- fold, elem: Implemented same as the default definition, but INLINABLE to
472+
-- allow specialization.
473+
-- foldr1, foldl1, null, maximum, minimum: Implemented more efficiently than
474+
-- defaults since trees are non-empty.
475+
-- sum, product: Implemented as strict left folds. Defaults use the lazy foldMap
476+
-- before base 4.15.1.
477+
--
478+
-- Not implemented:
479+
--
480+
-- foldMap', toList, length: Defaults perform well.
481+
-- foldr', foldl: Unlikely to be used.

0 commit comments

Comments
 (0)