Skip to content

Commit 11e82a4

Browse files
bolt12coot
authored andcommitted
Refactored timeout into better IOSim primitives
1 parent a1dd35e commit 11e82a4

File tree

3 files changed

+229
-53
lines changed

3 files changed

+229
-53
lines changed

io-sim/src/Control/Monad/IOSim/Internal.hs

Lines changed: 186 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ module Control.Monad.IOSim.Internal
5454
import Prelude hiding (read)
5555

5656
import Data.Dynamic
57-
import Data.Foldable (toList, traverse_)
57+
import Data.Foldable (toList, traverse_, foldlM)
58+
import Deque.Strict (Deque)
59+
import qualified Deque.Strict as Deque
5860
import qualified Data.List as List
5961
import qualified Data.List.Trace as Trace
6062
import Data.Map.Strict (Map)
@@ -65,16 +67,13 @@ import qualified Data.OrdPSQ as PSQ
6567
import Data.Set (Set)
6668
import qualified Data.Set as Set
6769
import Data.Time (UTCTime (..), fromGregorian)
68-
import Deque.Strict (Deque)
69-
import qualified Deque.Strict as Deque
7070

7171
import GHC.Exts (fromList)
7272
import GHC.Conc (ThreadStatus(..), BlockReason(..))
7373

74-
import Control.Exception (NonTermination (..), assert, throw)
75-
import Control.Monad (join)
76-
77-
import Control.Monad (when)
74+
import Control.Exception
75+
(NonTermination (..), assert, throw, AsyncException (..))
76+
import Control.Monad (join, when)
7877
import Control.Monad.ST.Lazy
7978
import Control.Monad.ST.Lazy.Unsafe (unsafeIOToST, unsafeInterleaveST)
8079
import Data.STRef.Lazy
@@ -126,6 +125,7 @@ data TimerCompletionInfo s =
126125
Timer !(TVar s TimeoutState)
127126
| TimerRegisterDelay !(TVar s Bool)
128127
| TimerThreadDelay !ThreadId
128+
| TimerTimeout !ThreadId !TimeoutId !(STRef s IsLocked)
129129

130130
-- | Internal state.
131131
--
@@ -138,7 +138,7 @@ data SimState s a = SimState {
138138
finished :: !(Map ThreadId FinishedReason),
139139
-- | current time
140140
curTime :: !Time,
141-
-- | ordered list of timers
141+
-- | ordered list of timers and timeouts
142142
timers :: !(OrdPSQ TimeoutId Time (TimerCompletionInfo s)),
143143
-- | list of clocks
144144
clocks :: !(Map ClockId UTCTime),
@@ -235,8 +235,53 @@ schedule !thread@Thread{
235235
let thread' = thread { threadControl = ThreadControl (k x) ctl' }
236236
schedule thread' simstate
237237

238+
TimeoutFrame tmid isLockedRef k ctl' -> do
239+
-- There is a possible race between timeout action and the timeout expiration.
240+
-- We use a lock to solve the race.
241+
--
242+
-- The lock starts 'NotLocked' and when the timeout fires the lock is
243+
-- locked and asynchronously an assassin thread is coming to interrupt
244+
-- it. If the lock is locked when the timeout is fired then nothing
245+
-- happens.
246+
--
247+
-- Knowing this, if we reached this point in the code and the lock is
248+
-- 'Locked', then it means that this thread still hasn't received the
249+
-- 'TimeoutException', so we need to kill the thread that is responsible
250+
-- for doing that (the assassin thread, we need to defend ourselves!)
251+
-- and run our continuation successfully and peacefully. We will do that
252+
-- by uninterruptibly-masking ourselves so we can not receive any
253+
-- exception and kill the assassin thread behind its back.
254+
-- If the lock is 'NotLocked' then it means we can just acquire it and
255+
-- carry on with the success case.
256+
locked <- readSTRef isLockedRef
257+
case locked of
258+
Locked etid -> do
259+
let -- Kill the assassin throwing thread and carry on the
260+
-- continuation
261+
thread' =
262+
thread { threadControl =
263+
ThreadControl (ThrowTo (toException ThreadKilled)
264+
etid
265+
(k (Just x)))
266+
ctl'
267+
, threadMasking = MaskedUninterruptible
268+
}
269+
schedule thread' simstate
270+
271+
NotLocked -> do
272+
-- Acquire lock
273+
writeSTRef isLockedRef (Locked tid)
274+
275+
-- Remove the timer from the queue
276+
let timers' = PSQ.delete tmid timers
277+
-- Run the continuation
278+
thread' = thread { threadControl = ThreadControl (k (Just x)) ctl' }
279+
280+
schedule thread' simstate { timers = timers'
281+
}
238282
Throw thrower e -> {-# SCC "schedule.Throw" #-}
239283
case unwindControlStack e thread of
284+
-- Found a CatchFrame
240285
Right thread'@Thread { threadMasking = maskst' } -> do
241286
-- We found a suitable exception handler, continue with that
242287
trace <- schedule thread' simstate
@@ -360,6 +405,23 @@ schedule !thread@Thread{
360405
, nextTmid = succ nextTmid }
361406
return (SimTrace time tid tlbl (EventTimerCreated nextTmid nextVid expiry) trace)
362407

