Skip to content

Commit 5f27b3e

Browse files
committed
Fix the fold demux combinators to not restart the fold
1 parent 9969702 commit 5f27b3e

File tree

1 file changed

+47
-32
lines changed

1 file changed

+47
-32
lines changed

core/src/Streamly/Internal/Data/Fold/Container.hs

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,12 @@ demuxerToContainer getKey getFold =
387387
step (Tuple' kv kv1) a = do
388388
let k = getKey a
389389
case IsMap.mapLookup k kv of
390-
Nothing -> do
391-
fld <- getFold k
392-
runFold kv kv1 fld (k, a)
390+
Nothing ->
391+
case IsMap.mapLookup k kv1 of
392+
Just _ -> pure $ Tuple' kv kv1
393+
Nothing -> do
394+
fld <- getFold k
395+
runFold kv kv1 fld (k, a)
393396
Just f -> runFold kv kv1 f (k, a)
394397

395398
final (Tuple' kv kv1) = do
@@ -406,7 +409,7 @@ demuxerToContainer getKey getFold =
406409

407410
-- | Scanning variant of 'demuxerToContainer'.
408411
{-# INLINE demuxScanGeneric #-}
409-
demuxScanGeneric :: (Monad m, IsMap f, Traversable f) =>
412+
demuxScanGeneric :: (Monad m, IsMap f, Traversable f, Ord (Key f)) =>
410413
(a -> Key f)
411414
-> (Key f -> m (Fold m a b))
412415
-> Scanl m a (m (f b), Maybe (Key f, b))
@@ -415,10 +418,10 @@ demuxScanGeneric getKey getFold =
415418

416419
where
417420

418-
initial = return $ Tuple' IsMap.mapEmpty Nothing
421+
initial = return $ Tuple3' IsMap.mapEmpty Set.empty Nothing
419422

420423
{-# INLINE runFold #-}
421-
runFold kv (Fold step1 initial1 extract1 final1) (k, a) = do
424+
runFold kv set (Fold step1 initial1 extract1 final1) (k, a) = do
422425
res <- initial1
423426
case res of
424427
Partial s -> do
@@ -427,23 +430,28 @@ demuxScanGeneric getKey getFold =
427430
$ case res1 of
428431
Partial _ ->
429432
let fld = Fold step1 (return res1) extract1 final1
430-
in Tuple' (IsMap.mapInsert k fld kv) Nothing
431-
Done b -> Tuple' (IsMap.mapDelete k kv) (Just (k, b))
433+
set1 = Set.insert k set
434+
kv1 = IsMap.mapInsert k fld kv
435+
in Tuple3' kv1 set1 Nothing
436+
Done b ->
437+
let kv1 = IsMap.mapDelete k kv
438+
set1 = Set.insert k set
439+
in Tuple3' kv1 set1 (Just (k, b))
432440
Done b ->
433441
-- Done in "initial" is possible only for the very first time
434442
-- the fold is initialized, and in that case we have not yet
435443
-- inserted it in the Map, so we do not need to delete it.
436-
return $ Tuple' kv (Just (k, b))
444+
return $ Tuple3' kv (Set.insert k set) (Just (k, b))
437445

438-
step (Tuple' kv _) a = do
446+
step (Tuple3' kv set _) a = do
439447
let k = getKey a
440448
case IsMap.mapLookup k kv of
441449
Nothing -> do
442450
fld <- getFold k
443-
runFold kv fld (k, a)
444-
Just f -> runFold kv f (k, a)
451+
runFold kv set fld (k, a)
452+
Just f -> runFold kv set f (k, a)
445453

446-
extract (Tuple' kv x) = return (Prelude.mapM f kv, x)
454+
extract (Tuple3' kv _ x) = return (Prelude.mapM f kv, x)
447455

448456
where
449457

@@ -453,7 +461,7 @@ demuxScanGeneric getKey getFold =
453461
Partial s -> e s
454462
_ -> error "demuxGeneric: unreachable code"
455463

456-
final (Tuple' kv x) = return (Prelude.mapM f kv, x)
464+
final (Tuple3' kv _ x) = return (Prelude.mapM f kv, x)
457465

458466
where
459467

@@ -647,8 +655,11 @@ demuxerToContainerIO getKey getFold =
647655
let k = getKey a
648656
case IsMap.mapLookup k kv of
649657
Nothing -> do
650-
f <- getFold k
651-
initFold kv kv1 f (k, a)
658+
case IsMap.mapLookup k kv1 of
659+
Just _ -> pure $ Tuple' kv kv1
660+
Nothing -> do
661+
f <- getFold k
662+
initFold kv kv1 f (k, a)
652663
Just ref -> do
653664
f <- liftIO $ readIORef ref
654665
runFold kv kv1 ref f (k, a)
@@ -673,7 +684,7 @@ demuxerToContainerIO getKey getFold =
673684
-- ongoing fold if you are using those concurrently in another thread.
674685
--
675686
{-# INLINE demuxScanGenericIO #-}
676-
demuxScanGenericIO :: (MonadIO m, IsMap f, Traversable f) =>
687+
demuxScanGenericIO :: (MonadIO m, IsMap f, Traversable f, Ord (Key f)) =>
677688
(a -> Key f)
678689
-> (Key f -> m (Fold m a b))
679690
-> Scanl m a (m (f b), Maybe (Key f, b))
@@ -682,10 +693,10 @@ demuxScanGenericIO getKey getFold =
682693

683694
where
684695

685-
initial = return $ Tuple' IsMap.mapEmpty Nothing
696+
initial = return $ Tuple3' IsMap.mapEmpty Set.empty Nothing
686697

687698
{-# INLINE initFold #-}
688-
initFold kv (Fold step1 initial1 extract1 final1) (k, a) = do
699+
initFold kv set (Fold step1 initial1 extract1 final1) (k, a) = do
689700
res <- initial1
690701
case res of
691702
Partial s -> do
@@ -697,12 +708,12 @@ demuxScanGenericIO getKey getFold =
697708
-- accumulator. That will reduce the allocations.
698709
let fld = Fold step1 (return res1) extract1 final1
699710
ref <- liftIO $ newIORef fld
700-
return $ Tuple' (IsMap.mapInsert k ref kv) Nothing
701-
Done b -> return $ Tuple' kv (Just (k, b))
702-
Done b -> return $ Tuple' kv (Just (k, b))
711+
return $ Tuple3' (IsMap.mapInsert k ref kv) set Nothing
712+
Done b -> pure $ Tuple3' kv (Set.insert k set) (Just (k, b))
713+
Done b -> return $ Tuple3' kv (Set.insert k set) (Just (k, b))
703714

704715
{-# INLINE runFold #-}
705-
runFold kv ref (Fold step1 initial1 extract1 final1) (k, a) = do
716+
runFold kv set ref (Fold step1 initial1 extract1 final1) (k, a) = do
706717
res <- initial1
707718
case res of
708719
Partial s -> do
@@ -711,23 +722,27 @@ demuxScanGenericIO getKey getFold =
711722
Partial _ -> do
712723
let fld = Fold step1 (return res1) extract1 final1
713724
liftIO $ writeIORef ref fld
714-
return $ Tuple' kv Nothing
725+
return $ Tuple3' kv set Nothing
715726
Done b ->
716727
let kv1 = IsMap.mapDelete k kv
717-
in return $ Tuple' kv1 (Just (k, b))
728+
set1 = Set.insert k set
729+
in return $ Tuple3' kv1 set1 (Just (k, b))
718730
Done _ -> error "demuxGenericIO: unreachable"
719731

720-
step (Tuple' kv _) a = do
732+
step (Tuple3' kv set _) a = do
721733
let k = getKey a
722734
case IsMap.mapLookup k kv of
723-
Nothing -> do
724-
f <- getFold k
725-
initFold kv f (k, a)
735+
Nothing ->
736+
if Set.member k set
737+
then return (Tuple3' kv set Nothing)
738+
else do
739+
f <- getFold k
740+
initFold kv set f (k, a)
726741
Just ref -> do
727742
f <- liftIO $ readIORef ref
728-
runFold kv ref f (k, a)
743+
runFold kv set ref f (k, a)
729744

730-
extract (Tuple' kv x) = return (Prelude.mapM f kv, x)
745+
extract (Tuple3' kv _ x) = return (Prelude.mapM f kv, x)
731746

732747
where
733748

@@ -738,7 +753,7 @@ demuxScanGenericIO getKey getFold =
738753
Partial s -> e s
739754
_ -> error "demuxGenericIO: unreachable code"
740755

741-
final (Tuple' kv x) = return (Prelude.mapM f kv, x)
756+
final (Tuple3' kv _ x) = return (Prelude.mapM f kv, x)
742757

743758
where
744759

0 commit comments

Comments
 (0)