Skip to content

Commit 8e05156

Browse files
bolt12coot
authored andcommitted
Added IOSimPOR counterpart
1 parent 11e82a4 commit 8e05156

File tree

1 file changed

+256
-32
lines changed

1 file changed

+256
-32
lines changed

io-sim/src/Control/Monad/IOSimPOR/Internal.hs

Lines changed: 256 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ module Control.Monad.IOSimPOR.Internal
5757
import Prelude hiding (read)
5858

5959
import Data.Dynamic
60-
import Data.Foldable (traverse_)
60+
import Data.Foldable (traverse_, foldlM)
6161
import qualified Data.List as List
6262
import qualified Data.List.Trace as Trace
6363
import Data.Map.Strict (Map)
@@ -70,10 +70,9 @@ import Data.Set (Set)
7070
import qualified Data.Set as Set
7171
import Data.Time (UTCTime (..), fromGregorian)
7272

73-
import Control.Exception (NonTermination (..), assert, throw)
74-
import Control.Monad (join)
75-
76-
import Control.Monad (when)
73+
import Control.Exception
74+
(NonTermination (..), assert, throw, AsyncException (..))
75+
import Control.Monad ( join, when )
7776
import Control.Monad.ST.Lazy
7877
import Control.Monad.ST.Lazy.Unsafe (unsafeIOToST, unsafeInterleaveST)
7978
import Data.STRef.Lazy
@@ -161,9 +160,13 @@ labelledThreads threadMap =
161160

162161

163162
-- | Timers mutable variables. First one supports 'newTimeout' api, the second
164-
-- one 'registerDelay'.
163+
-- one 'registerDelay', the third one 'threadDelay'.
165164
--
166-
data TimerVars s = TimerVars !(TVar s TimeoutState) !(TVar s Bool)
165+
data TimerCompletionInfo s =
166+
Timer !(TVar s TimeoutState)
167+
| TimerRegisterDelay !(TVar s Bool)
168+
| TimerThreadDelay !ThreadId
169+
| TimerTimeout !ThreadId !TimeoutId !(STRef s IsLocked)
167170

168171
type RunQueue = OrdPSQ (Down ThreadId) (Down ThreadId) ()
169172

