Skip to content

Commit 6a47782

Browse files
committed
Implement multiple folds using scans
1 parent 0e04dd1 commit 6a47782

File tree

4 files changed

+58
-230
lines changed

4 files changed

+58
-230
lines changed

core/src/Streamly/Internal/Data/Fold/Combinators.hs

Lines changed: 26 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ module Streamly.Internal.Data.Fold.Combinators
3232
, the
3333
, mean
3434
, rollingHash
35-
, defaultSalt
35+
, Scanl.defaultSalt
3636
, rollingHashWithSalt
3737
, rollingHashFirstN
3838
-- , rollingHashLastN
@@ -67,6 +67,7 @@ module Streamly.Internal.Data.Fold.Combinators
6767
-- usually a transformation of the current element rather than an
6868
-- aggregation of all elements till now.
6969
-- , nthLast -- using RingArray array
70+
, rollingMap
7071
, rollingMapM
7172

7273
-- *** Filters
@@ -242,12 +243,10 @@ import Streamly.Internal.Data.Unfold.Type (Unfold(..))
242243
import qualified Prelude
243244
import qualified Streamly.Internal.Data.MutArray.Type as MA
244245
import qualified Streamly.Internal.Data.Array.Type as Array
245-
import qualified Streamly.Internal.Data.Fold.Type as Fold
246246
import qualified Streamly.Internal.Data.Pipe.Type as Pipe
247247
import qualified Streamly.Internal.Data.RingArray as RingArray
248248
import qualified Streamly.Internal.Data.Scanl.Combinators as Scanl
249249
import qualified Streamly.Internal.Data.Scanl.Type as Scanl
250-
import qualified Streamly.Internal.Data.Scanl.Window as Scanl
251250
import qualified Streamly.Internal.Data.Stream.Type as StreamD
252251

