Skip to content

Commit 882ec3b

Browse files
iohk-bors[bot]coot
andauthored
Merge #3437
3437: MonadMaskingState r=coot a=coot - io-sim: do not arbitrary set masking state in `Throw` or `ThrowTo` - io-classes: `MonadMaskingState` - io-sim: introduce various test for the masking state Related to #3436 Co-authored-by: Marcin Szamotulski <[email protected]>
2 parents 64dc2ee + ebdccf4 commit 882ec3b

File tree

3 files changed

+317
-13
lines changed

3 files changed

+317
-13
lines changed

io-classes/src/Control/Monad/Class/MonadThrow.hs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ module Control.Monad.Class.MonadThrow
1010
( MonadThrow(..)
1111
, MonadCatch(..)
1212
, MonadMask(..)
13+
, MonadMaskingState(..)
1314
, MonadEvaluate(..)
15+
, MaskingState(..)
1416
, Exception(..)
1517
, SomeException
1618
, ExitCase(..)
@@ -20,7 +22,7 @@ module Control.Monad.Class.MonadThrow
2022
, throwM
2123
) where
2224

23-
import Control.Exception (Exception (..), SomeException)
25+
import Control.Exception (Exception (..), MaskingState, SomeException)
2426
import qualified Control.Exception as IO
2527
import Control.Monad (liftM)
2628
import Control.Monad.Except (ExceptT (..), lift, runExceptT)
@@ -183,6 +185,9 @@ class MonadCatch m => MonadMask m where
183185
uninterruptibleMask_ action = uninterruptibleMask $ \_ -> action
184186

185187

188+
class MonadMask m => MonadMaskingState m where
189+
getMaskingState :: m MaskingState
190+
186191
-- | Monads which can 'evaluate'.
187192
--
188193
class MonadThrow m => MonadEvaluate m where
@@ -223,6 +228,9 @@ instance MonadMask IO where
223228
uninterruptibleMask = IO.uninterruptibleMask
224229
uninterruptibleMask_ = IO.uninterruptibleMask_
225230

231+
instance MonadMaskingState IO where
232+
getMaskingState = IO.getMaskingState
233+
226234
instance MonadEvaluate IO where
227235
evaluate = IO.evaluate
228236

io-sim/src/Control/Monad/IOSim/Internal.hs

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ import Control.Monad.Class.MonadSay
9797
import Control.Monad.Class.MonadST
9898
import Control.Monad.Class.MonadSTM hiding (STM, TVar)
9999
import qualified Control.Monad.Class.MonadSTM as MonadSTM
100-
import Control.Monad.Class.MonadThrow as MonadThrow
100+
import Control.Monad.Class.MonadThrow hiding (getMaskingState)
101+
import qualified Control.Monad.Class.MonadThrow as MonadThrow
101102
import Control.Monad.Class.MonadTime
102103
import Control.Monad.Class.MonadTimer
103104

@@ -174,9 +175,6 @@ type STMSim = STM
174175
type SimSTM = STM
175176
{-# DEPRECATED SimSTM "Use STMSim" #-}
176177

177-
data MaskingState = Unmasked | MaskedInterruptible | MaskedUninterruptible
178-
deriving (Eq, Ord, Show)
179-
180178
--
181179
-- Monad class instances
182180
--
@@ -300,19 +298,22 @@ instance Exceptions.MonadCatch (IOSim s) where
300298

301299
instance MonadMask (IOSim s) where
302300
mask action = do
303-
b <- getMaskingState
301+
b <- getMaskingStateImpl
304302
case b of
305303
Unmasked -> block $ action unblock
306304
MaskedInterruptible -> action block
307305
MaskedUninterruptible -> action blockUninterruptible
308306

309307
uninterruptibleMask action = do
310-
b <- getMaskingState
308+
b <- getMaskingStateImpl
311309
case b of
312310
Unmasked -> blockUninterruptible $ action unblock
313311
MaskedInterruptible -> blockUninterruptible $ action block
314312
MaskedUninterruptible -> action blockUninterruptible
315313

314+
instance MonadMaskingState (IOSim s) where
315+
getMaskingState = getMaskingStateImpl
316+
316317
instance Exceptions.MonadMask (IOSim s) where
317318
mask = MonadThrow.mask
318319
uninterruptibleMask = MonadThrow.uninterruptibleMask
@@ -327,10 +328,10 @@ instance Exceptions.MonadMask (IOSim s) where
327328
return (b, c)
328329

329330

330-
getMaskingState :: IOSim s MaskingState
331+
getMaskingStateImpl :: IOSim s MaskingState
331332
unblock, block, blockUninterruptible :: IOSim s a -> IOSim s a
332333

333-
getMaskingState = IOSim GetMaskState
334+
getMaskingStateImpl = IOSim GetMaskState
334335
unblock a = IOSim (SetMaskState Unmasked a)
335336
block a = IOSim (SetMaskState MaskedInterruptible a)
336337
blockUninterruptible a = IOSim (SetMaskState MaskedUninterruptible a)
@@ -1064,8 +1065,7 @@ schedule thread@Thread{
10641065
ThrowTo e tid' _ | tid' == tid -> do
10651066
-- Throw to ourself is equivalent to a synchronous throw,
10661067
-- and works irrespective of masking state since it does not block.
1067-
let thread' = thread { threadControl = ThreadControl (Throw e) ctl
1068-
, threadMasking = MaskedInterruptible }
1068+
let thread' = thread { threadControl = ThreadControl (Throw e) ctl }
10691069
trace <- schedule thread' simstate
10701070
return (SimTrace time tid tlbl (EventThrowTo e tid) trace)
10711071

@@ -1096,7 +1096,7 @@ schedule thread@Thread{
10961096
let adjustTarget t@Thread{ threadControl = ThreadControl _ ctl' } =
10971097
t { threadControl = ThreadControl (Throw e) ctl'
10981098
, threadBlocked = False
1099-
, threadMasking = MaskedInterruptible }
1099+
}
11001100
simstate'@SimState { threads = threads' }
11011101
= snd (unblockThreads [tid'] simstate)
11021102
threads'' = Map.adjust adjustTarget tid' threads'
@@ -1290,9 +1290,13 @@ unwindControlStack e thread =
12901290
-- As per async exception rules, the handler is run masked
12911291
threadControl = ThreadControl (handler e')
12921292
(MaskFrame k maskst ctl),
1293-
threadMasking = max maskst MaskedInterruptible
1293+
threadMasking = atLeastInterruptibleMask maskst
12941294
}
12951295

1296+
atLeastInterruptibleMask :: MaskingState -> MaskingState
1297+
atLeastInterruptibleMask Unmasked = MaskedInterruptible
1298+
atLeastInterruptibleMask ms = ms
1299+
12961300

12971301
removeMinimums :: (Ord k, Ord p)
12981302
=> OrdPSQ k p a

0 commit comments

Comments
 (0)