Skip to content

Commit 6bee5ae

Browse files
Add concurrent scan combinators with static list of scans
1 parent f568c76 commit 6bee5ae

File tree

2 files changed

+45
-17
lines changed

2 files changed

+45
-17
lines changed

src/Streamly/Internal/Data/Scanl/Concurrent.hs

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
module Streamly.Internal.Data.Scanl.Concurrent
1010
(
1111
parTeeWith
12+
, parDistributeScanM
1213
, parDistributeScan
14+
, parDemuxScanM
1315
, parDemuxScan
1416
)
1517
where
@@ -19,7 +21,7 @@ where
1921
import Control.Concurrent (newEmptyMVar, takeMVar, throwTo)
2022
import Control.Monad.Catch (throwM)
2123
import Control.Monad.IO.Class (MonadIO(liftIO))
22-
import Data.IORef (newIORef, readIORef)
24+
import Data.IORef (newIORef, readIORef, atomicModifyIORef)
2325
import Fusion.Plugin.Types (Fuse(..))
2426
import Streamly.Internal.Control.Concurrent (MonadAsync)
2527
import Streamly.Internal.Data.Atomics (atomicModifyIORefCAS)
@@ -30,6 +32,7 @@ import Streamly.Internal.Data.SVar.Type (adaptState)
3032
import Streamly.Internal.Data.Tuple.Strict (Tuple3'(..))
3133

3234
import qualified Data.Map.Strict as Map
35+
import qualified Streamly.Internal.Data.Stream as Stream
3336

3437
import Streamly.Internal.Data.Fold.Channel.Type
3538
import Streamly.Internal.Data.Channel.Types
@@ -162,13 +165,13 @@ data ScanState s q db f =
162165
-- >>> import Data.IORef
163166
-- >>> ref <- newIORef [Scanl.take 5 Scanl.sum, Scanl.take 5 Scanl.length :: Scanl.Scanl IO Int Int]
164167
-- >>> gen = atomicModifyIORef ref (\xs -> ([], xs))
165-
-- >>> Stream.toList $ Scanl.parDistributeScan id gen (Stream.enumerateFromTo 1 10)
168+
-- >>> Stream.toList $ Scanl.parDistributeScanM id gen (Stream.enumerateFromTo 1 10)
166169
-- ...
167170
--
168-
{-# INLINE parDistributeScan #-}
169-
parDistributeScan :: MonadAsync m =>
171+
{-# INLINE parDistributeScanM #-}
172+
parDistributeScanM :: MonadAsync m =>
170173
(Config -> Config) -> m [Scanl m a b] -> Stream m a -> Stream m [b]
171-
parDistributeScan cfg getFolds (Stream sstep state) =
174+
parDistributeScanM cfg getFolds (Stream sstep state) =
172175
Stream step ScanInit
173176

174177
where
@@ -243,6 +246,20 @@ parDistributeScan cfg getFolds (Stream sstep state) =
243246
else return $ Yield outputs (ScanDrain q db running)
244247
step _ ScanStop = return Stop
245248

249+
-- | Like 'parDistributeScanM' but takes a list of static scans.
250+
--
251+
-- >>> xs = [Scanl.take 5 Scanl.sum, Scanl.take 5 Scanl.length :: Scanl.Scanl IO Int Int]
252+
-- >>> Stream.toList $ Scanl.parDistributeScan id xs (Stream.enumerateFromTo 1 10)
253+
-- ...
254+
{-# INLINE parDistributeScan #-}
255+
parDistributeScan :: MonadAsync m =>
256+
(Config -> Config) -> [Scanl m a b] -> Stream m a -> Stream m [b]
257+
parDistributeScan cfg getFolds stream =
258+
Stream.concatEffect $ do
259+
ref <- liftIO $ newIORef getFolds
260+
let action = liftIO $ atomicModifyIORef ref (\xs -> ([], xs))
261+
return $ parDistributeScanM cfg action stream
262+
246263
{-# ANN type DemuxState Fuse #-}
247264
data DemuxState s q db f =
248265
DemuxInit
@@ -273,17 +290,17 @@ data DemuxState s q db f =
273290
-- >>> getScan k = return (fromJust $ Map.lookup k kv)
274291
-- >>> getKey x = if even x then "even" else "odd"
275292
-- >>> input = Stream.enumerateFromTo 1 10
276-
-- >>> Stream.toList $ Scanl.parDemuxScan id getKey getScan input
293+
-- >>> Stream.toList $ Scanl.parDemuxScanM id getKey getScan input
277294
-- ...
278295
--
279-
{-# INLINE parDemuxScan #-}
280-
parDemuxScan :: (MonadAsync m, Ord k) =>
296+
{-# INLINE parDemuxScanM #-}
297+
parDemuxScanM :: (MonadAsync m, Ord k) =>
281298
(Config -> Config)
282299
-> (a -> k)
283300
-> (k -> m (Scanl m a b))
284301
-> Stream m a
285302
-> Stream m [(k, b)]
286-
parDemuxScan cfg getKey getFold (Stream sstep state) =
303+
parDemuxScanM cfg getKey getFold (Stream sstep state) =
287304
Stream step DemuxInit
288305

289306
where
@@ -368,3 +385,14 @@ parDemuxScan cfg getKey getFold (Stream sstep state) =
368385
return $ Skip (DemuxDrain q db keyToChan1)
369386
else return $ Yield outputs (DemuxDrain q db keyToChan1)
370387
step _ DemuxStop = return Stop
388+
389+
-- | Like 'parDemuxScanM' but the key to scan mapping is static/pure instead of
390+
-- monadic.
391+
{-# INLINE parDemuxScan #-}
392+
parDemuxScan :: (MonadAsync m, Ord k) =>
393+
(Config -> Config)
394+
-> (a -> k)
395+
-> (k -> Scanl m a b)
396+
-> Stream m a
397+
-> Stream m [(k, b)]
398+
parDemuxScan cfg getKey getFold = parDemuxScanM cfg getKey (pure . getFold)

test/Streamly/Test/Data/Scanl/Concurrent.hs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ parDistributeScan_ScanEnd concOpts = do
4848
inpList = [1..streamLen]
4949
inpStream = Stream.fromList inpList
5050
res1 <-
51-
Scanl.parDistributeScan concOpts gen inpStream
51+
Scanl.parDistributeScanM concOpts gen inpStream
5252
& Stream.concatMap Stream.fromList
5353
& Stream.catMaybes
5454
& Stream.fold Fold.toList
@@ -66,7 +66,7 @@ parDemuxScan_ScanEnd concOpts = do
6666
inpList = [1..streamLen]
6767
inpStream = Stream.fromList inpList
6868
res <-
69-
Scanl.parDemuxScan concOpts demuxer gen inpStream
69+
Scanl.parDemuxScanM concOpts demuxer gen inpStream
7070
& Stream.concatMap Stream.fromList
7171
& fmap (\x -> (fst x,) <$> snd x)
7272
& Stream.catMaybes
@@ -82,7 +82,7 @@ parDistributeScan_StreamEnd concOpts = do
8282
inpList = [1..streamLen]
8383
inpStream = Stream.fromList inpList
8484
res1 <-
85-
Scanl.parDistributeScan concOpts gen inpStream
85+
Scanl.parDistributeScanM concOpts gen inpStream
8686
& Stream.concatMap Stream.fromList
8787
& Stream.catMaybes
8888
& Stream.fold Fold.toList
@@ -96,7 +96,7 @@ parDemuxScan_StreamEnd concOpts = do
9696
inpList = [1..streamLen]
9797
inpStream = Stream.fromList inpList
9898
res <-
99-
Scanl.parDemuxScan concOpts demuxer gen inpStream
99+
Scanl.parDemuxScanM concOpts demuxer gen inpStream
100100
& Stream.concatMap Stream.fromList
101101
& fmap (\x -> (fst x,) <$> snd x)
102102
& Stream.catMaybes
@@ -111,11 +111,11 @@ main = hspec
111111
$ modifyMaxSuccess (const 10)
112112
#endif
113113
$ describe moduleName $ do
114-
it "parDistributeScan (stream end) (maxBuffer 1)"
114+
it "parDistributeScanM (stream end) (maxBuffer 1)"
115115
$ parDistributeScan_StreamEnd (Stream.maxBuffer 1)
116-
it "parDistributeScan (scan end) (maxBuffer 1)"
116+
it "parDistributeScanM (scan end) (maxBuffer 1)"
117117
$ parDistributeScan_ScanEnd (Stream.maxBuffer 1)
118-
it "parDemuxScan (stream end) (maxBuffer 1)"
118+
it "parDemuxScanM (stream end) (maxBuffer 1)"
119119
$ parDemuxScan_StreamEnd (Stream.maxBuffer 1)
120-
it "parDemuxScan (scan end) (maxBuffer 1)"
120+
it "parDemuxScanM (scan end) (maxBuffer 1)"
121121
$ parDemuxScan_ScanEnd (Stream.maxBuffer 1)

0 commit comments

Comments
 (0)