@@ -54,7 +54,9 @@ module Control.Monad.IOSim.Internal
54
54
import Prelude hiding (read )
55
55
56
56
import Data.Dynamic
57
- import Data.Foldable (toList , traverse_ )
57
+ import Data.Foldable (toList , traverse_ , foldlM )
58
+ import Deque.Strict (Deque )
59
+ import qualified Deque.Strict as Deque
58
60
import qualified Data.List as List
59
61
import qualified Data.List.Trace as Trace
60
62
import Data.Map.Strict (Map )
@@ -65,16 +67,13 @@ import qualified Data.OrdPSQ as PSQ
65
67
import Data.Set (Set )
66
68
import qualified Data.Set as Set
67
69
import Data.Time (UTCTime (.. ), fromGregorian )
68
- import Deque.Strict (Deque )
69
- import qualified Deque.Strict as Deque
70
70
71
71
import GHC.Exts (fromList )
72
72
import GHC.Conc (ThreadStatus (.. ), BlockReason (.. ))
73
73
74
- import Control.Exception (NonTermination (.. ), assert , throw )
75
- import Control.Monad (join )
76
-
77
- import Control.Monad (when )
74
+ import Control.Exception
75
+ (NonTermination (.. ), assert , throw , AsyncException (.. ))
76
+ import Control.Monad (join , when )
78
77
import Control.Monad.ST.Lazy
79
78
import Control.Monad.ST.Lazy.Unsafe (unsafeIOToST , unsafeInterleaveST )
80
79
import Data.STRef.Lazy
@@ -126,6 +125,7 @@ data TimerCompletionInfo s =
126
125
Timer ! (TVar s TimeoutState )
127
126
| TimerRegisterDelay ! (TVar s Bool )
128
127
| TimerThreadDelay ! ThreadId
128
+ | TimerTimeout ! ThreadId ! TimeoutId ! (STRef s IsLocked )
129
129
130
130
-- | Internal state.
131
131
--
@@ -138,7 +138,7 @@ data SimState s a = SimState {
138
138
finished :: ! (Map ThreadId FinishedReason ),
139
139
-- | current time
140
140
curTime :: ! Time ,
141
- -- | ordered list of timers
141
+ -- | ordered list of timers and timeouts
142
142
timers :: ! (OrdPSQ TimeoutId Time (TimerCompletionInfo s )),
143
143
-- | list of clocks
144
144
clocks :: ! (Map ClockId UTCTime ),
@@ -235,8 +235,53 @@ schedule !thread@Thread{
235
235
let thread' = thread { threadControl = ThreadControl (k x) ctl' }
236
236
schedule thread' simstate
237
237
238
+ TimeoutFrame tmid isLockedRef k ctl' -> do
239
+ -- There is a possible race between timeout action and the timeout expiration.
240
+ -- We use a lock to solve the race.
241
+ --
242
+ -- The lock starts 'NotLocked' and when the timeout fires the lock is
243
+ -- locked and asynchronously an assassin thread is coming to interrupt
244
+ -- it. If the lock is locked when the timeout is fired then nothing
245
+ -- happens.
246
+ --
247
+ -- Knowing this, if we reached this point in the code and the lock is
248
+ -- 'Locked', then it means that this thread still hasn't received the
249
+ -- 'TimeoutException', so we need to kill the thread that is responsible
250
+ -- for doing that (the assassin thread, we need to defend ourselves!)
251
+ -- and run our continuation successfully and peacefully. We will do that
252
+ -- by uninterruptibly-masking ourselves so we can not receive any
253
+ -- exception and kill the assassin thread behind its back.
254
+ -- If the lock is 'NotLocked' then it means we can just acquire it and
255
+ -- carry on with the success case.
256
+ locked <- readSTRef isLockedRef
257
+ case locked of
258
+ Locked etid -> do
259
+ let -- Kill the assassin throwing thread and carry on the
260
+ -- continuation
261
+ thread' =
262
+ thread { threadControl =
263
+ ThreadControl (ThrowTo (toException ThreadKilled )
264
+ etid
265
+ (k (Just x)))
266
+ ctl'
267
+ , threadMasking = MaskedUninterruptible
268
+ }
269
+ schedule thread' simstate
270
+
271
+ NotLocked -> do
272
+ -- Acquire lock
273
+ writeSTRef isLockedRef (Locked tid)
274
+
275
+ -- Remove the timer from the queue
276
+ let timers' = PSQ. delete tmid timers
277
+ -- Run the continuation
278
+ thread' = thread { threadControl = ThreadControl (k (Just x)) ctl' }
279
+
280
+ schedule thread' simstate { timers = timers'
281
+ }
238
282
Throw thrower e -> {-# SCC "schedule.Throw" #-}
239
283
case unwindControlStack e thread of
284
+ -- Found a CatchFrame
240
285
Right thread'@ Thread { threadMasking = maskst' } -> do
241
286
-- We found a suitable exception handler, continue with that
242
287
trace <- schedule thread' simstate
@@ -360,6 +405,23 @@ schedule !thread@Thread{
360
405
, nextTmid = succ nextTmid }
361
406
return (SimTrace time tid tlbl (EventTimerCreated nextTmid nextVid expiry) trace)
362
407
408
+ -- This case is guarded by checks in 'timeout' itself.
409
+ StartTimeout d _ _ | d <= 0 ->
410
+ error " schedule: StartTimeout: Impossible happened"
411
+
412
+ StartTimeout d action' k ->
413
+ {-# SCC "schedule.StartTimeout" #-} do
414
+ isLockedRef <- newSTRef NotLocked
415
+ let ! expiry = d `addTime` time
416
+ ! timers' = PSQ. insert nextTmid expiry (TimerTimeout tid nextTmid isLockedRef) timers
417
+ ! thread' = thread { threadControl =
418
+ ThreadControl action'
419
+ (TimeoutFrame nextTmid isLockedRef k ctl)
420
+ }
421
+ ! trace <- deschedule Yield thread' simstate { timers = timers'
422
+ , nextTmid = succ nextTmid }
423
+ return (SimTrace time tid tlbl (EventTimeoutCreated nextTmid tid expiry) trace)
424
+
363
425
RegisterDelay d k | d < 0 ->
364
426
{-# SCC "schedule.NewRegisterDelay" #-} do
365
427
! tvar <- execNewTVar nextVid
@@ -404,7 +466,6 @@ schedule !thread@Thread{
404
466
, nextTmid = succ nextTmid }
405
467
return (SimTrace time tid tlbl (EventThreadDelay expiry) trace)
406
468
407
-
408
469
-- we do not follow `GHC.Event` behaviour here; updating a timer to the past
409
470
-- effectively cancels it.
410
471
UpdateTimeout (Timeout _tvar tmid) d k | d < 0 ->
@@ -777,8 +838,23 @@ reschedule !simstate@SimState{ threads, timers, curTime = time } =
777
838
wakeup = wakeupThreadDelay ++ wakeupSTM
778
839
(_, ! simstate') = unblockThreads wakeup simstate
779
840
780
- ! trace <- reschedule simstate' { curTime = time'
781
- , timers = timers' }
841
+ -- For each 'timeout' action where the timeout has fired, start a
842
+ -- new thread to execute throwTo to interrupt the action.
843
+ ! timeoutExpired = [ (tid, tmid, isLockedRef)
844
+ | TimerTimeout tid tmid isLockedRef <- fired ]
845
+
846
+ -- Get the isLockedRef values
847
+ ! timeoutExpired' <- traverse (\ (tid, tmid, isLockedRef) -> do
848
+ locked <- readSTRef isLockedRef
849
+ return (tid, tmid, isLockedRef, locked)
850
+ )
851
+ timeoutExpired
852
+
853
+ ! simstate'' <- forkTimeoutInterruptThreads timeoutExpired' simstate'
854
+
855
+ ! trace <- reschedule simstate'' { curTime = time'
856
+ , timers = timers' }
857
+
782
858
return $
783
859
traceMany ([ ( time', ThreadId [- 1 ], Just " timer"
784
860
, EventTimerFired tmid)
@@ -792,7 +868,13 @@ reschedule !simstate@SimState{ threads, timers, curTime = time } =
792
868
, let Just vids = Set. toList <$> Map. lookup tid' wokeby ]
793
869
++ [ ( time', tid, Just " thread delay timer"
794
870
, EventThreadDelayFired )
795
- | tid <- wakeupThreadDelay ])
871
+ | tid <- wakeupThreadDelay ]
872
+ ++ [ ( time', tid, Just " timeout timer"
873
+ , EventTimeoutFired tmid)
874
+ | (tid, tmid, _, _) <- timeoutExpired' ]
875
+ ++ [ ( time', tid, Just " thread forked"
876
+ , EventThreadForked tid)
877
+ | (tid, _, _, _) <- timeoutExpired' ])
796
878
trace
797
879
where
798
880
timeoutSTMAction (Timer var) = do
@@ -804,7 +886,8 @@ reschedule !simstate@SimState{ threads, timers, curTime = time } =
804
886
timeoutSTMAction (TimerRegisterDelay var) = writeTVar var True
805
887
-- Note that 'threadDelay' is not handled via STM style wakeup, but rather
806
888
-- it's handled directly above with 'wakeupThreadDelay' and 'unblockThreads'
807
- timeoutSTMAction (TimerThreadDelay _) = return ()
889
+ timeoutSTMAction TimerThreadDelay {} = return ()
890
+ timeoutSTMAction TimerTimeout {} = return ()
808
891
809
892
unblockThreads :: [ThreadId ] -> SimState s a -> ([ThreadId ], SimState s a )
810
893
unblockThreads ! wakeup ! simstate@ SimState {runqueue, threads} =
@@ -825,7 +908,76 @@ unblockThreads !wakeup !simstate@SimState {runqueue, threads} =
825
908
-- and in which case we mark them as now running
826
909
! threads' = List. foldl'
827
910
(flip (Map. adjust (\ t -> t { threadBlocked = False })))
828
- threads unblocked
911
+ threads
912
+ unblocked
913
+
914
+ -- | This function receives a list of TimerTimeout values that represent threads
915
+ -- for which the timeout expired and kills the running thread if needed.
916
+ --
917
+ -- This function is responsible for the second part of the race condition issue
918
+ -- and relates to the 'schedule's 'TimeoutFrame' locking explanation (here is
919
+ -- where the assassin threads are launched. So, as explained previously, at this
920
+ -- point in code, the timeout expired so we need to interrupt the running
921
+ -- thread. If the running thread finished at the same time the timeout expired
922
+ -- we have a race condition. To deal with this race condition what we do is
923
+ -- look at the lock value. If it is 'Locked' this means that the running thread
924
+ -- already finished (or won the race) so we can safely do nothing. Otherwise, if
925
+ -- the lock value is 'NotLocked' we need to acquire the lock and launch an
926
+ -- assassin thread that is going to interrupt the running one. Note that we
927
+ -- should run this interrupting thread in an unmasked state since it might
928
+ -- receive a 'ThreadKilled' exception.
929
+ --
930
+ forkTimeoutInterruptThreads :: [(ThreadId , TimeoutId , STRef s IsLocked , IsLocked )]
931
+ -> SimState s a
932
+ -> ST s (SimState s a )
933
+ forkTimeoutInterruptThreads timeoutExpired simState@ SimState {threads} =
934
+ foldlM (\ st@ SimState { runqueue = runqueue,
935
+ threads = threads'
936
+ }
937
+ (t, isLockedRef)
938
+ -> do
939
+ let tid' = threadId t
940
+ threads'' = Map. insert tid' t threads'
941
+ runqueue' = Deque. snoc tid' runqueue
942
+
943
+ writeSTRef isLockedRef (Locked tid')
944
+
945
+ return st { runqueue = runqueue',
946
+ threads = threads''
947
+ })
948
+ simState
949
+ throwToThread
950
+
951
+ where
952
+ -- can only throw exception if the thread exists and if the mutually
953
+ -- exclusive lock exists and is still 'NotLocked'
954
+ toThrow = [ (tid, tmid, ref, t)
955
+ | (tid, tmid, ref, locked) <- timeoutExpired
956
+ , Just t <- [Map. lookup tid threads]
957
+ , NotLocked <- [locked]
958
+ ]
959
+ -- we launch a thread responsible for throwing an AsyncCancelled exception
960
+ -- to the thread which timeout expired
961
+ throwToThread =
962
+ [ let nextId = threadNextTId t
963
+ tid' = childThreadId tid nextId
964
+ in ( Thread { threadId = tid',
965
+ threadControl =
966
+ ThreadControl
967
+ (ThrowTo (toException (TimeoutException tmid))
968
+ tid
969
+ (Return () ))
970
+ ForkFrame ,
971
+ threadBlocked = False ,
972
+ threadMasking = Unmasked ,
973
+ threadThrowTo = [] ,
974
+ threadClockId = threadClockId t,
975
+ threadLabel = Just " timeout-forked-thread" ,
976
+ threadNextTId = 1
977
+ }
978
+ , ref )
979
+ | (tid, tmid, ref, t) <- toThrow
980
+ ]
829
981
830
982
831
983
-- | Iterate through the control stack to find an enclosing exception handler
@@ -843,7 +995,8 @@ unwindControlStack e thread =
843
995
ThreadControl _ ctl -> unwind (threadMasking thread) ctl
844
996
where
845
997
unwind :: forall s' c . MaskingState
846
- -> ControlStack s' c a -> Either Bool (Thread s' a )
998
+ -> ControlStack s' c a
999
+ -> Either Bool (Thread s' a )
847
1000
unwind _ MainFrame = Left True
848
1001
unwind _ ForkFrame = Left False
849
1002
unwind _ (MaskFrame _k maskst' ctl) = unwind maskst' ctl
@@ -855,12 +1008,28 @@ unwindControlStack e thread =
855
1008
856
1009
-- Ok! We will be able to continue the thread with the handler
857
1010
-- followed by the continuation after the catch
858
- Just e' -> Right thread {
859
- -- As per async exception rules, the handler is run masked
1011
+ Just e' -> Right ( thread {
1012
+ -- As per async exception rules, the handler is run
1013
+ -- masked
860
1014
threadControl = ThreadControl (handler e')
861
1015
(MaskFrame k maskst ctl),
862
1016
threadMasking = atLeastInterruptibleMask maskst
863
1017
}
1018
+ )
1019
+
1020
+ -- Either Timeout fired or the action threw an exception.
1021
+ -- - If Timeout fired, then it was possibly during this thread's execution
1022
+ -- so we need to run the continuation with a Nothing value.
1023
+ -- - If the timeout action threw an exception we need to keep unwinding the
1024
+ -- control stack looking for a handler to this exception.
1025
+ unwind maskst (TimeoutFrame tmid _ k ctl) =
1026
+ case fromException e of
1027
+ -- Exception came from timeout expiring
1028
+ Just (TimeoutException tmid') ->
1029
+ assert (tmid == tmid')
1030
+ Right thread { threadControl = ThreadControl (k Nothing ) ctl }
1031
+ -- Exception came from a different exception
1032
+ _ -> unwind maskst ctl
864
1033
865
1034
atLeastInterruptibleMask :: MaskingState -> MaskingState
866
1035
atLeastInterruptibleMask Unmasked = MaskedInterruptible
0 commit comments