Skip to content

Commit 8777fcd

Browse files
iohk-bors[bot]coot
andauthored
Merge #3207
3207: Use newtype for StrictTVar if checktvarinvariant cabal flag is not set r=coot a=coot - renamed cabal.project.ci.windows file - io-sim-classes: StrictTVar representation - Updated cabal.project.local.ci Co-authored-by: Marcin Szamotulski <[email protected]>
2 parents 4d57942 + 5c05f93 commit 8777fcd

File tree

1 file changed

+30
-11
lines changed
  • io-sim-classes/src/Control/Monad/Class/MonadSTM

1 file changed

+30
-11
lines changed

io-sim-classes/src/Control/Monad/Class/MonadSTM/Strict.hs

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,17 @@ type LazyTMVar m = Lazy.TMVar m
7676
Strict TVar
7777
-------------------------------------------------------------------------------}
7878

79+
#if CHECK_TVAR_INVARIANT
7980
data StrictTVar m a = StrictTVar
8081
{ invariant :: !(a -> Maybe String)
8182
-- ^ Invariant checked whenever updating the 'StrictTVar'.
8283
, tvar :: !(LazyTVar m a)
8384
}
85+
#else
86+
newtype StrictTVar m a = StrictTVar
87+
{ tvar :: LazyTVar m a
88+
}
89+
#endif
8490

8591
labelTVar :: MonadLabelledSTM m => StrictTVar m a -> String -> STM m ()
8692
labelTVar StrictTVar { tvar } = Lazy.labelTVar tvar
@@ -90,7 +96,8 @@ labelTVarIO v = atomically . labelTVar v
9096

9197
castStrictTVar :: LazyTVar m ~ LazyTVar n
9298
=> StrictTVar m a -> StrictTVar n a
93-
castStrictTVar StrictTVar{invariant, tvar} = StrictTVar{invariant, tvar}
99+
castStrictTVar v@StrictTVar {tvar} =
100+
mkStrictTVar (getInvariant v) tvar
94101

95102
-- | Get the underlying @TVar@
96103
--
@@ -100,7 +107,8 @@ toLazyTVar :: StrictTVar m a -> LazyTVar m a
100107
toLazyTVar StrictTVar { tvar } = tvar
101108

102109
newTVar :: MonadSTM m => a -> STM m (StrictTVar m a)
103-
newTVar !a = StrictTVar (const Nothing) <$> Lazy.newTVar a
110+
newTVar !a = (\tvar -> mkStrictTVar (const Nothing) tvar)
111+
<$> Lazy.newTVar a
104112

105113
newTVarIO :: MonadSTM m => a -> m (StrictTVar m a)
106114
newTVarIO = newTVarWithInvariantIO (const Nothing)
@@ -113,9 +121,10 @@ newTVarWithInvariantIO :: (MonadSTM m, HasCallStack)
113121
=> (a -> Maybe String) -- ^ Invariant (expect 'Nothing')
114122
-> a
115123
-> m (StrictTVar m a)
116-
newTVarWithInvariantIO invariant !a =
117-
checkInvariant (invariant a) $
118-
StrictTVar invariant <$> Lazy.newTVarIO a
124+
newTVarWithInvariantIO invariant !a =
125+
checkInvariant (invariant a) $
126+
(\tvar -> mkStrictTVar invariant tvar)
127+
<$> Lazy.newTVarIO a
119128

120129
newTVarWithInvariantM :: (MonadSTM m, HasCallStack)
121130
=> (a -> Maybe String) -- ^ Invariant (expect 'Nothing')
@@ -131,9 +140,9 @@ readTVarIO :: MonadSTM m => StrictTVar m a -> m a
131140
readTVarIO StrictTVar { tvar } = Lazy.readTVarIO tvar
132141

133142
writeTVar :: (MonadSTM m, HasCallStack) => StrictTVar m a -> a -> STM m ()
134-
writeTVar StrictTVar { tvar, invariant } !a =
135-
checkInvariant (invariant a) $
136-
Lazy.writeTVar tvar a
143+
writeTVar v !a =
144+
checkInvariant (getInvariant v a) $
145+
Lazy.writeTVar (tvar v) a
137146

138147
modifyTVar :: MonadSTM m => StrictTVar m a -> (a -> a) -> STM m ()
139148
modifyTVar v f = readTVar v >>= writeTVar v . f
@@ -225,6 +234,9 @@ isEmptyTMVar (StrictTMVar tmvar) = Lazy.isEmptyTMVar tmvar
225234
Dealing with invariants
226235
-------------------------------------------------------------------------------}
227236

237+
getInvariant :: StrictTVar m a -> a -> Maybe String
238+
mkStrictTVar :: (a -> Maybe String) -> Lazy.TVar m a -> StrictTVar m a
239+
228240
-- | Check invariant (if enabled) before continuing
229241
--
230242
-- @checkInvariant mErr x@ is equal to @x@ if @mErr == Nothing@, and throws
@@ -234,9 +246,16 @@ isEmptyTMVar (StrictTMVar tmvar) = Lazy.isEmptyTMVar tmvar
234246
-- invariants can reuse the same logic, rather than having to introduce new
235247
-- per-package flags.
236248
checkInvariant :: HasCallStack => Maybe String -> a -> a
249+
237250
#if CHECK_TVAR_INVARIANT
238-
checkInvariant Nothing k = k
239-
checkInvariant (Just err) _ = error $ "Invariant violation: " ++ err
251+
getInvariant StrictTVar {invariant} = invariant
252+
mkStrictTVar invariant tvar = StrictTVar {invariant, tvar}
253+
254+
checkInvariant Nothing k = k
255+
checkInvariant (Just err) _ = error $ "Invariant violation: " ++ err
240256
#else
241-
checkInvariant _err k = k
257+
getInvariant _ = \_ -> Nothing
258+
mkStrictTVar _invariant tvar = StrictTVar {tvar}
259+
260+
checkInvariant _err k = k
242261
#endif

0 commit comments

Comments
 (0)