Skip to content

Commit 75e5352

Browse files
committed
Provide a exception-safe withTMVar combinator
1 parent fc24459 commit 75e5352

File tree

4 files changed

+98
-1
lines changed

4 files changed

+98
-1
lines changed

io-classes/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
* Added `threadLabel` to `MonadThread`
88
* Added `MonadLabelledMVar` class.
9+
* Added `withTMVar` and `withTMVarAnd` functions.
910

1011
### 1.7.0.0
1112

io-classes/src/Control/Concurrent/Class/MonadSTM/TMVar.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ module Control.Concurrent.Class.MonadSTM.TMVar
1818
, swapTMVar
1919
, writeTMVar
2020
, isEmptyTMVar
21+
, withTMVar
22+
, withTMVarAnd
2123
-- * MonadLabelledSTM
2224
, labelTMVar
2325
, labelTMVarIO

io-classes/src/Control/Monad/Class/MonadSTM/Internal.hs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
{-# LANGUAGE FlexibleInstances #-}
77
{-# LANGUAGE GADTs #-}
88
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
9+
{-# LANGUAGE LambdaCase #-}
910
{-# LANGUAGE MultiParamTypeClasses #-}
1011
{-# LANGUAGE NamedFieldPuns #-}
1112
{-# LANGUAGE PatternSynonyms #-}
@@ -99,6 +100,9 @@ module Control.Monad.Class.MonadSTM.Internal
99100
, isEmptyTChanDefault
100101
, cloneTChanDefault
101102
, labelTChanDefault
103+
-- * WithTMVar
104+
, withTMVar
105+
, withTMVarAnd
102106
) where
103107

104108
import Prelude hiding (read)
@@ -1257,3 +1261,40 @@ instance MonadSTM m => MonadSTM (ReaderT r m) where
12571261
writeTMVar' :: STM.TMVar a -> a -> STM.STM ()
12581262
writeTMVar' t new = STM.tryTakeTMVar t >> STM.putTMVar t new
12591263
#endif
1264+
1265+
1266+
-- | Apply @f@ with the content of @tv@ as state, restoring the original value when an
1267+
-- exception occurs
1268+
withTMVar ::
1269+
(MonadSTM m, MonadThrow.MonadCatch m)
1270+
=> TMVar m a
1271+
-> (a -> m (c, a))
1272+
-> m c
1273+
withTMVar tv f = withTMVarAnd tv (const $ pure ()) (\a -> const $ f a)
1274+
1275+
-- | Apply @f@ with the content of @tv@ as state, restoring the original value
1276+
-- when an exception occurs. Additionally run a @STM@ action when acquiring the
1277+
-- value.
1278+
withTMVarAnd ::
1279+
(MonadSTM m, MonadThrow.MonadCatch m)
1280+
=> TMVar m a
1281+
-> (a -> STM m b) -- ^ Additional STM action to run in the same atomically
1282+
-- block as the TMVar is acquired
1283+
-> (a -> b -> m (c, a)) -- ^ Action
1284+
-> m c
1285+
withTMVarAnd tv guard f =
1286+
fst . fst <$> MonadThrow.generalBracket
1287+
(atomically $ do
1288+
istate <- takeTMVar tv
1289+
guarded <- guard istate
1290+
pure (istate, guarded)
1291+
)
1292+
(\(origState, _) -> \case
1293+
MonadThrow.ExitCaseSuccess (_, newState)
1294+
-> atomically $ putTMVar tv newState
1295+
MonadThrow.ExitCaseException _
1296+
-> atomically $ putTMVar tv origState
1297+
MonadThrow.ExitCaseAbort
1298+
-> atomically $ putTMVar tv origState
1299+
)
1300+
(uncurry f)

io-classes/strict-stm/Control/Concurrent/Class/MonadSTM/Strict/TMVar.hs

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{-# LANGUAGE BangPatterns #-}
22
{-# LANGUAGE ExplicitNamespaces #-}
33
{-# LANGUAGE GADTs #-}
4+
{-# LANGUAGE LambdaCase #-}
45
{-# LANGUAGE TypeOperators #-}
56

67
-- | This module corresponds to `Control.Concurrent.STM.TMVar` in "stm" package
@@ -25,18 +26,22 @@ module Control.Concurrent.Class.MonadSTM.Strict.TMVar
2526
, swapTMVar
2627
, writeTMVar
2728
, isEmptyTMVar
29+
, withTMVar
30+
, withTMVarAnd
2831
-- * MonadLabelledSTM
2932
, labelTMVar
3033
, labelTMVarIO
3134
-- * MonadTraceSTM
3235
, traceTMVar
3336
, traceTMVarIO
37+
, traceTMVarShow
38+
, traceTMVarShowIO
3439
) where
3540

3641

3742
import Control.Concurrent.Class.MonadSTM.TMVar qualified as Lazy
3843
import Control.Monad.Class.MonadSTM hiding (traceTMVar, traceTMVarIO)
39-
44+
import Control.Monad.Class.MonadThrow
4045

4146
type LazyTMVar m = Lazy.TMVar m
4247

@@ -59,12 +64,39 @@ traceTMVar :: MonadTraceSTM m
5964
-> STM m ()
6065
traceTMVar p (StrictTMVar var) = Lazy.traceTMVar p var
6166

67+
traceTMVarShow :: (MonadTraceSTM m, Show a)
68+
=> proxy m
69+
-> StrictTMVar m a
70+
-> STM m ()
71+
traceTMVarShow p tmvar =
72+
traceTMVar p tmvar (\pv v -> pure $ TraceString $ case (pv, v) of
73+
(Nothing, Nothing) -> "Created empty"
74+
(Nothing, Just st') -> "Created full: " <> show st'
75+
(Just Nothing, Just st') -> "Put: " <> show st'
76+
(Just Nothing, Nothing) -> "Remains empty"
77+
(Just Just{}, Nothing) -> "Take"
78+
(Just (Just st'), Just st'') -> "Modified: " <> show st' <> " -> " <> show st''
79+
)
80+
6281
traceTMVarIO :: MonadTraceSTM m
6382
=> StrictTMVar m a
6483
-> (Maybe (Maybe a) -> (Maybe a) -> InspectMonad m TraceValue)
6584
-> m ()
6685
traceTMVarIO (StrictTMVar var) = Lazy.traceTMVarIO var
6786

87+
traceTMVarShowIO :: (Show a, MonadTraceSTM m)
88+
=> StrictTMVar m a
89+
-> m ()
90+
traceTMVarShowIO tmvar =
91+
traceTMVarIO tmvar (\pv v -> pure $ TraceString $ case (pv, v) of
92+
(Nothing, Nothing) -> "Created empty"
93+
(Nothing, Just st') -> "Created full: " <> show st'
94+
(Just Nothing, Just st') -> "Put: " <> show st'
95+
(Just Nothing, Nothing) -> "Remains empty"
96+
(Just Just{}, Nothing) -> "Take"
97+
(Just (Just st'), Just st'') -> "Modified: " <> show st' <> " -> " <> show st''
98+
)
99+
68100
castStrictTMVar :: LazyTMVar m ~ LazyTMVar n
69101
=> StrictTMVar m a -> StrictTMVar n a
70102
castStrictTMVar (StrictTMVar var) = StrictTMVar var
@@ -107,3 +139,24 @@ writeTMVar (StrictTMVar tmvar) !a = Lazy.writeTMVar tmvar a
107139

108140
isEmptyTMVar :: MonadSTM m => StrictTMVar m a -> STM m Bool
109141
isEmptyTMVar (StrictTMVar tmvar) = Lazy.isEmptyTMVar tmvar
142+
143+
withTMVar :: (MonadSTM m, MonadCatch m)
144+
=> StrictTMVar m a
145+
-> (a -> m (c, a))
146+
-> m c
147+
withTMVar (StrictTMVar tmvar) f =
148+
Lazy.withTMVar tmvar (\x -> do
149+
!(!c, !a) <- f x
150+
pure $! (c, a)
151+
)
152+
153+
withTMVarAnd :: (MonadSTM m, MonadCatch m)
154+
=> StrictTMVar m a
155+
-> (a -> STM m b)
156+
-> (a -> b -> m (c, a))
157+
-> m c
158+
withTMVarAnd (StrictTMVar tmvar) f g =
159+
Lazy.withTMVarAnd tmvar f (\x y -> do
160+
!(!c, !a) <- g x y
161+
pure $! (c, a)
162+
)

0 commit comments

Comments
 (0)