99module Streamly.Internal.Data.Scanl.Concurrent
1010 (
1111 parTeeWith
12+ , parDistributeScanM
1213 , parDistributeScan
14+ , parDemuxScanM
1315 , parDemuxScan
1416 )
1517where
1921import Control.Concurrent (newEmptyMVar , takeMVar , throwTo )
2022import Control.Monad.Catch (throwM )
2123import Control.Monad.IO.Class (MonadIO (liftIO ))
22- import Data.IORef (newIORef , readIORef )
24+ import Data.IORef (newIORef , readIORef , atomicModifyIORef )
2325import Fusion.Plugin.Types (Fuse (.. ))
2426import Streamly.Internal.Control.Concurrent (MonadAsync )
2527import Streamly.Internal.Data.Atomics (atomicModifyIORefCAS )
@@ -30,6 +32,7 @@ import Streamly.Internal.Data.SVar.Type (adaptState)
3032import Streamly.Internal.Data.Tuple.Strict (Tuple3' (.. ))
3133
3234import qualified Data.Map.Strict as Map
35+ import qualified Streamly.Internal.Data.Stream as Stream
3336
3437import Streamly.Internal.Data.Fold.Channel.Type
3538import Streamly.Internal.Data.Channel.Types
@@ -162,13 +165,13 @@ data ScanState s q db f =
162165-- >>> import Data.IORef
163166-- >>> ref <- newIORef [Scanl.take 5 Scanl.sum, Scanl.take 5 Scanl.length :: Scanl.Scanl IO Int Int]
164167-- >>> gen = atomicModifyIORef ref (\xs -> ([], xs))
165- -- >>> Stream.toList $ Scanl.parDistributeScan id gen (Stream.enumerateFromTo 1 10)
168+ -- >>> Stream.toList $ Scanl.parDistributeScanM id gen (Stream.enumerateFromTo 1 10)
166169-- ...
167170--
168- {-# INLINE parDistributeScan #-}
169- parDistributeScan :: MonadAsync m =>
171+ {-# INLINE parDistributeScanM #-}
172+ parDistributeScanM :: MonadAsync m =>
170173 (Config -> Config ) -> m [Scanl m a b ] -> Stream m a -> Stream m [b ]
171- parDistributeScan cfg getFolds (Stream sstep state) =
174+ parDistributeScanM cfg getFolds (Stream sstep state) =
172175 Stream step ScanInit
173176
174177 where
@@ -243,6 +246,20 @@ parDistributeScan cfg getFolds (Stream sstep state) =
243246 else return $ Yield outputs (ScanDrain q db running)
244247 step _ ScanStop = return Stop
245248
249+ -- | Like 'parDistributeScanM' but takes a list of static scans.
250+ --
251+ -- >>> xs = [Scanl.take 5 Scanl.sum, Scanl.take 5 Scanl.length :: Scanl.Scanl IO Int Int]
252+ -- >>> Stream.toList $ Scanl.parDistributeScan id xs (Stream.enumerateFromTo 1 10)
253+ -- ...
254+ {-# INLINE parDistributeScan #-}
255+ parDistributeScan :: MonadAsync m =>
256+ (Config -> Config ) -> [Scanl m a b ] -> Stream m a -> Stream m [b ]
257+ parDistributeScan cfg getFolds stream =
258+ Stream. concatEffect $ do
259+ ref <- liftIO $ newIORef getFolds
260+ let action = liftIO $ atomicModifyIORef ref (\ xs -> ([] , xs))
261+ return $ parDistributeScanM cfg action stream
262+
246263{-# ANN type DemuxState Fuse #-}
247264data DemuxState s q db f =
248265 DemuxInit
@@ -273,17 +290,17 @@ data DemuxState s q db f =
273290-- >>> getScan k = return (fromJust $ Map.lookup k kv)
274291-- >>> getKey x = if even x then "even" else "odd"
275292-- >>> input = Stream.enumerateFromTo 1 10
276- -- >>> Stream.toList $ Scanl.parDemuxScan id getKey getScan input
293+ -- >>> Stream.toList $ Scanl.parDemuxScanM id getKey getScan input
277294-- ...
278295--
279- {-# INLINE parDemuxScan #-}
280- parDemuxScan :: (MonadAsync m , Ord k ) =>
296+ {-# INLINE parDemuxScanM #-}
297+ parDemuxScanM :: (MonadAsync m , Ord k ) =>
281298 (Config -> Config )
282299 -> (a -> k )
283300 -> (k -> m (Scanl m a b ))
284301 -> Stream m a
285302 -> Stream m [(k , b )]
286- parDemuxScan cfg getKey getFold (Stream sstep state) =
303+ parDemuxScanM cfg getKey getFold (Stream sstep state) =
287304 Stream step DemuxInit
288305
289306 where
@@ -368,3 +385,14 @@ parDemuxScan cfg getKey getFold (Stream sstep state) =
368385 return $ Skip (DemuxDrain q db keyToChan1)
369386 else return $ Yield outputs (DemuxDrain q db keyToChan1)
370387 step _ DemuxStop = return Stop
388+
389+ -- | Like 'parDemuxScanM' but the key to scan mapping is static/pure instead of
390+ -- monadic.
391+ {-# INLINE parDemuxScan #-}
392+ parDemuxScan :: (MonadAsync m , Ord k ) =>
393+ (Config -> Config )
394+ -> (a -> k )
395+ -> (k -> Scanl m a b )
396+ -> Stream m a
397+ -> Stream m [(k , b )]
398+ parDemuxScan cfg getKey getFold = parDemuxScanM cfg getKey (pure . getFold)
0 commit comments