Skip to content

Commit d5fc4cb

Browse files
committed
MonadSTM: added TSem
Fixes IntersectMBO/ouroboros-network#2587
1 parent 699b56a commit d5fc4cb

File tree

2 files changed

+144
-4
lines changed

2 files changed

+144
-4
lines changed

io-classes/src/Control/Monad/Class/MonadSTM.hs

Lines changed: 142 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ module Control.Monad.Class.MonadSTM
3434
, TBQueueDefault (..)
3535
-- * Default 'TArray' implementation
3636
, TArrayDefault (..)
37+
-- * Default 'TSem' implementation
38+
, TSemDefault (..)
3739
-- * MonadThrow aliases
3840
, throwSTM
3941
, catchSTM
@@ -53,8 +55,9 @@ import qualified Control.Concurrent.STM.TArray as STM
5355
import qualified Control.Concurrent.STM.TBQueue as STM
5456
import qualified Control.Concurrent.STM.TMVar as STM
5557
import qualified Control.Concurrent.STM.TQueue as STM
58+
import qualified Control.Concurrent.STM.TSem as STM
5659
import qualified Control.Concurrent.STM.TVar as STM
57-
import Control.Monad (MonadPlus (..))
60+
import Control.Monad (MonadPlus (..), when)
5861
import qualified Control.Monad.STM as STM
5962

