Skip to content

Commit d5b435b

Browse files
committed
Implement exception-safe withTMVar
1 parent 1170a41 commit d5b435b

File tree

3 files changed

+67
-1
lines changed
  • io-classes

3 files changed

+67
-1
lines changed

io-classes/src/Control/Concurrent/Class/MonadSTM/TMVar.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ module Control.Concurrent.Class.MonadSTM.TMVar
1818
, swapTMVar
1919
, writeTMVar
2020
, isEmptyTMVar
21+
, withTMVar
22+
, withTMVarAnd
2123
-- * MonadLabelledSTM
2224
, labelTMVar
2325
, labelTMVarIO

io-classes/src/Control/Monad/Class/MonadSTM/Internal.hs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
{-# LANGUAGE FlexibleInstances #-}
77
{-# LANGUAGE GADTs #-}
88
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
9+
{-# LANGUAGE LambdaCase #-}
910
{-# LANGUAGE MultiParamTypeClasses #-}
1011
{-# LANGUAGE NamedFieldPuns #-}
1112
{-# LANGUAGE PatternSynonyms #-}
@@ -99,6 +100,9 @@ module Control.Monad.Class.MonadSTM.Internal
99100
, isEmptyTChanDefault
100101
, cloneTChanDefault
101102
, labelTChanDefault
103+
-- * WithTMVar
104+
, withTMVar
105+
, withTMVarAnd
102106
) where
103107

104108
import Prelude hiding (read)
@@ -1257,3 +1261,40 @@ instance MonadSTM m => MonadSTM (ReaderT r m) where
12571261
writeTMVar' :: STM.TMVar a -> a -> STM.STM ()
12581262
writeTMVar' t new = STM.tryTakeTMVar t >> STM.putTMVar t new
12591263
#endif
1264+
1265+
1266+
-- | Apply @f@ with the content of @tv@ as state, restoring the original value when an
1267+
-- exception occurs
1268+
withTMVar ::
1269+
(MonadSTM m, MonadThrow.MonadCatch m)
1270+
=> TMVar m a
1271+
-> (a -> m (c, a))
1272+
-> m c
1273+
withTMVar tv f = withTMVarAnd tv (const $ pure ()) (\a -> const $ f a)
1274+
1275+
-- | Apply @f@ with the content of @tv@ as state, restoring the original value
1276+
-- when an exception occurs. Additionally run a @STM@ action when acquiring the
1277+
-- value.
1278+
withTMVarAnd ::
1279+
(MonadSTM m, MonadThrow.MonadCatch m)
1280+
=> TMVar m a
1281+
-> (a -> STM m b) -- ^ Additional STM action to run in the same atomically
1282+
-- block as the TMVar is acquired
1283+
-> (a -> b -> m (c, a)) -- ^ Action
1284+
-> m c
1285+
withTMVarAnd tv guard f =
1286+
fst . fst <$> MonadThrow.generalBracket
1287+
(atomically $ do
1288+
istate <- takeTMVar tv
1289+
guarded <- guard istate
1290+
pure (istate, guarded)
1291+
)
1292+
(\(origState, _) -> \case
1293+
MonadThrow.ExitCaseSuccess (_, newState)
1294+
-> atomically $ putTMVar tv newState
1295+
MonadThrow.ExitCaseException _
1296+
-> atomically $ putTMVar tv origState
1297+
MonadThrow.ExitCaseAbort
1298+
-> atomically $ putTMVar tv origState
1299+
)
1300+
(uncurry f)

io-classes/strict-stm/Control/Concurrent/Class/MonadSTM/Strict/TMVar.hs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ module Control.Concurrent.Class.MonadSTM.Strict.TMVar
2525
, swapTMVar
2626
, writeTMVar
2727
, isEmptyTMVar
28+
, withTMVar
29+
, withTMVarAnd
2830
-- * MonadLabelledSTM
2931
, labelTMVar
3032
, labelTMVarIO
@@ -36,7 +38,7 @@ module Control.Concurrent.Class.MonadSTM.Strict.TMVar
3638

3739
import Control.Concurrent.Class.MonadSTM.TMVar qualified as Lazy
3840
import Control.Monad.Class.MonadSTM hiding (traceTMVar, traceTMVarIO)
39-
41+
import Control.Monad.Class.MonadThrow
4042

4143
type LazyTMVar m = Lazy.TMVar m
4244

@@ -107,3 +109,24 @@ writeTMVar (StrictTMVar tmvar) !a = Lazy.writeTMVar tmvar a
107109

108110
isEmptyTMVar :: MonadSTM m => StrictTMVar m a -> STM m Bool
109111
isEmptyTMVar (StrictTMVar tmvar) = Lazy.isEmptyTMVar tmvar
112+
113+
withTMVar :: (MonadSTM m, MonadCatch m)
114+
=> StrictTMVar m a
115+
-> (a -> m (c, a))
116+
-> m c
117+
withTMVar (StrictTMVar tmvar) f =
118+
Lazy.withTMVar tmvar (\x -> do
119+
!(!c, !a) <- f x
120+
pure $! (c, a)
121+
)
122+
123+
withTMVarAnd :: (MonadSTM m, MonadCatch m)
124+
=> StrictTMVar m a
125+
-> (a -> STM m b)
126+
-> (a -> b -> m (c, a))
127+
-> m c
128+
withTMVarAnd (StrictTMVar tmvar) f g =
129+
Lazy.withTMVarAnd tmvar f (\x y -> do
130+
!(!c, !a) <- g x y
131+
pure $! (c, a)
132+
)

0 commit comments

Comments
 (0)