Skip to content

Commit 311d13d

Browse files
committed
Use a fake GADT for sequence folds and traversals
1 parent 3117213 commit 311d13d

File tree

4 files changed

+260
-15
lines changed

4 files changed

+260
-15
lines changed

containers-tests/containers-tests.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ library
106106
Data.Map.Strict.Internal
107107
Data.Sequence
108108
Data.Sequence.Internal
109+
Data.Sequence.Internal.Depth
109110
Data.Sequence.Internal.Sorting
110111
Data.Set
111112
Data.Set.Internal

containers/containers.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ Library
7070
Data.Graph
7171
Data.Sequence
7272
Data.Sequence.Internal
73+
Data.Sequence.Internal.Depth
7374
Data.Sequence.Internal.Sorting
7475
Data.Tree
7576
Utils.Containers.Internal.BitUtil

containers/src/Data/Sequence/Internal.hs

Lines changed: 138 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{- OPTIONS_GHC -ddump-simpl #-}
12
{-# LANGUAGE CPP #-}
23
#include "containers.h"
34
{-# LANGUAGE BangPatterns #-}
@@ -7,6 +8,7 @@
78
{-# LANGUAGE DeriveLift #-}
89
{-# LANGUAGE StandaloneDeriving #-}
910
{-# LANGUAGE FlexibleInstances #-}
11+
{-# LANGUAGE GADTs #-}
1012
{-# LANGUAGE InstanceSigs #-}
1113
{-# LANGUAGE ScopedTypeVariables #-}
1214
{-# LANGUAGE TemplateHaskellQuotes #-}
@@ -193,6 +195,7 @@ module Data.Sequence.Internal (
193195
node2,
194196
node3,
195197
#endif
198+
bongo
196199
) where
197200

198201
import Utils.Containers.Internal.Prelude hiding (
@@ -210,7 +213,7 @@ import Control.Applicative ((<$>), (<**>), Alternative,
210213
import qualified Control.Applicative as Applicative
211214
import Control.DeepSeq (NFData(rnf))
212215
import Control.Monad (MonadPlus(..))
213-
import Data.Monoid (Monoid(..))
216+
import Data.Monoid (Monoid(..), Endo(..), Dual(..))
214217
import Data.Functor (Functor(..))
215218
import Utils.Containers.Internal.State (State(..), execState)
216219
import Data.Foldable (foldr', toList)
@@ -250,6 +253,7 @@ import Data.Functor.Identity (Identity(..))
250253
import Utils.Containers.Internal.StrictPair (StrictPair (..), toPair)
251254
import Control.Monad.Zip (MonadZip (..))
252255
import Control.Monad.Fix (MonadFix (..), fix)
256+
import Data.Sequence.Internal.Depth (Depth_ (..), Depth2_ (..))
253257

254258
default ()
255259

@@ -394,16 +398,38 @@ fmapSeq f (Seq xs) = Seq (fmap (fmap f) xs)
394398
#-}
395399
#endif
396400

401+
--type Depth = Depth_ Elem Node
402+
type Depth = Depth_ Node
403+
type Depth2 = Depth2_ Node
404+
397405
instance Foldable Seq where
398406
#ifdef __GLASGOW_HASKELL__
399407
foldMap :: forall m a. Monoid m => (a -> m) -> Seq a -> m
400-
foldMap = coerce (foldMap :: (Elem a -> m) -> FingerTree (Elem a) -> m)
408+
foldMap f (Seq t0) = foldMapFT Bottom t0
409+
where
410+
foldMapBlob :: Depth (Elem a) t -> t -> m
411+
foldMapBlob Bottom (Elem a) = f a
412+
foldMapBlob (Deeper w) (Node2 _ x y) = foldMapBlob w x <> foldMapBlob w y
413+
foldMapBlob (Deeper w) (Node3 _ x y z) = foldMapBlob w x <> foldMapBlob w y <> foldMapBlob w z
414+
415+
foldMapFT :: Depth (Elem a) t -> FingerTree t -> m
416+
foldMapFT !_ EmptyT = mempty
417+
foldMapFT w (Single t) = foldMapBlob w t
418+
foldMapFT w (Deep _ pr m sf) =
419+
foldMap (foldMapBlob w) pr
420+
<> foldMapFT (Deeper w) m
421+
<> foldMap (foldMapBlob w) sf
401422

402423
foldr :: forall a b. (a -> b -> b) -> b -> Seq a -> b
403-
foldr = coerce (foldr :: (Elem a -> b -> b) -> b -> FingerTree (Elem a) -> b)
424+
-- We define this explicitly so we can inline the foldMap. And we don't
425+
-- define it as a coercion of the FingerTree version because we want users
426+
-- to have the option of (effectively) inlining it explicitly.
427+
foldr f z t = appEndo (GHC.Exts.inline foldMap (coerce f) t) z
404428

405429
foldl :: forall b a. (b -> a -> b) -> b -> Seq a -> b
406-
foldl = coerce (foldl :: (b -> Elem a -> b) -> b -> FingerTree (Elem a) -> b)
430+
-- Should we define this by hand to associate optimally? Or is GHC
431+
-- clever enough to do that for us?
432+
foldl f z t = appEndo (getDual (GHC.Exts.inline foldMap (Dual . Endo . flip f) t)) z
407433

408434
foldr' :: forall a b. (a -> b -> b) -> b -> Seq a -> b
409435
foldr' = coerce (foldr' :: (Elem a -> b -> b) -> b -> FingerTree (Elem a) -> b)
@@ -442,7 +468,37 @@ instance Foldable Seq where
442468
instance Traversable Seq where
443469
#if __GLASGOW_HASKELL__
444470
{-# INLINABLE traverse #-}
445-
#endif
471+
traverse :: forall f a b. Applicative f => (a -> f b) -> Seq a -> f (Seq b)
472+
traverse f (Seq t0) = Seq <$> traverseFT Bottom2 t0
473+
where
474+
traverseFT :: Depth2 (Elem a) t (Elem b) u -> FingerTree t -> f (FingerTree u)
475+
traverseFT !_ EmptyT = pure EmptyT
476+
traverseFT w (Single t) = Single <$> traverseBlob w t
477+
traverseFT w (Deep s pr m sf) = liftA3 (Deep s)
478+
(traverse (traverseBlob w) pr)
479+
(traverseFT (Deeper2 w) m)
480+
(traverse (traverseBlob w) sf)
481+
482+
-- Traverse a 2-3 tree, given its height.
483+
traverseBlob :: Depth2 (Elem a) t (Elem b) u -> t -> f u
484+
traverseBlob Bottom2 (Elem a) = Elem <$> f a
485+
486+
-- We have a special case here to avoid needing to `fmap Elem` over
487+
-- each of the leaves, in case that's not free in the relevant functor.
488+
-- We still end up using extra fmaps for the very first level of the
489+
-- FingerTree and the Seq constructor. While we *could* avoid that,
490+
-- doing so requires a good bit of extra code to save *at most* nine
491+
-- fmap applications for the sequence. It would also save on Depth
492+
-- comparisons, but I doubt that matters very much.
493+
traverseBlob (Deeper2 Bottom2) (Node2 s (Elem x) (Elem y))
494+
= liftA2 (\x' y' -> Node2 s (Elem x') (Elem y')) (f x) (f y)
495+
traverseBlob (Deeper2 Bottom2) (Node3 s (Elem x) (Elem y) (Elem z))
496+
= liftA3 (\x' y' z' -> Node3 s (Elem x') (Elem y') (Elem z'))
497+
(f x) (f y) (f z)
498+
499+
traverseBlob (Deeper2 w) (Node2 s x y) = liftA2 (Node2 s) (traverseBlob w x) (traverseBlob w y)
500+
traverseBlob (Deeper2 w) (Node3 s x y z) = liftA3 (Node3 s) (traverseBlob w x) (traverseBlob w y) (traverseBlob w z)
501+
#else
446502
traverse _ (Seq EmptyT) = pure (Seq EmptyT)
447503
traverse f' (Seq (Single (Elem x'))) =
448504
(\x'' -> Seq (Single (Elem x''))) <$> f' x'
@@ -514,6 +570,7 @@ instance Traversable Seq where
514570
:: Applicative f
515571
=> (Node a -> f (Node b)) -> Node (Node a) -> f (Node (Node b))
516572
traverseNodeN f t = traverse f t
573+
#endif
517574

518575
instance NFData a => NFData (Seq a) where
519576
rnf (Seq xs) = rnf xs
@@ -1078,7 +1135,33 @@ instance Sized a => Sized (FingerTree a) where
10781135
size (Single x) = size x
10791136
size (Deep v _ _ _) = v
10801137

1138+
-- We don't fold FingerTrees directly, but instead coerce them to
1139+
-- Seqs and fold those. This seems backwards! Why do it? We certainly
1140+
-- *could* fold FingerTrees directly, but we'd need a slightly different
1141+
-- version of the Depth GADT to do so. While that's not a big deal,
1142+
-- it is a bit annoying. Note: we need the current version of Depth
1143+
-- to deal with the Sized issues for indexed folds.
10811144
instance Foldable FingerTree where
1145+
#ifdef __GLASGOW_HASKELL__
1146+
foldMap :: forall m a. Monoid m => (a -> m) -> FingerTree a -> m
1147+
foldMap f = foldMapFT Bottom
1148+
where
1149+
foldMapBlob :: Depth a t -> t -> m
1150+
foldMapBlob Bottom a = f a
1151+
foldMapBlob (Deeper w) (Node2 _ x y) = foldMapBlob w x <> foldMapBlob w y
1152+
foldMapBlob (Deeper w) (Node3 _ x y z) = foldMapBlob w x <> foldMapBlob w y <> foldMapBlob w z
1153+
1154+
foldMapFT :: Depth a t -> FingerTree t -> m
1155+
foldMapFT !_ EmptyT = mempty
1156+
foldMapFT w (Single t) = foldMapBlob w t
1157+
foldMapFT w (Deep _ pr m sf) =
1158+
foldMap (foldMapBlob w) pr
1159+
<> foldMapFT (Deeper w) m
1160+
<> foldMap (foldMapBlob w) sf
1161+
1162+
-- foldMap = coerce (foldMap :: (a -> m) -> Seq a -> m)
1163+
{-# INLINABLE foldMap #-}
1164+
#else
10821165
foldMap _ EmptyT = mempty
10831166
foldMap f' (Single x') = f' x'
10841167
foldMap f' (Deep _ pr' m' sf') =
@@ -1105,8 +1188,6 @@ instance Foldable FingerTree where
11051188

11061189
foldMapNodeN :: Monoid m => (Node a -> m) -> Node (Node a) -> m
11071190
foldMapNodeN f t = foldNode (<>) f t
1108-
#if __GLASGOW_HASKELL__
1109-
{-# INLINABLE foldMap #-}
11101191
#endif
11111192

11121193
foldr _ z' EmptyT = z'
@@ -1270,7 +1351,7 @@ foldDigit _ f (One a) = f a
12701351
foldDigit (<+>) f (Two a b) = f a <+> f b
12711352
foldDigit (<+>) f (Three a b c) = f a <+> f b <+> f c
12721353
foldDigit (<+>) f (Four a b c d) = f a <+> f b <+> f c <+> f d
1273-
{-# INLINE foldDigit #-}
1354+
{-# INLINABLE foldDigit #-}
12741355

12751356
instance Foldable Digit where
12761357
foldMap = foldDigit mappend
@@ -3203,15 +3284,56 @@ foldWithIndexNode (<+>) f s (Node3 _ a b c) = f s a <+> f sPsa b <+> f sPsab c
32033284
-- element in the sequence.
32043285
--
32053286
-- @since 0.5.8
3206-
foldMapWithIndex :: Monoid m => (Int -> a -> m) -> Seq a -> m
3287+
foldMapWithIndex :: forall m a. Monoid m => (Int -> a -> m) -> Seq a -> m
3288+
#ifdef __GLASGOW_HASKELL__
3289+
foldMapWithIndex f (Seq t) = foldMapWithIndexFT Bottom 0 t
3290+
where
3291+
foldMapWithIndexFT :: Depth (Elem a) t -> Int -> FingerTree t -> m
3292+
foldMapWithIndexFT !_ !_ EmptyT = mempty
3293+
foldMapWithIndexFT d s (Single xs) = foldMapWithIndexBlob d s xs
3294+
foldMapWithIndexFT d s (Deep _ pr m sf) = case depthSized d of { Sizzy ->
3295+
foldWithIndexDigit (<>) (foldMapWithIndexBlob d) s pr <>
3296+
foldMapWithIndexFT (Deeper d) sPspr m <>
3297+
foldWithIndexDigit (<>) (foldMapWithIndexBlob d) sPsprm sf
3298+
where
3299+
!sPspr = s + size pr
3300+
!sPsprm = sPspr + size m
3301+
}
3302+
3303+
foldMapWithIndexBlob :: Depth (Elem a) t -> Int -> t -> m
3304+
foldMapWithIndexBlob Bottom k (Elem a) = f k a
3305+
foldMapWithIndexBlob (Deeper yop) k (Node2 _s t1 t2) =
3306+
foldMapWithIndexBlob yop k t1 <>
3307+
foldMapWithIndexBlob yop (k + sizeBlob yop t1) t2
3308+
foldMapWithIndexBlob (Deeper yop) k (Node3 _s t1 t2 t3) =
3309+
foldMapWithIndexBlob yop k t1 <>
3310+
foldMapWithIndexBlob yop (k + st1) t2 <>
3311+
foldMapWithIndexBlob yop (k + st1t2) t3
3312+
where
3313+
st1 = sizeBlob yop t1
3314+
st1t2 = st1 + sizeBlob yop t2
3315+
{-# INLINABLE foldMapWithIndex #-}
3316+
3317+
data Sizzy a where
3318+
Sizzy :: Sized a => Sizzy a
3319+
3320+
depthSized :: Depth (Elem a) t -> Sizzy t
3321+
depthSized Bottom = Sizzy
3322+
depthSized (Deeper _) = Sizzy
3323+
3324+
sizeBlob :: Depth (Elem a) t -> t -> Int
3325+
sizeBlob Bottom = size
3326+
sizeBlob (Deeper _) = size
3327+
3328+
#else
32073329
foldMapWithIndex f' (Seq xs') = foldMapWithIndexTreeE (lift_elem f') 0 xs'
32083330
where
32093331
lift_elem :: (Int -> a -> m) -> (Int -> Elem a -> m)
3210-
#ifdef __GLASGOW_HASKELL__
3332+
# ifdef __GLASGOW_HASKELL__
32113333
lift_elem g = coerce g
3212-
#else
3334+
# else
32133335
lift_elem g = \s (Elem a) -> g s a
3214-
#endif
3336+
# endif
32153337
{-# INLINE lift_elem #-}
32163338
-- We have to specialize these functions by hand, unfortunately, because
32173339
-- GHC does not specialize until *all* instances are determined.
@@ -3250,9 +3372,6 @@ foldMapWithIndex f' (Seq xs') = foldMapWithIndexTreeE (lift_elem f') 0 xs'
32503372

32513373
foldMapWithIndexNodeN :: Monoid m => (Int -> Node a -> m) -> Int -> Node (Node a) -> m
32523374
foldMapWithIndexNodeN f i t = foldWithIndexNode (<>) f i t
3253-
3254-
#if __GLASGOW_HASKELL__
3255-
{-# INLINABLE foldMapWithIndex #-}
32563375
#endif
32573376

32583377
-- | 'traverseWithIndex' is a version of 'traverse' that also offers
@@ -4997,3 +5116,7 @@ fromList2 n = execState (replicateA n (State ht))
49975116
where
49985117
ht (x:xs) = (xs, x)
49995118
ht [] = error "fromList2: short list"
5119+
5120+
{-# NOINLINE bongo #-}
5121+
bongo :: Seq [a] -> [a]
5122+
bongo xs = GHC.Exts.inline foldMap id xs
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
{-# OPTIONS_GHC -ddump-prep #-}
2+
{-# LANGUAGE GADTs #-}
3+
{-# LANGUAGE KindSignatures #-}
4+
{-# LANGUAGE PatternSynonyms #-}
5+
{-# LANGUAGE RoleAnnotations #-}
6+
{-# LANGUAGE Trustworthy #-}
7+
{-# LANGUAGE TypeOperators #-}
8+
{-# LANGUAGE ViewPatterns #-}
9+
10+
-- | This module defines efficient representations of GADTs that are shaped
11+
-- like (strict) unary natural numbers. That is, each type looks, from the
12+
-- outside, something like this:
13+
--
14+
-- @
15+
-- data NatLike ... where
16+
-- ZeroLike :: NatLike ...
17+
-- SuccLike :: !(NatLike ...) -> NatLike ...
18+
-- @
19+
--
20+
-- but in fact it is represented by a single machine word. We put these in a
21+
-- separate module to confine the highly unsafe magic used in the
22+
-- implementation.
23+
--
24+
-- Caution: Unlike the GADTs they represent, the types in this module are
25+
-- bounded by @maxBound \@Word@, and attempting to take a successor of the
26+
-- maximum bound will throw an overflow error. That's okay for our purposes
27+
-- of implementing certain functions in "Data.Sequence.Internal"—the spine
28+
-- of a well-formed sequence can only reach a length of around the word
29+
-- size, not even close to @maxBound \@Word@.
30+
31+
module Data.Sequence.Internal.Depth
32+
( Depth_ (Bottom, Deeper)
33+
, Depth2_ (Bottom2, Deeper2)
34+
) where
35+
36+
import Data.Kind (Type)
37+
import Unsafe.Coerce (unsafeCoerce)
38+
39+
-- @Depth_@ is an optimized representation of the following GADT:
40+
--
41+
-- @
42+
-- data Depth_ node a t where
43+
-- Bottom :: Depth_ node a a
44+
-- Deeper :: !(Depth_ node a t) -> Depth_ node a (node t)
45+
-- @
46+
--
47+
-- "Data.Sequence.Internal" fills in the @node@ parameter with its @Node@
48+
-- constructor; we have to be more general in this module because we don't
49+
-- have access to that.
50+
--
51+
-- @Depth_@ is represented internally as a 'Word' for performance, and the
52+
-- 'Bottom' and 'Deeper' pattern synonyms implement the above GADT interface.
53+
-- The implementation is "safe"—in the very unlikely event of arithmetic
54+
-- overflow, an error will be thrown. This decision is subject to change;
55+
-- arithmetic overflow on 64-bit systems requires somewhat absurdly long
56+
-- computations on sequences constructed with extensive amounts of internal
57+
-- sharing (e.g., using the '*>' operator repeatedly).
58+
newtype Depth_ (node :: Type -> Type) (a :: Type) (t :: Type)
59+
= Depth_ Word
60+
type role Depth_ nominal nominal nominal
61+
62+
-- | The depth is 0.
63+
pattern Bottom :: () => t ~ a => Depth_ node a t
64+
pattern Bottom <- (checkBottom -> AtBottom)
65+
where
66+
Bottom = Depth_ 0
67+
68+
-- | The depth is non-zero.
69+
pattern Deeper :: () => t ~ node t' => Depth_ node a t' -> Depth_ node a t
70+
pattern Deeper d <- (checkBottom -> NotBottom d)
71+
where
72+
Deeper (Depth_ d)
73+
| d == maxBound = error "Depth overflow"
74+
| otherwise = Depth_ (d + 1)
75+
76+
{-# COMPLETE Bottom, Deeper #-}
77+
78+
data CheckedBottom node a t where
79+
AtBottom :: CheckedBottom node a a
80+
NotBottom :: !(Depth_ node a t) -> CheckedBottom node a (node t)
81+
82+
checkBottom :: Depth_ node a t -> CheckedBottom node a t
83+
checkBottom (Depth_ 0) = unsafeCoerce AtBottom
84+
checkBottom (Depth_ d) = unsafeCoerce (NotBottom (Depth_ (d - 1)))
85+
86+
87+
-- | A version of 'Depth_' for implementing traversals. Conceptually,
88+
--
89+
-- @
90+
-- data Depth2_ node a t b u where
91+
-- Bottom2 :: Depth_ node a a b b
92+
-- Deeper2 :: !(Depth_ node a t b u) -> Depth_ node a (node t) b (node u)
93+
-- @
94+
newtype Depth2_ (node :: Type -> Type) (a :: Type) (t :: Type) (b :: Type) (u :: Type)
95+
= Depth2_ Word
96+
type role Depth2_ nominal nominal nominal nominal nominal
97+
98+
-- | The depth is 0.
99+
pattern Bottom2 :: () => (t ~ a, u ~ b) => Depth2_ node a t b u
100+
pattern Bottom2 <- (checkBottom2 -> AtBottom2)
101+
where
102+
Bottom2 = Depth2_ 0
103+
104+
-- | The depth is non-zero.
105+
pattern Deeper2 :: () => (t ~ node t', u ~ node u') => Depth2_ node a t' b u' -> Depth2_ node a t b u
106+
pattern Deeper2 d <- (checkBottom2 -> NotBottom2 d)
107+
where
108+
Deeper2 (Depth2_ d)
109+
| d == maxBound = error "Depth2 overflow"
110+
| otherwise = Depth2_ (d + 1)
111+
112+
{-# COMPLETE Bottom2, Deeper2 #-}
113+
114+
data CheckedBottom2 node a t b u where
115+
AtBottom2 :: CheckedBottom2 node a a b b
116+
NotBottom2 :: !(Depth2_ node a t b u) -> CheckedBottom2 node a (node t) b (node u)
117+
118+
checkBottom2 :: Depth2_ node a t b u -> CheckedBottom2 node a t b u
119+
checkBottom2 (Depth2_ 0) = unsafeCoerce AtBottom2
120+
checkBottom2 (Depth2_ d) = unsafeCoerce (NotBottom2 (Depth2_ (d - 1)))

0 commit comments

Comments
 (0)