6063
import Control.Monad.Cont (ContT (..))
@@ -160,6 +163,12 @@ class ( Monad m
160163

161164
type TArray m :: Type -> Type -> Type
162165

166+
type TSem m :: Type
167+
newTSem :: Integer -> STM m (TSem m)
168+
waitTSem :: TSem m -> STM m ()
169+
signalTSem :: TSem m -> STM m ()
170+
signalTSemN :: Natural -> TSem m -> STM m ()
171+
163172
-- Helpful derived functions with default implementations
164173

165174
newTVarIO :: a -> m (TVar m a)
@@ -294,6 +303,22 @@ class ( Monad m
294303
=> TBQueue m a -> a -> STM m ()
295304
unGetTBQueue = unGetTBQueueDefault
296305

306+
default newTSem :: TSem m ~ TSemDefault m
307+
=> Integer -> STM m (TSem m)
308+
newTSem = newTSemDefault
309+
310+
default waitTSem :: TSem m ~ TSemDefault m
311+
=> TSem m -> STM m ()
312+
waitTSem = waitTSemDefault
313+
314+
default signalTSem :: TSem m ~ TSemDefault m
315+
=> TSem m -> STM m ()
316+
signalTSem = signalTSemDefault
317+
318+
default signalTSemN :: TSem m ~ TSemDefault m
319+
=> Natural -> TSem m -> STM m ()
320+
signalTSemN = signalTSemNDefault
321+
297322

298323
stateTVarDefault :: MonadSTM m => TVar m s -> (s -> (a, s)) -> STM m a
299324
stateTVarDefault var f = do
@@ -332,13 +357,15 @@ class MonadSTM m
332357
labelTBQueue :: TBQueue m a -> String -> STM m ()
333358
labelTArray :: (Ix i, Show i)
334359
=> TArray m i e -> String -> STM m ()
360+
labelTSem :: TSem m -> String -> STM m ()
335361

336362
labelTVarIO :: TVar m a -> String -> m ()
337363
labelTMVarIO :: TMVar m a -> String -> m ()
338364
labelTQueueIO :: TQueue m a -> String -> m ()
339365
labelTBQueueIO :: TBQueue m a -> String -> m ()
340366
labelTArrayIO :: (Ix i, Show i)
341367
=> TArray m i e -> String -> m ()
368+
labelTSemIO :: TSem m -> String -> m ()
342369

343370
--
344371
-- default implementations
@@ -356,6 +383,10 @@ class MonadSTM m
356383
=> TBQueue m a -> String -> STM m ()
357384
labelTBQueue = labelTBQueueDefault
358385

386+
default labelTSem :: TSem m ~ TSemDefault m
387+
=> TSem m -> String -> STM m ()
388+
labelTSem = labelTSemDefault
389+
359390
default labelTArray :: ( TArray m ~ TArrayDefault m
360391
, Ix i
361392
, Show i
@@ -379,6 +410,9 @@ class MonadSTM m
379410
=> TArray m i e -> String -> m ()
380411
labelTArrayIO = \v l -> atomically (labelTArray v l)
381412

413+
default labelTSemIO :: TSem m -> String -> m ()
414+
labelTSemIO = \v l -> atomically (labelTSem v l)
415+
382416

383417
-- | This type class is indented for 'io-sim', where one might want to access
384418
-- 'TVar' in the underlying 'ST' monad.
@@ -471,14 +505,25 @@ class MonadInspectSTM m
471505
-> (Maybe [a] -> [a] -> InspectMonad m TraceValue)
472506
-> STM m ()
473507

474-
default traceTMVar :: ( TMVar m a ~ TMVarDefault m a
475-
)
508+
traceTSem :: proxy m
509+
-> TSem m
510+
-> (Maybe Integer -> Integer -> InspectMonad m TraceValue)
511+
-> STM m ()
512+
513+
default traceTMVar :: TMVar m a ~ TMVarDefault m a
476514
=> proxy m
477515
-> TMVar m a
478516
-> (Maybe (Maybe a) -> (Maybe a) -> InspectMonad m TraceValue)
479517
-> STM m ()
480518
traceTMVar = traceTMVarDefault
481519

520+
default traceTSem :: TSem m ~ TSemDefault m
521+
=> proxy m
522+
-> TSem m
523+
-> (Maybe Integer -> Integer -> InspectMonad m TraceValue)
524+
-> STM m ()
525+
traceTSem = traceTSemDefault
526+
482527

483528
traceTVarIO :: proxy m
484529
-> TVar m a
@@ -500,6 +545,11 @@ class MonadInspectSTM m
500545
-> (Maybe [a] -> [a] -> InspectMonad m TraceValue)
501546
-> m ()
502547

548+
traceTSemIO :: proxy m
549+
-> TSem m
550+
-> (Maybe Integer -> Integer -> InspectMonad m TraceValue)
551+
-> m ()
552+
503553
default traceTVarIO :: proxy m
504554
-> TVar m a
505555
-> (Maybe a -> a -> InspectMonad m TraceValue)
@@ -524,6 +574,12 @@ class MonadInspectSTM m
524574
-> m ()
525575
traceTBQueueIO = \p v f -> atomically (traceTBQueue p v f)
526576

577+
default traceTSemIO :: proxy m
578+
-> TSem m
579+
-> (Maybe Integer -> Integer -> InspectMonad m TraceValue)
580+
-> m ()
581+
traceTSemIO = \p v f -> atomically (traceTSem p v f)
582+
527583

528584
--
529585
-- Instance for IO uses the existing STM library implementations
@@ -539,6 +595,7 @@ instance MonadSTM IO where
539595
type TQueue IO = STM.TQueue
540596
type TBQueue IO = STM.TBQueue
541597
type TArray IO = STM.TArray
598+
type TSem IO = STM.TSem
542599

543600
newTVar = STM.newTVar
544601
readTVar = STM.readTVar
@@ -579,6 +636,10 @@ instance MonadSTM IO where
579636
isEmptyTBQueue = STM.isEmptyTBQueue
580637
isFullTBQueue = STM.isFullTBQueue
581638
unGetTBQueue = STM.unGetTBQueue
639+
newTSem = STM.newTSem
640+
waitTSem = STM.waitTSem
641+
signalTSem = STM.signalTSem
642+
signalTSemN = STM.signalTSemN
582643

583644
newTVarIO = STM.newTVarIO
584645
readTVarIO = STM.readTVarIO
@@ -595,12 +656,14 @@ instance MonadLabelledSTM IO where
595656
labelTQueue = \_ _ -> return ()
596657
labelTBQueue = \_ _ -> return ()
597658
labelTArray = \_ _ -> return ()
659+
labelTSem = \_ _ -> return ()
598660

599661
labelTVarIO = \_ _ -> return ()
600662
labelTMVarIO = \_ _ -> return ()
601663
labelTQueueIO = \_ _ -> return ()
602664
labelTBQueueIO = \_ _ -> return ()
603665
labelTArrayIO = \_ _ -> return ()
666+
labelTSemIO = \_ _ -> return ()
604667

605668
-- | noop instance
606669
--
@@ -609,11 +672,13 @@ instance MonadTraceSTM IO where
609672
traceTMVar = \_ _ _ -> return ()
610673
traceTQueue = \_ _ _ -> return ()
611674
traceTBQueue = \_ _ _ -> return ()
675+
traceTSem = \_ _ _ -> return ()
612676

613677
traceTVarIO = \_ _ _ -> return ()
614678
traceTMVarIO = \_ _ _ -> return ()
615679
traceTQueueIO = \_ _ _ -> return ()
616680
traceTBQueueIO = \_ _ _ -> return ()
681+
traceTSemIO = \_ _ _ -> return ()
617682

618683
-- | Wrapper around 'BlockedIndefinitelyOnSTM' that stores a call stack
619684
data BlockedIndefinitely = BlockedIndefinitely {
@@ -981,6 +1046,44 @@ labelTArrayDefault (TArray arr) name = do
9811046
let as = Array.assocs arr
9821047
traverse_ (\(i, v) -> labelTVar v (name ++ ":" ++ show i)) as
9831048

1049+
1050+
--
1051+
-- Default `TSem` implementation
1052+
--
1053+
1054+
newtype TSemDefault m = TSem (TVar m Integer)
1055+
1056+
labelTSemDefault :: MonadLabelledSTM m => TSemDefault m -> String -> STM m ()
1057+
labelTSemDefault (TSem t) = labelTVar t
1058+
1059+
traceTSemDefault :: MonadTraceSTM m
1060+
=> proxy m
1061+
-> TSemDefault m
1062+
-> (Maybe Integer -> Integer -> InspectMonad m TraceValue)
1063+
-> STM m ()
1064+
traceTSemDefault proxy (TSem t) k = traceTVar proxy t k
1065+
1066+
newTSemDefault :: MonadSTM m => Integer -> STM m (TSemDefault m)
1067+
newTSemDefault i = TSem <$> (newTVar $! i)
1068+
1069+
waitTSemDefault :: MonadSTM m => TSemDefault m -> STM m ()
1070+
waitTSemDefault (TSem t) = do
1071+
i <- readTVar t
1072+
when (i <= 0) retry
1073+
writeTVar t $! (i-1)
1074+
1075+
signalTSemDefault :: MonadSTM m => TSemDefault m -> STM m ()
1076+
signalTSemDefault (TSem t) = do
1077+
i <- readTVar t
1078+
writeTVar t $! i+1
1079+
1080+
signalTSemNDefault :: MonadSTM m => Natural -> TSemDefault m -> STM m ()
1081+
signalTSemNDefault 0 _ = return ()
1082+
signalTSemNDefault 1 s = signalTSemDefault s
1083+
signalTSemNDefault n (TSem t) = do
1084+
i <- readTVar t
1085+
writeTVar t $! i+(toInteger n)
1086+
9841087
-- | 'throwIO' specialised to @stm@ monad.
9851088
--
9861089
throwSTM :: (MonadSTM m, MonadThrow.MonadThrow (STM m), Exception e)
@@ -1094,6 +1197,12 @@ instance MonadSTM m => MonadSTM (ContT r m) where
10941197

10951198
type TArray (ContT r m) = TArray m
10961199

1200+
type TSem (ContT r m) = TSem m
1201+
newTSem = WrappedSTM . newTSem
1202+
waitTSem = WrappedSTM . waitTSem
1203+
signalTSem = WrappedSTM . signalTSem
1204+
signalTSemN = WrappedSTM .: signalTSemN
1205+
10971206

10981207
instance MonadSTM m => MonadSTM (ReaderT r m) where
10991208
type STM (ReaderT r m) = WrappedSTM Reader r m
@@ -1149,6 +1258,12 @@ instance MonadSTM m => MonadSTM (ReaderT r m) where
11491258

11501259
type TArray (ReaderT r m) = TArray m
11511260

1261+
type TSem (ReaderT r m) = TSem m
1262+
newTSem = WrappedSTM . newTSem
1263+
waitTSem = WrappedSTM . waitTSem
1264+
signalTSem = WrappedSTM . signalTSem
1265+
signalTSemN = WrappedSTM .: signalTSemN
1266+
11521267

11531268
instance (Monoid w, MonadSTM m) => MonadSTM (WriterT w m) where
11541269
type STM (WriterT w m) = WrappedSTM Writer w m
@@ -1204,6 +1319,12 @@ instance (Monoid w, MonadSTM m) => MonadSTM (WriterT w m) where
12041319

12051320
type TArray (WriterT w m) = TArray m
12061321

1322+
type TSem (WriterT w m) = TSem m
1323+
newTSem = WrappedSTM . newTSem
1324+
waitTSem = WrappedSTM . waitTSem
1325+
signalTSem = WrappedSTM . signalTSem
1326+
signalTSemN = WrappedSTM .: signalTSemN
1327+
12071328

12081329
instance MonadSTM m => MonadSTM (StateT s m) where
12091330
type STM (StateT s m) = WrappedSTM State s m
@@ -1259,6 +1380,12 @@ instance MonadSTM m => MonadSTM (StateT s m) where
12591380

12601381
type TArray (StateT s m) = TArray m
12611382

1383+
type TSem (StateT s m) = TSem m
1384+
newTSem = WrappedSTM . newTSem
1385+
waitTSem = WrappedSTM . waitTSem
1386+
signalTSem = WrappedSTM . signalTSem
1387+
signalTSemN = WrappedSTM .: signalTSemN
1388+
12621389

12631390
instance MonadSTM m => MonadSTM (ExceptT e m) where
12641391
type STM (ExceptT e m) = WrappedSTM Except e m
@@ -1314,6 +1441,12 @@ instance MonadSTM m => MonadSTM (ExceptT e m) where
13141441

13151442
type TArray (ExceptT e m) = TArray m
13161443

1444+
type TSem (ExceptT e m) = TSem m
1445+
newTSem = WrappedSTM . newTSem
1446+
waitTSem = WrappedSTM . waitTSem
1447+
signalTSem = WrappedSTM . signalTSem
1448+
signalTSemN = WrappedSTM .: signalTSemN
1449+
13171450

13181451
instance (Monoid w, MonadSTM m) => MonadSTM (RWST r w s m) where
13191452
type STM (RWST r w s m) = WrappedSTM RWS (r, w, s) m
@@ -1369,6 +1502,12 @@ instance (Monoid w, MonadSTM m) => MonadSTM (RWST r w s m) where
13691502

13701503
type TArray (RWST r w s m) = TArray m
13711504

1505+
type TSem (RWST r w s m) = TSem m
1506+
newTSem = WrappedSTM . newTSem
1507+
waitTSem = WrappedSTM . waitTSem
1508+
signalTSem = WrappedSTM . signalTSem
1509+
signalTSemN = WrappedSTM .: signalTSemN
1510+
13721511

13731512
(.:) :: (c -> d) -> (a -> b -> c) -> (a -> b -> d)
13741513
(f .: g) x y = f (g x y)

io-sim/src/Control/Monad/IOSim/Types.hs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ import Control.Monad.Class.MonadMVar
7676
import Control.Monad.Class.MonadST
7777
import Control.Monad.Class.MonadSTM (MonadInspectSTM (..),
7878
MonadLabelledSTM (..), MonadSTM, MonadTraceSTM (..),
79-
TArrayDefault, TMVarDefault, TraceValue)
79+
TArrayDefault, TMVarDefault, TSemDefault, TraceValue)
8080
import qualified Control.Monad.Class.MonadSTM as MonadSTM
8181
import Control.Monad.Class.MonadSay
8282
import Control.Monad.Class.MonadTest
@@ -399,6 +399,7 @@ instance MonadSTM (IOSim s) where
399399
type TQueue (IOSim s) = TQueueDefault (IOSim s)
400400
type TBQueue (IOSim s) = TBQueueDefault (IOSim s)
401401
type TArray (IOSim s) = TArrayDefault (IOSim s)
402+
type TSem (IOSim s) = TSemDefault (IOSim s)
402403

403404
atomically action = IOSim $ oneShot $ \k -> Atomically action k
404405

0 commit comments

Comments
 (0)