@@ -178,9 +181,10 @@ data SimState s a = SimState {
178181
finished :: !(Map ThreadId (FinishedReason, VectorClock)),
179182
-- | current time
180183
curTime :: !Time,
181-
-- | ordered list of timers
182-
timers :: !(OrdPSQ TimeoutId Time (TimerVars s)),
183-
-- | list of clocks
184+
-- | ordered list of timers and timeouts
185+
timers :: !(OrdPSQ TimeoutId Time (TimerCompletionInfo s)),
186+
-- | timeout locks in order to synchronize the timeout handler and the
187+
-- main thread
184188
clocks :: !(Map ClockId UTCTime),
185189
nextVid :: !TVarId, -- ^ next unused 'TVarId'
186190
nextTmid :: !TimeoutId, -- ^ next unused 'TimeoutId'
@@ -338,7 +342,53 @@ schedule thread@Thread{
338342
let thread' = thread { threadControl = ThreadControl (k x) ctl' }
339343
schedule thread' simstate
340344

345+
TimeoutFrame tmid isLockedRef k ctl' -> do
346+
-- It could happen that the timeout action finished at the same time
347+
-- as the timeout expired, this will be a race condition. That's why
348+
-- we have the locks to solve this.
349+
--
350+
-- The lock starts 'NotLocked' and when the timeout fires the lock is
351+
-- locked and asynchronously an assassin thread is coming to interrupt
352+
-- this one. If the lock is locked when the timeout is fired then nothing
353+
-- happens.
354+
--
355+
-- Knowing this, if we reached this point in the code and the lock is
356+
-- 'Locked', then it means that this thread still hasn't received the
357+
-- 'TimeoutException', so we need to kill the thread that is responsible
358+
-- for doing that (the assassin one, we need to defend ourselves!)
359+
-- and run our continuation successfully and peacefully. We will do that
360+
-- by uninterruptibly-masking ourselves so we can not receive any
361+
-- exception and kill the assassin thread behind its back.
362+
-- If the lock is 'NotLocked' then it means we can just acquire it and
363+
-- carry on with the success case.
364+
locked <- readSTRef isLockedRef
365+
case locked of
366+
Locked etid -> do
367+
let -- Kill the exception throwing thread and carry on the
368+
-- continuation
369+
thread' =
370+
thread { threadControl =
371+
ThreadControl (ThrowTo (toException ThreadKilled)
372+
etid
373+
(k (Just x)))
374+
ctl'
375+
, threadMasking = MaskedUninterruptible
376+
}
377+
schedule thread' simstate
378+
379+
NotLocked -> do
380+
-- Acquire lock
381+
writeSTRef isLockedRef (Locked tid)
382+
383+
-- Remove the timer from the queue
384+
let timers' = PSQ.delete tmid timers
385+
-- Run the continuation successfully
386+
thread' = thread { threadControl = ThreadControl (k (Just x)) ctl' }
387+
388+
schedule thread' simstate { timers = timers'
389+
}
341390
Throw thrower e -> case unwindControlStack e thread of
391+
-- Found a CatchFrame
342392
Right thread0@Thread { threadMasking = maskst' } -> do
343393
-- We found a suitable exception handler, continue with that
344394
-- We record a step, in case there is no exception handler on replay.
@@ -452,19 +502,73 @@ schedule thread@Thread{
452502
(Just $ "<<timeout-state " ++ show (unTimeoutId nextTmid) ++ ">>")
453503
TimeoutPending
454504
modifySTRef (tvarVClock tvar) (leastUpperBoundVClock vClock)
455-
tvar' <- execNewTVar (succ nextVid)
456-
(Just $ "<<timeout " ++ show (unTimeoutId nextTmid) ++ ">>")
457-
False
458-
modifySTRef (tvarVClock tvar') (leastUpperBoundVClock vClock)
459505
let expiry = d `addTime` time
460506
t = Timeout tvar nextTmid
461-
timers' = PSQ.insert nextTmid expiry (TimerVars tvar tvar') timers
507+
timers' = PSQ.insert nextTmid expiry (Timer tvar) timers
462508
thread' = thread { threadControl = ThreadControl (k t) ctl }
463-
!trace <- schedule thread' simstate { timers = timers'
509+
trace <- schedule thread' simstate { timers = timers'
464510
, nextVid = succ (succ nextVid)
465511
, nextTmid = succ nextTmid }
466512
return (SimPORTrace time tid tstep tlbl (EventTimerCreated nextTmid nextVid expiry) trace)
467513

514+
-- This case is guarded by checks in 'timeout' itself.
515+
StartTimeout d _ _ | d <= 0 ->
516+
error "schedule: StartTimeout: Impossible happened"
517+
518+
StartTimeout d action' k -> do
519+
isLockedRef <- newSTRef NotLocked
520+
let expiry = d `addTime` time
521+
timers' = PSQ.insert nextTmid expiry (TimerTimeout tid nextTmid isLockedRef) timers
522+
thread' = thread { threadControl =
523+
ThreadControl action'
524+
(TimeoutFrame nextTmid isLockedRef k ctl)
525+
}
526+
trace <- deschedule Yield thread' simstate { timers = timers'
527+
, nextTmid = succ nextTmid }
528+
return (SimPORTrace time tid tstep tlbl (EventTimeoutCreated nextTmid tid expiry) trace)
529+
530+
RegisterDelay d k | d < 0 -> do
531+
tvar <- execNewTVar nextVid
532+
(Just $ "<<timeout " ++ show (unTimeoutId nextTmid) ++ ">>")
533+
True
534+
modifySTRef (tvarVClock tvar) (leastUpperBoundVClock vClock)
535+
let !expiry = d `addTime` time
536+
!thread' = thread { threadControl = ThreadControl (k tvar) ctl }
537+
trace <- schedule thread' simstate { nextVid = succ nextVid }
538+
return (SimPORTrace time tid tstep tlbl (EventRegisterDelayCreated nextTmid nextVid expiry) $
539+
SimPORTrace time tid tstep tlbl (EventRegisterDelayFired nextTmid) $
540+
trace)
541+
542+
RegisterDelay d k -> do
543+
tvar <- execNewTVar nextVid
544+
(Just $ "<<timeout " ++ show (unTimeoutId nextTmid) ++ ">>")
545+
False
546+
modifySTRef (tvarVClock tvar) (leastUpperBoundVClock vClock)
547+
let !expiry = d `addTime` time
548+
!timers' = PSQ.insert nextTmid expiry (TimerRegisterDelay tvar) timers
549+
!thread' = thread { threadControl = ThreadControl (k tvar) ctl }
550+
trace <- schedule thread' simstate { timers = timers'
551+
, nextVid = succ nextVid
552+
, nextTmid = succ nextTmid }
553+
return (SimPORTrace time tid tstep tlbl
554+
(EventRegisterDelayCreated nextTmid nextVid expiry) trace)
555+
556+
ThreadDelay d k | d < 0 -> do
557+
let expiry = d `addTime` time
558+
thread' = thread { threadControl = ThreadControl k ctl }
559+
trace <- schedule thread' simstate
560+
return (SimPORTrace time tid tstep tlbl (EventThreadDelay expiry) $
561+
SimPORTrace time tid tstep tlbl EventThreadDelayFired $
562+
trace)
563+
564+
ThreadDelay d k -> do
565+
let expiry = d `addTime` time
566+
timers' = PSQ.insert nextTmid expiry (TimerThreadDelay tid) timers
567+
thread' = thread { threadControl = ThreadControl k ctl }
568+
trace <- deschedule Blocked thread' simstate { timers = timers'
569+
, nextTmid = succ nextTmid }
570+
return (SimPORTrace time tid tstep tlbl (EventThreadDelay expiry) trace)
571+
468572
-- we do not follow `GHC.Event` behaviour here; updating a timer to the past
469573
-- effectively cancels it.
470574
UpdateTimeout (Timeout _tvar tmid) d k | d < 0 -> do
@@ -935,34 +1039,65 @@ reschedule simstate@SimState{ threads, timers, curTime = time, races } =
9351039
-- Reuse the STM functionality here to write all the timer TVars.
9361040
-- Simplify to a special case that only reads and writes TVars.
9371041
written <- execAtomically' (runSTM $ mapM_ timeoutAction fired)
938-
(wakeup, wokeby) <- threadsUnblockedByWrites written
1042+
(wakeupSTM, wokeby) <- threadsUnblockedByWrites written
9391043
mapM_ (\(SomeTVar tvar) -> unblockAllThreadsFromTVar tvar) written
9401044