253252
import Prelude hiding
@@ -500,17 +499,7 @@ pipe (Pipe consume produce pinitial) (Fold fstep finitial fextract ffinal) =
500499
--
501500
{-# INLINE_NORMAL deleteBy #-}
502501
deleteBy :: Monad m => (a -> a -> Bool) -> a -> Fold m a (Maybe a)
503-
deleteBy eq x0 = fmap extract $ foldl' step (Tuple' False Nothing)
504-
505-
where
506-
507-
step (Tuple' False _) x =
508-
if eq x x0
509-
then Tuple' True Nothing
510-
else Tuple' False (Just x)
511-
step (Tuple' True _) x = Tuple' True (Just x)
512-
513-
extract (Tuple' _ x) = x
502+
deleteBy eq = fromScanl . Scanl.deleteBy eq
514503

515504
-- | Provide a sliding window of length 2 elements.
516505
--
@@ -550,14 +539,7 @@ slide2 (Fold step1 initial1 extract1 final1) = Fold step initial extract final
550539
--
551540
{-# INLINE uniqBy #-}
552541
uniqBy :: Monad m => (a -> a -> Bool) -> Fold m a (Maybe a)
553-
uniqBy eq = rollingMap f
554-
555-
where
556-
557-
f pre curr =
558-
case pre of
559-
Nothing -> Just curr
560-
Just x -> if x `eq` curr then Nothing else Just curr
542+
uniqBy = fromScanl . Scanl.uniqBy
561543

562544
-- | See 'uniqBy'.
563545
--
@@ -567,7 +549,7 @@ uniqBy eq = rollingMap f
567549
--
568550
{-# INLINE uniq #-}
569551
uniq :: (Monad m, Eq a) => Fold m a (Maybe a)
570-
uniq = uniqBy (==)
552+
uniq = fromScanl Scanl.uniq
571553

572554
-- | Strip all leading and trailing occurrences of an element passing a
573555
-- predicate and make all other consecutive occurrences uniq.
@@ -628,17 +610,7 @@ drainBy = drainMapM
628610
--
629611
{-# INLINE the #-}
630612
the :: (Monad m, Eq a) => Fold m a (Maybe a)
631-
the = foldt' step initial id
632-
633-
where
634-
635-
initial = Partial Nothing
636-
637-
step Nothing x = Partial (Just x)
638-
step old@(Just x0) x =
639-
if x0 == x
640-
then Partial old
641-
else Done Nothing
613+
the = fromScanl Scanl.the
642614

643615
------------------------------------------------------------------------------
644616
-- To Summary
@@ -657,7 +629,7 @@ the = foldt' step initial id
657629
--
658630
{-# INLINE sum #-}
659631
sum :: (Monad m, Num a) => Fold m a a
660-
sum = Fold.fromScanl $ Scanl.cumulativeScan Scanl.incrSum
632+
sum = fromScanl Scanl.sum
661633

662634
-- | Determine the product of all elements of a stream of numbers. Returns
663635
-- multiplicative identity (@1@) when the stream is empty. The fold terminates
@@ -669,14 +641,7 @@ sum = Fold.fromScanl $ Scanl.cumulativeScan Scanl.incrSum
669641
--
670642
{-# INLINE product #-}
671643
product :: (Monad m, Num a, Eq a) => Fold m a a
672-
product = foldt' step (Partial 1) id
673-
674-
where
675-
676-
step x a =
677-
if a == 0
678-
then Done 0
679-
else Partial $ x * a
644+
product = fromScanl Scanl.product
680645

681646
------------------------------------------------------------------------------
682647
-- To Summary (Maybe)
@@ -761,17 +726,7 @@ range = fromScanl Scanl.range
761726
--
762727
{-# INLINE mean #-}
763728
mean :: (Monad m, Fractional a) => Fold m a a
764-
mean = fmap done $ foldl' step begin
765-
766-
where
767-
768-
begin = Tuple' 0 0
769-
770-
step (Tuple' x n) y =
771-
let n1 = n + 1
772-
in Tuple' (x + (y - x) / n1) n1
773-
774-
done (Tuple' x _) = x
729+
mean = fromScanl Scanl.mean
775730

776731
-- | Compute a numerically stable (population) variance over all elements in
777732
-- the input stream.
@@ -817,26 +772,15 @@ stdDev = sqrt <$> variance
817772
--
818773
{-# INLINE rollingHashWithSalt #-}
819774
rollingHashWithSalt :: (Monad m, Enum a) => Int64 -> Fold m a Int64
820-
rollingHashWithSalt = foldl' step
821-
822-
where
823-
824-
k = 2891336453 :: Int64
825-
826-
step cksum a = cksum * k + fromIntegral (fromEnum a)
827-
828-
-- | A default salt used in the implementation of 'rollingHash'.
829-
{-# INLINE defaultSalt #-}
830-
defaultSalt :: Int64
831-
defaultSalt = -2578643520546668380
775+
rollingHashWithSalt = fromScanl . Scanl.rollingHashWithSalt
832776

833777
-- | Compute an 'Int' sized polynomial rolling hash of a stream.
834778
--
835779
-- >>> rollingHash = Fold.rollingHashWithSalt Fold.defaultSalt
836780
--
837781
{-# INLINE rollingHash #-}
838782
rollingHash :: (Monad m, Enum a) => Fold m a Int64
839-
rollingHash = rollingHashWithSalt defaultSalt
783+
rollingHash = fromScanl Scanl.rollingHash
840784

841785
-- | Compute an 'Int' sized polynomial rolling hash of the first n elements of
842786
-- a stream.
@@ -846,7 +790,7 @@ rollingHash = rollingHashWithSalt defaultSalt
846790
-- /Pre-release/
847791
{-# INLINE rollingHashFirstN #-}
848792
rollingHashFirstN :: (Monad m, Enum a) => Int -> Fold m a Int64
849-
rollingHashFirstN n = take n rollingHash
793+
rollingHashFirstN = fromScanl . Scanl.rollingHashFirstN
850794

851795
-- XXX Compare this with the implementation in Fold.Window, preferrably use the
852796
-- latter if performance is good.
@@ -860,26 +804,14 @@ rollingHashFirstN n = take n rollingHash
860804
--
861805
{-# INLINE rollingMapM #-}
862806
rollingMapM :: Monad m => (Maybe a -> a -> m b) -> Fold m a b
863-
rollingMapM f = Fold step initial extract extract
864-
865-
where
866-
867-
-- XXX We need just a postscan. We do not need an initial result here.
868-
-- Or we can supply a default initial result as an argument to rollingMapM.
869-
initial = return $ Partial (Nothing, error "Empty stream")
870-
871-
step (prev, _) cur = do
872-
x <- f prev cur
873-
return $ Partial (Just cur, x)
874-
875-
extract = return . snd
807+
rollingMapM = fromScanl . Scanl.rollingMapM
876808

877809
-- |
878810
-- >>> rollingMap f = Fold.rollingMapM (\x y -> return $ f x y)
879811
--
880812
{-# INLINE rollingMap #-}
881813
rollingMap :: Monad m => (Maybe a -> a -> b) -> Fold m a b
882-
rollingMap f = rollingMapM (\x y -> return $ f x y)
814+
rollingMap = fromScanl . Scanl.rollingMap
883815

884816
------------------------------------------------------------------------------
885817
-- Monoidal left folds
@@ -898,7 +830,7 @@ rollingMap f = rollingMapM (\x y -> return $ f x y)
898830
--
899831
{-# INLINE sconcat #-}
900832
sconcat :: (Monad m, Semigroup a) => a -> Fold m a a
901-
sconcat = foldl' (<>)
833+
sconcat = fromScanl . Scanl.sconcat
902834

903835
-- | Monoid concat. Fold an input stream consisting of monoidal elements using
904836
-- 'mappend' and 'mempty'.
@@ -915,7 +847,7 @@ sconcat = foldl' (<>)
915847
mconcat ::
916848
( Monad m
917849
, Monoid a) => Fold m a a
918-
mconcat = sconcat mempty
850+
mconcat = fromScanl Scanl.mconcat
919851

920852
-- |
921853
-- Definition:
@@ -931,7 +863,7 @@ mconcat = sconcat mempty
931863
--
932864
{-# INLINE foldMap #-}
933865
foldMap :: (Monad m, Monoid b) => (a -> b) -> Fold m a b
934-
foldMap f = lmap f mconcat
866+
foldMap = fromScanl . Scanl.foldMap
935867

936868
-- |
937869
-- Definition:
@@ -947,13 +879,7 @@ foldMap f = lmap f mconcat
947879
--
948880
{-# INLINE foldMapM #-}
949881
foldMapM :: (Monad m, Monoid b) => (a -> m b) -> Fold m a b
950-
foldMapM act = foldlM' step (pure mempty)
951-
952-
where
953-
954-
step m a = do
955-
m' <- act a
956-
return $! mappend m m'
882+
foldMapM = fromScanl . Scanl.foldMapM
957883

958884
------------------------------------------------------------------------------
959885
-- Partial Folds
@@ -969,7 +895,7 @@ foldMapM act = foldlM' step (pure mempty)
969895
-- /Pre-release/
970896
{-# INLINE drainN #-}
971897
drainN :: Monad m => Int -> Fold m a ()
972-
drainN n = take n drain
898+
drainN = fromScanl . Scanl.drainN
973899

974900
------------------------------------------------------------------------------
975901
-- To Elements
@@ -1134,16 +1060,7 @@ findIndex predicate = foldt' step (Partial 0) (const Nothing)
11341060
--
11351061
{-# INLINE findIndices #-}
11361062
findIndices :: Monad m => (a -> Bool) -> Fold m a (Maybe Int)
1137-
findIndices predicate =
1138-
-- XXX implement by combining indexing and filtering scans
1139-
fmap (either (const Nothing) Just) $ foldl' step (Left (-1))
1140-
1141-
where
1142-
1143-
step i a =
1144-
if predicate a
1145-
then Right (either id id i + 1)
1146-
else Left (either id id i + 1)
1063+
findIndices = fromScanl . Scanl.findIndices
11471064

11481065
-- | Returns the index of the latest element if the element matches the given
11491066
-- value.
@@ -1154,7 +1071,7 @@ findIndices predicate =
11541071
--
11551072
{-# INLINE elemIndices #-}
11561073
elemIndices :: (Monad m, Eq a) => a -> Fold m a (Maybe Int)
1157-
elemIndices a = findIndices (== a)
1074+
elemIndices = fromScanl . Scanl.elemIndices
11581075

11591076
-- | Returns the first index where a given value is found in the stream.
11601077
--
@@ -2256,7 +2173,7 @@ chunksBetween _low _high _f1 _f2 = undefined
22562173
-- /Pre-release/
22572174
{-# INLINE toStream #-}
22582175
toStream :: (Monad m, Monad n) => Fold m a (Stream n a)
2259-
toStream = fmap StreamD.fromList toList
2176+
toStream = fromScanl Scanl.toStream
22602177

22612178
-- This is more efficient than 'toStream'. toStream is exactly the same as
22622179
-- reversing the stream after toStreamRev.
@@ -2274,7 +2191,7 @@ toStream = fmap StreamD.fromList toList
22742191
-- xn : ... : x2 : x1 : []
22752192
{-# INLINE toStreamRev #-}
22762193
toStreamRev :: (Monad m, Monad n) => Fold m a (Stream n a)
2277-
toStreamRev = fmap StreamD.fromList toListRev
2194+
toStreamRev = fromScanl Scanl.toStreamRev
22782195

22792196
-- XXX This does not fuse. It contains a recursive step function. We will need
22802197
-- a Skip input constructor in the fold type to make it fuse.
@@ -2316,32 +2233,7 @@ bottomBy :: (MonadIO m, Unbox a) =>
23162233
(a -> a -> Ordering)
23172234
-> Int
23182235
-> Fold m a (MutArray a)
2319-
bottomBy cmp n = Fold step initial extract extract
2320-
2321-
where
2322-
2323-
initial = do
2324-
arr <- MA.emptyOf' n
2325-
if n <= 0
2326-
then return $ Done arr
2327-
else return $ Partial (arr, 0)
2328-
2329-
step (arr, i) x =
2330-
if i < n
2331-
then do
2332-
arr' <- MA.snoc arr x
2333-
MA.bubble cmp arr'
2334-
return $ Partial (arr', i + 1)
2335-
else do
2336-
x1 <- MA.unsafeGetIndex (i - 1) arr
2337-
case x `cmp` x1 of
2338-
LT -> do
2339-
MA.unsafePutIndex (i - 1) arr x
2340-
MA.bubble cmp arr
2341-
return $ Partial (arr, i)
2342-
_ -> return $ Partial (arr, i)
2343-
2344-
extract = return . fst
2236+
bottomBy cmp = fromScanl . Scanl.bottomBy cmp
23452237

23462238
-- | Get the top @n@ elements using the supplied comparison function.
23472239
--
@@ -2377,7 +2269,7 @@ topBy cmp = bottomBy (flip cmp)
23772269
-- /Pre-release/
23782270
{-# INLINE top #-}
23792271
top :: (MonadIO m, Unbox a, Ord a) => Int -> Fold m a (MutArray a)
2380-
top = bottomBy $ flip compare
2272+
top = fromScanl . Scanl.top
23812273

23822274
-- | Fold the input stream to bottom n elements.
23832275
--
@@ -2392,7 +2284,7 @@ top = bottomBy $ flip compare
23922284
-- /Pre-release/
23932285
{-# INLINE bottom #-}
23942286
bottom :: (MonadIO m, Unbox a, Ord a) => Int -> Fold m a (MutArray a)
2395-
bottom = bottomBy compare
2287+
bottom = fromScanl . Scanl.bottom
23962288

23972289
------------------------------------------------------------------------------
23982290
-- Interspersed parsing

0 commit comments

Comments
 (0)