Skip to content

Commit 699b56a

Browse files
committed
MonadSTM: added TArray
Fixes IntersectMBO/ouroboros-network#2588
1 parent 10b9966 commit 699b56a

File tree

5 files changed

+127
-10
lines changed

5 files changed

+127
-10
lines changed

io-classes/io-classes.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ library
5252
ScopedTypeVariables
5353
RankNTypes
5454
build-depends: base >=4.9 && <4.18,
55+
array,
5556
async >=2.1,
5657
bytestring,
5758
deque,

io-classes/src/Control/Monad/Class/MonadSTM.hs

Lines changed: 92 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
{-# LANGUAGE DefaultSignatures #-}
33
{-# LANGUAGE DerivingStrategies #-}
44
{-# LANGUAGE FlexibleContexts #-}
5+
{-# LANGUAGE FlexibleInstances #-}
56
{-# LANGUAGE GADTs #-}
67
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
78
{-# LANGUAGE MultiParamTypeClasses #-}
@@ -31,6 +32,8 @@ module Control.Monad.Class.MonadSTM
3132
, TQueueDefault (..)
3233
-- * Default 'TBQueue' implementation
3334
, TBQueueDefault (..)
35+
-- * Default 'TArray' implementation
36+
, TArrayDefault (..)
3437
-- * MonadThrow aliases
3538
, throwSTM
3639
, catchSTM
@@ -46,6 +49,7 @@ module Control.Monad.Class.MonadSTM
4649

4750
import Prelude hiding (read)
4851

52+
import qualified Control.Concurrent.STM.TArray as STM
4953
import qualified Control.Concurrent.STM.TBQueue as STM
5054
import qualified Control.Concurrent.STM.TMVar as STM
5155
import qualified Control.Concurrent.STM.TQueue as STM
@@ -65,7 +69,13 @@ import qualified Control.Monad.Class.MonadThrow as MonadThrow
6569

6670
import Control.Applicative (Alternative (..))
6771
import Control.Exception
72+
import Data.Array (Array, bounds)
73+
import qualified Data.Array as Array
74+
import Data.Array.Base (IArray (numElements), MArray (..),
75+
arrEleBottom, listArray, unsafeAt)
76+
import Data.Foldable (traverse_)
6877
import Data.Function (on)
78+
import Data.Ix (Ix, rangeSize)
6979
import Data.Kind (Type)
7080
import Data.Typeable (Typeable)
7181
import GHC.Stack
@@ -148,6 +158,8 @@ class ( Monad m
148158
isFullTBQueue :: TBQueue m a -> STM m Bool
149159
unGetTBQueue :: TBQueue m a -> a -> STM m ()
150160

161+
type TArray m :: Type -> Type -> Type
162+
151163
-- Helpful derived functions with default implementations
152164

153165
newTVarIO :: a -> m (TVar m a)
@@ -314,15 +326,19 @@ newEmptyTMVarM = newEmptyTMVarIO
314326
--
315327
class MonadSTM m
316328
=> MonadLabelledSTM m where
317-
labelTVar :: TVar m a -> String -> STM m ()
318-
labelTMVar :: TMVar m a -> String -> STM m ()
319-
labelTQueue :: TQueue m a -> String -> STM m ()
320-
labelTBQueue :: TBQueue m a -> String -> STM m ()
321-
322-
labelTVarIO :: TVar m a -> String -> m ()
323-
labelTMVarIO :: TMVar m a -> String -> m ()
324-
labelTQueueIO :: TQueue m a -> String -> m ()
325-
labelTBQueueIO :: TBQueue m a -> String -> m ()
329+
labelTVar :: TVar m a -> String -> STM m ()
330+
labelTMVar :: TMVar m a -> String -> STM m ()
331+
labelTQueue :: TQueue m a -> String -> STM m ()
332+
labelTBQueue :: TBQueue m a -> String -> STM m ()
333+
labelTArray :: (Ix i, Show i)
334+
=> TArray m i e -> String -> STM m ()
335+
336+
labelTVarIO :: TVar m a -> String -> m ()
337+
labelTMVarIO :: TMVar m a -> String -> m ()
338+
labelTQueueIO :: TQueue m a -> String -> m ()
339+
labelTBQueueIO :: TBQueue m a -> String -> m ()
340+
labelTArrayIO :: (Ix i, Show i)
341+
=> TArray m i e -> String -> m ()
326342

327343
--
328344
-- default implementations
@@ -340,6 +356,13 @@ class MonadSTM m
340356
=> TBQueue m a -> String -> STM m ()
341357
labelTBQueue = labelTBQueueDefault
342358

359+
default labelTArray :: ( TArray m ~ TArrayDefault m
360+
, Ix i
361+
, Show i
362+
)
363+
=> TArray m i e -> String -> STM m ()
364+
labelTArray = labelTArrayDefault
365+
343366
default labelTVarIO :: TVar m a -> String -> m ()
344367
labelTVarIO = \v l -> atomically (labelTVar v l)
345368

@@ -352,6 +375,10 @@ class MonadSTM m
352375
default labelTBQueueIO :: TBQueue m a -> String -> m ()
353376
labelTBQueueIO = \v l -> atomically (labelTBQueue v l)
354377

378+
default labelTArrayIO :: (Ix i, Show i)
379+
=> TArray m i e -> String -> m ()
380+
labelTArrayIO = \v l -> atomically (labelTArray v l)
381+
355382

356383
-- | This type class is indented for 'io-sim', where one might want to access
357384
-- 'TVar' in the underlying 'ST' monad.
@@ -511,6 +538,7 @@ instance MonadSTM IO where
511538
type TMVar IO = STM.TMVar
512539
type TQueue IO = STM.TQueue
513540
type TBQueue IO = STM.TBQueue
541+
type TArray IO = STM.TArray
514542

515543
newTVar = STM.newTVar
516544
readTVar = STM.readTVar
@@ -566,11 +594,13 @@ instance MonadLabelledSTM IO where
566594
labelTMVar = \_ _ -> return ()
567595
labelTQueue = \_ _ -> return ()
568596
labelTBQueue = \_ _ -> return ()
597+
labelTArray = \_ _ -> return ()
569598

570599
labelTVarIO = \_ _ -> return ()
571600
labelTMVarIO = \_ _ -> return ()
572601
labelTQueueIO = \_ _ -> return ()
573602
labelTBQueueIO = \_ _ -> return ()
603+
labelTArrayIO = \_ _ -> return ()
574604

575605
-- | noop instance
576606
--
@@ -910,6 +940,47 @@ unGetTBQueueDefault (TBQueue rsize read wsize _write _size) a = do
910940
writeTVar read (a:xs)
911941

912942

943+
--
944+
-- Default `TArray` implementation
945+
--
946+
947+
-- | Default implementation of 'TArray'.
948+
--
949+
data TArrayDefault m i e = TArray (Array i (TVar m e))
950+
deriving Typeable
951+
952+
deriving instance (Eq (TVar m e), Ix i) => Eq (TArrayDefault m i e)
953+
954+
instance (Monad stm, MonadSTM m, stm ~ STM m)
955+
=> MArray (TArrayDefault m) e stm where
956+
getBounds (TArray a) = return (bounds a)
957+
newArray b e = do
958+
a <- rep (rangeSize b) (newTVar e)
959+
return $ TArray (listArray b a)
960+
newArray_ b = do
961+
a <- rep (rangeSize b) (newTVar arrEleBottom)
962+
return $ TArray (listArray b a)
963+
unsafeRead (TArray a) i = readTVar $ unsafeAt a i
964+
unsafeWrite (TArray a) i e = writeTVar (unsafeAt a i) e
965+
getNumElements (TArray a) = return (numElements a)
966+
967+
rep :: Monad m => Int -> m a -> m [a]
968+
rep n m = go n []
969+
where
970+
go 0 xs = return xs
971+
go i xs = do
972+
x <- m
973+
go (i-1) (x:xs)
974+
975+
labelTArrayDefault :: ( MonadLabelledSTM m
976+
, Ix i
977+
, Show i
978+
)
979+
=> TArrayDefault m i e -> String -> STM m ()
980+
labelTArrayDefault (TArray arr) name = do
981+
let as = Array.assocs arr
982+
traverse_ (\(i, v) -> labelTVar v (name ++ ":" ++ show i)) as
983+
913984
-- | 'throwIO' specialised to @stm@ monad.
914985
--
915986
throwSTM :: (MonadSTM m, MonadThrow.MonadThrow (STM m), Exception e)
@@ -1021,6 +1092,8 @@ instance MonadSTM m => MonadSTM (ContT r m) where
10211092
isFullTBQueue = WrappedSTM . isFullTBQueue
10221093
unGetTBQueue = WrappedSTM .: unGetTBQueue
10231094

1095+
type TArray (ContT r m) = TArray m
1096+
10241097

10251098
instance MonadSTM m => MonadSTM (ReaderT r m) where
10261099
type STM (ReaderT r m) = WrappedSTM Reader r m
@@ -1074,6 +1147,8 @@ instance MonadSTM m => MonadSTM (ReaderT r m) where
10741147
isFullTBQueue = WrappedSTM . isFullTBQueue
10751148
unGetTBQueue = WrappedSTM .: unGetTBQueue
10761149

1150+
type TArray (ReaderT r m) = TArray m
1151+
10771152

10781153
instance (Monoid w, MonadSTM m) => MonadSTM (WriterT w m) where
10791154
type STM (WriterT w m) = WrappedSTM Writer w m
@@ -1127,6 +1202,8 @@ instance (Monoid w, MonadSTM m) => MonadSTM (WriterT w m) where
11271202
isFullTBQueue = WrappedSTM . isFullTBQueue
11281203
unGetTBQueue = WrappedSTM .: unGetTBQueue
11291204

1205+
type TArray (WriterT w m) = TArray m
1206+
11301207

11311208
instance MonadSTM m => MonadSTM (StateT s m) where
11321209
type STM (StateT s m) = WrappedSTM State s m
@@ -1180,6 +1257,8 @@ instance MonadSTM m => MonadSTM (StateT s m) where
11801257
isFullTBQueue = WrappedSTM . isFullTBQueue
11811258
unGetTBQueue = WrappedSTM .: unGetTBQueue
11821259

1260+
type TArray (StateT s m) = TArray m
1261+
11831262

11841263
instance MonadSTM m => MonadSTM (ExceptT e m) where
11851264
type STM (ExceptT e m) = WrappedSTM Except e m
@@ -1233,6 +1312,8 @@ instance MonadSTM m => MonadSTM (ExceptT e m) where
12331312
isFullTBQueue = WrappedSTM . isFullTBQueue
12341313
unGetTBQueue = WrappedSTM .: unGetTBQueue
12351314

1315+
type TArray (ExceptT e m) = TArray m
1316+
12361317

12371318
instance (Monoid w, MonadSTM m) => MonadSTM (RWST r w s m) where
12381319
type STM (RWST r w s m) = WrappedSTM RWS (r, w, s) m
@@ -1286,6 +1367,8 @@ instance (Monoid w, MonadSTM m) => MonadSTM (RWST r w s m) where
12861367
isFullTBQueue = WrappedSTM . isFullTBQueue
12871368
unGetTBQueue = WrappedSTM .: unGetTBQueue
12881369

1370+
type TArray (RWST r w s m) = TArray m
1371+
12891372

12901373
(.:) :: (c -> d) -> (a -> b -> c) -> (a -> b -> d)
12911374
(f .: g) x y = f (g x y)

io-sim/src/Control/Monad/IOSim/Types.hs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ import Control.Monad.Class.MonadMVar
7676
import Control.Monad.Class.MonadST
7777
import Control.Monad.Class.MonadSTM (MonadInspectSTM (..),
7878
MonadLabelledSTM (..), MonadSTM, MonadTraceSTM (..),
79-
TMVarDefault, TraceValue)
79+
TArrayDefault, TMVarDefault, TraceValue)
8080
import qualified Control.Monad.Class.MonadSTM as MonadSTM
8181
import Control.Monad.Class.MonadSay
8282
import Control.Monad.Class.MonadTest
@@ -398,6 +398,7 @@ instance MonadSTM (IOSim s) where
398398
type TMVar (IOSim s) = TMVarDefault (IOSim s)
399399
type TQueue (IOSim s) = TQueueDefault (IOSim s)
400400
type TBQueue (IOSim s) = TBQueueDefault (IOSim s)
401+
type TArray (IOSim s) = TArrayDefault (IOSim s)
401402

402403
atomically action = IOSim $ oneShot $ \k -> Atomically action k
403404

strict-stm/src/Control/Monad/Class/MonadSTM/Strict.hs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
{-# LANGUAGE CPP #-}
33
{-# LANGUAGE DuplicateRecordFields #-}
44
{-# LANGUAGE FlexibleContexts #-}
5+
{-# LANGUAGE FlexibleInstances #-}
6+
{-# LANGUAGE MultiParamTypeClasses #-}
57
{-# LANGUAGE NamedFieldPuns #-}
68
{-# LANGUAGE TypeFamilies #-}
79
{-# LANGUAGE TypeOperators #-}
10+
{-# LANGUAGE UndecidableInstances #-}
811

912
-- to preserve 'HasCallstack' constraint on 'checkInvariant'
1013
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
@@ -13,6 +16,9 @@ module Control.Monad.Class.MonadSTM.Strict
1316
( module X
1417
, LazyTVar
1518
, LazyTMVar
19+
, LazyTQueue
20+
, LazyTBQueue
21+
, LazyTArray
1622
-- * 'StrictTVar'
1723
, StrictTVar
1824
, labelTVar
@@ -89,6 +95,10 @@ module Control.Monad.Class.MonadSTM.Strict
8995
, isEmptyTBQueue
9096
, isFullTBQueue
9197
, unGetTBQueue
98+
-- * 'StrictTArray'
99+
, StrictTArray
100+
, toLazyTArray
101+
, fromLazyTArray
92102
-- ** Low-level API
93103
, checkInvariant
94104
-- * Deprecated API
@@ -115,6 +125,7 @@ import Control.Monad.Class.MonadSTM as X hiding (LazyTMVar, LazyTVar,
115125
tryReadTMVar, tryReadTQueue, tryTakeTMVar, unGetTBQueue,
116126
unGetTQueue, writeTBQueue, writeTQueue, writeTVar)
117127
import qualified Control.Monad.Class.MonadSTM as Lazy
128+
import Data.Array.Base (MArray (..))
118129
import GHC.Stack
119130
import Numeric.Natural (Natural)
120131

@@ -126,6 +137,7 @@ type LazyTVar m = Lazy.TVar m
126137
type LazyTMVar m = Lazy.TMVar m
127138
type LazyTQueue m = Lazy.TQueue m
128139
type LazyTBQueue m = Lazy.TBQueue m
140+
type LazyTArray m = Lazy.TArray m
129141

130142
{-------------------------------------------------------------------------------
131143
Strict TVar
@@ -453,6 +465,25 @@ isFullTBQueue = Lazy.isFullTBQueue . toLazyTBQueue
453465
unGetTBQueue :: MonadSTM m => StrictTBQueue m a -> a -> STM m ()
454466
unGetTBQueue (StrictTBQueue queue) !a = Lazy.unGetTBQueue queue a
455467

468+
{-------------------------------------------------------------------------------
469+
StrictTArray
470+
-------------------------------------------------------------------------------}
471+
472+
newtype StrictTArray m i e = StrictTArray { toLazyTArray :: LazyTArray m i e }
473+
474+
fromLazyTArray :: LazyTArray m i e -> StrictTArray m i e
475+
fromLazyTArray = StrictTArray
476+
477+
instance ( MArray (Lazy.TArray m) e stm
478+
, Monad stm
479+
)
480+
=> MArray (StrictTArray m) e stm where
481+
getBounds (StrictTArray arr) = getBounds arr
482+
newArray b !e = StrictTArray <$> newArray b e
483+
newArray_ b = StrictTArray <$> newArray_ b
484+
unsafeRead (StrictTArray arr) i = unsafeRead arr i
485+
unsafeWrite (StrictTArray arr) i !e = unsafeWrite arr i e
486+
getNumElements (StrictTArray arr) = getNumElements arr
456487

457488
{-------------------------------------------------------------------------------
458489
Dealing with invariants

strict-stm/strict-stm.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ library
3737
exposed-modules: Control.Monad.Class.MonadSTM.Strict
3838
default-language: Haskell2010
3939
build-depends: base >=4.9 && <4.18,
40+
array,
4041
stm >=2.5 && <2.6,
4142
io-classes
4243
ghc-options: -Wall

0 commit comments

Comments
 (0)