941-
-- TODO: the vector clock below cannot be right, can it?
942-
let (unblocked,
943-
simstate') = unblockThreads bottomVClock wakeup simstate
944-
-- all open races will be completed and reported at this time
945-
simstate'' = simstate'{ races = noRaces }
1045+
let wakeupThreadDelay = [ tid | TimerThreadDelay tid <- fired ]
1046+
wakeup = wakeupThreadDelay ++ wakeupSTM
1047+
-- TODO: the vector clock below cannot be right, can it?
1048+
(_, !simstate') = unblockThreads bottomVClock wakeup simstate
1049+
1050+
-- For each 'timeout' action where the timeout has fired, start a
1051+
-- new thread to execute throwTo to interrupt the action.
1052+
!timeoutExpired = [ (tid, tmid, isLockedRef)
1053+
| TimerTimeout tid tmid isLockedRef <- fired ]
1054+
1055+
-- Get the isLockedRef values
1056+
!timeoutExpired' <- traverse (\(tid, tmid, isLockedRef) -> do
1057+
locked <- readSTRef isLockedRef
1058+
return (tid, tmid, isLockedRef, locked)
1059+
)
1060+
timeoutExpired
1061+
1062+
-- all open races will be completed and reported at this time
1063+
!simstate'' <- forkTimeoutInterruptThreads timeoutExpired'
1064+
simstate' { races = noRaces }
9461065
!trace <- reschedule simstate'' { curTime = time'
9471066
, timers = timers' }
9481067
let traceEntries =
949-
[ (time', ThreadId [-1], (-1), Just "timer", EventTimerFired tmid)
950-
| tmid <- tmids ]
951-
++ [ (time', tid', (-1), tlbl', EventTxWakeup vids)
952-
| tid' <- unblocked
1068+
[ ( time', ThreadId [-1], -1, Just "timer"
1069+
, EventTimerFired tmid)
1070+
| (tmid, Timer _) <- zip tmids fired ]
1071+
++ [ ( time', ThreadId [-1], -1, Just "register delay timer"
1072+
, EventRegisterDelayFired tmid)
1073+
| (tmid, TimerRegisterDelay _) <- zip tmids fired ]
1074+
++ [ (time', tid', -1, tlbl', EventTxWakeup vids)
1075+
| tid' <- wakeupSTM
9531076
, let tlbl' = lookupThreadLabel tid' threads
9541077
, let Just vids = Set.toList <$> Map.lookup tid' wokeby ]
1078+
++ [ ( time', tid, -1, Just "thread delay timer"
1079+
, EventThreadDelayFired)
1080+
| tid <- wakeupThreadDelay ]
1081+
++ [ ( time', tid, -1, Just "timeout timer"
1082+
, EventTimeoutFired tmid)
1083+
| (tid, tmid, _, _) <- timeoutExpired' ]
1084+
++ [ ( time', tid, -1, Just "forked thread"
1085+
, EventThreadForked tid)
1086+
| (tid, _, _, _) <- timeoutExpired' ]
1087+
9551088
return $
9561089
traceFinalRacesFound simstate $
9571090
traceMany traceEntries trace
9581091
where
959-
timeoutAction (TimerVars var bvar) = do
1092+
timeoutAction (Timer var) = do
9601093
x <- readTVar var
9611094
case x of
962-
TimeoutPending -> writeTVar var TimeoutFired
963-
>> writeTVar bvar True
1095+
TimeoutPending -> writeTVar var TimeoutFired
9641096
TimeoutFired -> error "MonadTimer(Sim): invariant violation"
9651097
TimeoutCancelled -> return ()
1098+
timeoutAction (TimerRegisterDelay var) = writeTVar var True
1099+
timeoutAction (TimerThreadDelay _) = return ()
1100+
timeoutAction (TimerTimeout _ _ _) = return ()
9661101

9671102
unblockThreads :: forall s a.
9681103
VectorClock
@@ -998,6 +1133,78 @@ unblockThreads vClock wakeup simstate@SimState {runqueue, threads} =
9981133
threadVClock = vClock `leastUpperBoundVClock` threadVClock t })))
9991134
threads unblockedIds
10001135

1136+
-- | This function receives a list of TimerTimeout values that represent threads
1137+
-- for which the timeout expired and kills the running thread if needed.
1138+
--
1139+
-- This function is responsible for the second part of the race condition issue
1140+
-- and relates to the 'schedule's 'TimeoutFrame' locking explanation (here is
1141+
-- where the assassin threads are launched. So, as explained previously, at this
1142+
-- point in code, the timeout expired so we need to interrupt the running
1143+
-- thread. If the running thread finished at the same time the timeout expired
1144+
-- we have a race condition. To deal with this race condition what we do is
1145+
-- look at the lock value. If it is 'Locked' this means that the running thread
1146+
-- already finished (or won the race) so we can safely do nothing. Otherwise, if
1147+
-- the lock value is 'NotLocked' we need to acquire the lock and launch an
1148+
-- assassin thread that is going to interrupt the running one. Note that we
1149+
-- should run this interrupting thread in an unmasked state since it might
1150+
-- receive a 'ThreadKilled' exception.
1151+
--
1152+
forkTimeoutInterruptThreads :: [(ThreadId, TimeoutId, STRef s IsLocked, IsLocked)]
1153+
-> SimState s a
1154+
-> ST s (SimState s a)
1155+
forkTimeoutInterruptThreads timeoutExpired simState@SimState {threads} =
1156+
foldlM (\st@SimState{ runqueue = runqueue,
1157+
threads = threads'
1158+
}
1159+
(t, isLockedRef)
1160+
-> do
1161+
let tid' = threadId t
1162+
threads'' = Map.insert tid' t threads'
1163+
runqueue' = insertThread t runqueue
1164+
writeSTRef isLockedRef (Locked tid')
1165+
1166+
return st { runqueue = runqueue',
1167+
threads = threads''
1168+
})
1169+
simState
1170+
throwToThread
1171+
1172+
where
1173+
-- can only throw exception if the thread exists and if the mutually
1174+
-- exclusive lock exists and is still 'NotLocked'
1175+
toThrow = [ (tid, tmid, ref, t)
1176+
| (tid, tmid, ref, locked) <- timeoutExpired
1177+
, Just t <- [Map.lookup tid threads]
1178+
, NotLocked <- [locked]
1179+
]
1180+
-- we launch a thread responsible for throwing an AsyncCancelled exception
1181+
-- to the thread which timeout expired
1182+
throwToThread =
1183+
[ let nextId = threadNextTId t
1184+
tid' = childThreadId tid nextId
1185+
in ( Thread { threadId = tid',
1186+
threadControl =
1187+
ThreadControl
1188+
(ThrowTo (toException (TimeoutException tmid))
1189+
tid
1190+
(Return ()))
1191+
ForkFrame,
1192+
threadBlocked = False,
1193+
threadDone = False,
1194+
threadMasking = Unmasked,
1195+
threadThrowTo = [],
1196+
threadClockId = threadClockId t,
1197+
threadLabel = Just "timeout-forked-thread",
1198+
threadNextTId = 1,
1199+
threadStep = 0,
1200+
threadVClock = insertVClock tid' 0
1201+
$ threadVClock t,
1202+
threadEffect = mempty,
1203+
threadRacy = threadRacy t
1204+
}
1205+
, ref)
1206+
| (tid, tmid, ref, t) <- toThrow
1207+
]
10011208

10021209
-- | Iterate through the control stack to find an enclosing exception handler
10031210
-- of the right type, or unwind all the way to the top level for the thread.
@@ -1014,7 +1221,8 @@ unwindControlStack e thread =
10141221
ThreadControl _ ctl -> unwind (threadMasking thread) ctl
10151222
where
10161223
unwind :: forall s' c. MaskingState
1017-
-> ControlStack s' c a -> Either Bool (Thread s' a)
1224+
-> ControlStack s' c a
1225+
-> Either Bool (Thread s' a)
10181226
unwind _ MainFrame = Left True
10191227
unwind _ ForkFrame = Left False
10201228
unwind _ (MaskFrame _k maskst' ctl) = unwind maskst' ctl
@@ -1026,12 +1234,28 @@ unwindControlStack e thread =
10261234

10271235
-- Ok! We will be able to continue the thread with the handler
10281236
-- followed by the continuation after the catch
1029-
Just e' -> Right thread {
1030-
-- As per async exception rules, the handler is run masked
1237+
Just e' -> Right ( thread {
1238+
-- As per async exception rules, the handler is run
1239+
-- masked
10311240
threadControl = ThreadControl (handler e')
10321241
(MaskFrame k maskst ctl),
10331242
threadMasking = atLeastInterruptibleMask maskst
10341243
}
1244+
)
1245+
1246+
-- Either Timeout fired or the action threw an exception.
1247+
-- - If Timeout fired, then it was possibly during this thread's execution
1248+
-- so we need to run the continuation with a Nothing value.
1249+
-- - If the timeout action threw an exception we need to keep unwinding the
1250+
-- control stack looking for a handler to this exception.
1251+
unwind maskst (TimeoutFrame tmid isLockedRef k ctl) =
1252+
case fromException e of
1253+
-- Exception came from timeout expiring
1254+
Just (TimeoutException tmid') ->
1255+
assert (tmid == tmid')
1256+
Right thread { threadControl = ThreadControl (k Nothing) ctl }
1257+
-- Exception came from a different exception
1258+
_ -> unwind maskst ctl
10351259

10361260
atLeastInterruptibleMask :: MaskingState -> MaskingState
10371261
atLeastInterruptibleMask Unmasked = MaskedInterruptible

0 commit comments

Comments
 (0)