408+
-- This case is guarded by checks in 'timeout' itself.
409+
StartTimeout d _ _ | d <= 0 ->
410+
error "schedule: StartTimeout: Impossible happened"
411+
412+
StartTimeout d action' k ->
413+
{-# SCC "schedule.StartTimeout" #-} do
414+
isLockedRef <- newSTRef NotLocked
415+
let !expiry = d `addTime` time
416+
!timers' = PSQ.insert nextTmid expiry (TimerTimeout tid nextTmid isLockedRef) timers
417+
!thread' = thread { threadControl =
418+
ThreadControl action'
419+
(TimeoutFrame nextTmid isLockedRef k ctl)
420+
}
421+
!trace <- deschedule Yield thread' simstate { timers = timers'
422+
, nextTmid = succ nextTmid }
423+
return (SimTrace time tid tlbl (EventTimeoutCreated nextTmid tid expiry) trace)
424+
363425
RegisterDelay d k | d < 0 ->
364426
{-# SCC "schedule.NewRegisterDelay" #-} do
365427
!tvar <- execNewTVar nextVid
@@ -404,7 +466,6 @@ schedule !thread@Thread{
404466
, nextTmid = succ nextTmid }
405467
return (SimTrace time tid tlbl (EventThreadDelay expiry) trace)
406468

407-
408469
-- we do not follow `GHC.Event` behaviour here; updating a timer to the past
409470
-- effectively cancels it.
410471
UpdateTimeout (Timeout _tvar tmid) d k | d < 0 ->
@@ -777,8 +838,23 @@ reschedule !simstate@SimState{ threads, timers, curTime = time } =
777838
wakeup = wakeupThreadDelay ++ wakeupSTM
778839
(_, !simstate') = unblockThreads wakeup simstate
779840

780-
!trace <- reschedule simstate' { curTime = time'
781-
, timers = timers' }
841+
-- For each 'timeout' action where the timeout has fired, start a
842+
-- new thread to execute throwTo to interrupt the action.
843+
!timeoutExpired = [ (tid, tmid, isLockedRef)
844+
| TimerTimeout tid tmid isLockedRef <- fired ]
845+
846+
-- Get the isLockedRef values
847+
!timeoutExpired' <- traverse (\(tid, tmid, isLockedRef) -> do
848+
locked <- readSTRef isLockedRef
849+
return (tid, tmid, isLockedRef, locked)
850+
)
851+
timeoutExpired
852+
853+
!simstate'' <- forkTimeoutInterruptThreads timeoutExpired' simstate'
854+
855+
!trace <- reschedule simstate'' { curTime = time'
856+
, timers = timers' }
857+
782858
return $
783859
traceMany ([ ( time', ThreadId [-1], Just "timer"
784860
, EventTimerFired tmid)
@@ -792,7 +868,13 @@ reschedule !simstate@SimState{ threads, timers, curTime = time } =
792868
, let Just vids = Set.toList <$> Map.lookup tid' wokeby ]
793869
++ [ ( time', tid, Just "thread delay timer"
794870
, EventThreadDelayFired)
795-
| tid <- wakeupThreadDelay ])
871+
| tid <- wakeupThreadDelay ]
872+
++ [ ( time', tid, Just "timeout timer"
873+
, EventTimeoutFired tmid)
874+
| (tid, tmid, _, _) <- timeoutExpired' ]
875+
++ [ ( time', tid, Just "thread forked"
876+
, EventThreadForked tid)
877+
| (tid, _, _, _) <- timeoutExpired' ])
796878
trace
797879
where
798880
timeoutSTMAction (Timer var) = do
@@ -804,7 +886,8 @@ reschedule !simstate@SimState{ threads, timers, curTime = time } =
804886
timeoutSTMAction (TimerRegisterDelay var) = writeTVar var True
805887
-- Note that 'threadDelay' is not handled via STM style wakeup, but rather
806888
-- it's handled directly above with 'wakeupThreadDelay' and 'unblockThreads'
807-
timeoutSTMAction (TimerThreadDelay _) = return ()
889+
timeoutSTMAction TimerThreadDelay{} = return ()
890+
timeoutSTMAction TimerTimeout{} = return ()
808891

809892
unblockThreads :: [ThreadId] -> SimState s a -> ([ThreadId], SimState s a)
810893
unblockThreads !wakeup !simstate@SimState {runqueue, threads} =
@@ -825,7 +908,76 @@ unblockThreads !wakeup !simstate@SimState {runqueue, threads} =
825908
-- and in which case we mark them as now running
826909
!threads' = List.foldl'
827910
(flip (Map.adjust (\t -> t { threadBlocked = False })))
828-
threads unblocked
911+
threads
912+
unblocked
913+
914+
-- | This function receives a list of TimerTimeout values that represent threads
915+
-- for which the timeout expired and kills the running thread if needed.
916+
--
917+
-- This function is responsible for the second part of the race condition issue
918+
-- and relates to the 'schedule's 'TimeoutFrame' locking explanation (here is
919+
-- where the assassin threads are launched. So, as explained previously, at this
920+
-- point in code, the timeout expired so we need to interrupt the running
921+
-- thread. If the running thread finished at the same time the timeout expired
922+
-- we have a race condition. To deal with this race condition what we do is
923+
-- look at the lock value. If it is 'Locked' this means that the running thread
924+
-- already finished (or won the race) so we can safely do nothing. Otherwise, if
925+
-- the lock value is 'NotLocked' we need to acquire the lock and launch an
926+
-- assassin thread that is going to interrupt the running one. Note that we
927+
-- should run this interrupting thread in an unmasked state since it might
928+
-- receive a 'ThreadKilled' exception.
929+
--
930+
forkTimeoutInterruptThreads :: [(ThreadId, TimeoutId, STRef s IsLocked, IsLocked)]
931+
-> SimState s a
932+
-> ST s (SimState s a)
933+
forkTimeoutInterruptThreads timeoutExpired simState@SimState {threads} =
934+
foldlM (\st@SimState{ runqueue = runqueue,
935+
threads = threads'
936+
}
937+
(t, isLockedRef)
938+
-> do
939+
let tid' = threadId t
940+
threads'' = Map.insert tid' t threads'
941+
runqueue' = Deque.snoc tid' runqueue
942+
943+
writeSTRef isLockedRef (Locked tid')
944+
945+
return st { runqueue = runqueue',
946+
threads = threads''
947+
})
948+
simState
949+
throwToThread
950+
951+
where
952+
-- can only throw exception if the thread exists and if the mutually
953+
-- exclusive lock exists and is still 'NotLocked'
954+
toThrow = [ (tid, tmid, ref, t)
955+
| (tid, tmid, ref, locked) <- timeoutExpired
956+
, Just t <- [Map.lookup tid threads]
957+
, NotLocked <- [locked]
958+
]
959+
-- we launch a thread responsible for throwing an AsyncCancelled exception
960+
-- to the thread which timeout expired
961+
throwToThread =
962+
[ let nextId = threadNextTId t
963+
tid' = childThreadId tid nextId
964+
in ( Thread { threadId = tid',
965+
threadControl =
966+
ThreadControl
967+
(ThrowTo (toException (TimeoutException tmid))
968+
tid
969+
(Return ()))
970+
ForkFrame,
971+
threadBlocked = False,
972+
threadMasking = Unmasked,
973+
threadThrowTo = [],
974+
threadClockId = threadClockId t,
975+
threadLabel = Just "timeout-forked-thread",
976+
threadNextTId = 1
977+
}
978+
, ref )
979+
| (tid, tmid, ref, t) <- toThrow
980+
]
829981

830982

831983
-- | Iterate through the control stack to find an enclosing exception handler
@@ -843,7 +995,8 @@ unwindControlStack e thread =
843995
ThreadControl _ ctl -> unwind (threadMasking thread) ctl
844996
where
845997
unwind :: forall s' c. MaskingState
846-
-> ControlStack s' c a -> Either Bool (Thread s' a)
998+
-> ControlStack s' c a
999+
-> Either Bool (Thread s' a)
8471000
unwind _ MainFrame = Left True
8481001
unwind _ ForkFrame = Left False
8491002
unwind _ (MaskFrame _k maskst' ctl) = unwind maskst' ctl
@@ -855,12 +1008,28 @@ unwindControlStack e thread =
8551008

8561009
-- Ok! We will be able to continue the thread with the handler
8571010
-- followed by the continuation after the catch
858-
Just e' -> Right thread {
859-
-- As per async exception rules, the handler is run masked
1011+
Just e' -> Right ( thread {
1012+
-- As per async exception rules, the handler is run
1013+
-- masked
8601014
threadControl = ThreadControl (handler e')
8611015
(MaskFrame k maskst ctl),
8621016
threadMasking = atLeastInterruptibleMask maskst
8631017
}
1018+
)
1019+
1020+
-- Either Timeout fired or the action threw an exception.
1021+
-- - If Timeout fired, then it was possibly during this thread's execution
1022+
-- so we need to run the continuation with a Nothing value.
1023+
-- - If the timeout action threw an exception we need to keep unwinding the
1024+
-- control stack looking for a handler to this exception.
1025+
unwind maskst (TimeoutFrame tmid _ k ctl) =
1026+
case fromException e of
1027+
-- Exception came from timeout expiring
1028+
Just (TimeoutException tmid') ->
1029+
assert (tmid == tmid')
1030+
Right thread { threadControl = ThreadControl (k Nothing) ctl }
1031+
-- Exception came from a different exception
1032+
_ -> unwind maskst ctl
8641033

8651034
atLeastInterruptibleMask :: MaskingState -> MaskingState
8661035
atLeastInterruptibleMask Unmasked = MaskedInterruptible
Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
1-
{-# LANGUAGE GADTs #-}
2-
{-# LANGUAGE StandaloneDeriving #-}
1+
{-# LANGUAGE GADTs #-}
2+
{-# LANGUAGE StandaloneDeriving #-}
3+
{-# LANGUAGE RankNTypes #-}
4+
{-# LANGUAGE ScopedTypeVariables #-}
35

46
-- | Internal types shared between `IOSim` and `IOSimPOR`.
57
--
68
module Control.Monad.IOSim.InternalTypes
79
( ThreadControl (..)
810
, ControlStack (..)
11+
, IsLocked (..)
912
) where
1013

14+
import Data.STRef.Lazy (STRef)
1115
import Control.Exception (Exception)
1216
import Control.Monad.Class.MonadThrow (MaskingState (..))
1317

14-
import Control.Monad.IOSim.Types (SimA)
18+
import Control.Monad.IOSim.Types (SimA, ThreadId, TimeoutId)
1519

1620
-- We hide the type @b@ here, so it's useful to bundle these two parts together,
1721
-- rather than having Thread have an existential type, which makes record
@@ -25,29 +29,43 @@ instance Show (ThreadControl s a) where
2529
show _ = "..."
2630

2731
data ControlStack s b a where
28-
MainFrame :: ControlStack s a a
29-
ForkFrame :: ControlStack s () a
30-
MaskFrame :: (b -> SimA s c) -- subsequent continuation
31-
-> !MaskingState -- thread local state to restore
32+
MainFrame :: ControlStack s a a
33+
ForkFrame :: ControlStack s () a
34+
MaskFrame :: (b -> SimA s c) -- subsequent continuation
35+
-> MaskingState -- thread local state to restore
3236
-> !(ControlStack s c a)
33-
-> ControlStack s b a
34-
CatchFrame :: Exception e
35-
=> (e -> SimA s b) -- exception continuation
36-
-> (b -> SimA s c) -- subsequent continuation
37+
-> ControlStack s b a
38+
CatchFrame :: Exception e
39+
=> (e -> SimA s b) -- exception continuation
40+
-> (b -> SimA s c) -- subsequent continuation
3741
-> !(ControlStack s c a)
38-
-> ControlStack s b a
42+
-> ControlStack s b a
43+
TimeoutFrame :: TimeoutId
44+
-> STRef s IsLocked
45+
-> (Maybe b -> SimA s c)
46+
-> !(ControlStack s c a)
47+
-> ControlStack s b a
3948

4049
instance Show (ControlStack s b a) where
4150
show = show . dash
42-
where dash :: ControlStack s' b' a' -> ControlStackDash
43-
dash MainFrame = MainFrame'
44-
dash ForkFrame = ForkFrame'
45-
dash (MaskFrame _ m s) = MaskFrame' m (dash s)
46-
dash (CatchFrame _ _ s) = CatchFrame' (dash s)
51+
where
52+
dash :: ControlStack s b' a -> ControlStackDash
53+
dash MainFrame = MainFrame'
54+
dash ForkFrame = ForkFrame'
55+
dash (MaskFrame _ m cs) = MaskFrame' m (dash cs)
56+
dash (CatchFrame _ _ cs) = CatchFrame' (dash cs)
57+
dash (TimeoutFrame tmid _ _ cs) = TimeoutFrame' tmid (dash cs)
4758

4859
data ControlStackDash =
4960
MainFrame'
5061
| ForkFrame'
5162
| MaskFrame' MaskingState ControlStackDash
5263
| CatchFrame' ControlStackDash
64+
-- TODO: Figure out a better way to include IsLocked here
65+
| TimeoutFrame' TimeoutId ControlStackDash
66+
| ThreadDelayFrame' TimeoutId ControlStackDash
5367
deriving Show
68+
69+
data IsLocked = NotLocked | Locked !ThreadId
70+
deriving (Eq, Show)
71+

0 commit comments

Comments
 (0)